Convert pretrained Tensorflow models (.h5) to pytorch

In [1]:
import h5py
import numpy as np
import requests
import torch
from torch import Tensor

In [2]:
def download_file(url, to):
    response = requests.get(url, stream=True)

    if response.status_code == 200:
        # Open a local file in write-binary mode and write the response content to it
        with open(to, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"File downloaded successfully: {to}")
    else:
        print(f"Failed to download file: {response.status_code}")

In [3]:
xception_url = (
    "https://github.com/fchollet/deep-learning-models/"
    "releases/download/v0.4/"
    "xception_weights_tf_dim_ordering_tf_kernels.h5"
)
xception_file = "../assets/xception_weights.h5"
download_file(xception_url, xception_file)

File downloaded successfully: ../assets/xception_weights.h5


In [4]:
def flatten_h5(h5: h5py.File | h5py.Group) -> dict[str, h5py.Dataset]:
    contents: dict[str, h5py.Dataset] = {}
    for k, v in h5.items():
        if isinstance(v, (h5py.File, h5py.Group)):
            sub_contents = flatten_h5(v)
            with_prefix = {k + "/" + sk: sv for sk, sv in sub_contents.items()}
            contents.update(with_prefix)
        elif isinstance(v, h5py.Dataset):
            contents[k] = v
        else:
            raise ValueError(f"Unknown value", v)
    return contents

In [5]:
xception_h5: h5py.File = h5py.File(xception_file, "r")
xception_items = flatten_h5(xception_h5)
for k, v in xception_items.items():
    print(k, v.shape, v.dtype)
    v = np.array(v)  # convert to numpy first is much faster
    torch.tensor(v)

batchnormalization_1/batchnormalization_1_beta:0 (32,) float32
batchnormalization_1/batchnormalization_1_gamma:0 (32,) float32
batchnormalization_1/batchnormalization_1_running_mean:0 (32,) float32
batchnormalization_1/batchnormalization_1_running_std:0 (32,) float32
batchnormalization_10/batchnormalization_10_beta:0 (728,) float32
batchnormalization_10/batchnormalization_10_gamma:0 (728,) float32
batchnormalization_10/batchnormalization_10_running_mean:0 (728,) float32
batchnormalization_10/batchnormalization_10_running_std:0 (728,) float32
batchnormalization_11/batchnormalization_11_beta:0 (728,) float32
batchnormalization_11/batchnormalization_11_gamma:0 (728,) float32
batchnormalization_11/batchnormalization_11_running_mean:0 (728,) float32
batchnormalization_11/batchnormalization_11_running_std:0 (728,) float32
batchnormalization_12/batchnormalization_12_beta:0 (728,) float32
batchnormalization_12/batchnormalization_12_gamma:0 (728,) float32
batchnormalization_12/batchnormalizatio

In [6]:
from pathlib import Path
import sys


sys.path.append(str(Path() / ".."))
from src.pixseg.models.backbones import (
    xception_original,
    Xception_Weights,
)

pytorch_model = xception_original(1000)
pytorch_state_dict = pytorch_model.state_dict()
for k, v in pytorch_state_dict.items():
    print(k, v.shape, v.dtype)

entry_flow.0.weight torch.Size([32, 3, 3, 3]) torch.float32
entry_flow.1.weight torch.Size([32]) torch.float32
entry_flow.1.bias torch.Size([32]) torch.float32
entry_flow.1.running_mean torch.Size([32]) torch.float32
entry_flow.1.running_var torch.Size([32]) torch.float32
entry_flow.1.num_batches_tracked torch.Size([]) torch.int64
entry_flow.3.weight torch.Size([64, 32, 3, 3]) torch.float32
entry_flow.4.weight torch.Size([64]) torch.float32
entry_flow.4.bias torch.Size([64]) torch.float32
entry_flow.4.running_mean torch.Size([64]) torch.float32
entry_flow.4.running_var torch.Size([64]) torch.float32
entry_flow.4.num_batches_tracked torch.Size([]) torch.int64
entry_flow.6.residual_branch.0.weight torch.Size([128, 64, 1, 1]) torch.float32
entry_flow.6.residual_branch.1.weight torch.Size([128]) torch.float32
entry_flow.6.residual_branch.1.bias torch.Size([128]) torch.float32
entry_flow.6.residual_branch.1.running_mean torch.Size([128]) torch.float32
entry_flow.6.residual_branch.1.running_

In [7]:
print(len(xception_items))
print(len(pytorch_state_dict))
print("Difference = 40 because keras doesn't store batchnorm's num_batches_tracked")

236
276
Difference = 40 because keras doesn't store batchnorm's num_batches_tracked


In [8]:
from torch import nn

from src.pixseg.models.backbones.xception import SeparableConv


def h5_to_pth(items: dict[str, h5py.Dataset], model: nn.Module) -> dict[str, Tensor]:
    conv_keys: list[str] = []
    batch_norm_keys: list[str] = []
    separable_conv_keys: list[str] = []
    linear_keys: list[str] = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            conv_keys.append(name)
        if isinstance(module, nn.BatchNorm2d):
            batch_norm_keys.append(name)
        if isinstance(module, SeparableConv):
            separable_conv_keys.append(name)
        if isinstance(module, nn.Linear):
            linear_keys.append(name)
    conv_keys = [
        e for e in conv_keys if e.rsplit(".", maxsplit=1)[0] not in separable_conv_keys
    ]
    print(
        len(conv_keys), len(batch_norm_keys), len(separable_conv_keys), len(linear_keys)
    )

    assert (
        len(items)
        == len(conv_keys) * 1
        + len(batch_norm_keys) * 4
        + len(separable_conv_keys) * 2
        + len(linear_keys) * 2
    )

    state_dict: dict[str, Tensor] = {}
    for i, k in enumerate(conv_keys):
        data = items[f"convolution2d_{i+1}/convolution2d_{i+1}_W:0"]
        data = torch.tensor(np.array(data), dtype=torch.float32)
        data = data.permute(3, 2, 0, 1)
        state_dict[f"{k}.weight"] = data

    key_pairs = [
        ("gamma", "weight"),
        ("beta", "bias"),
        ("running_mean", "running_mean"),
        ("running_std", "running_var"),
    ]
    for i, k in enumerate(batch_norm_keys):
        for keras_name, torch_name in key_pairs:
            data = items[
                f"batchnormalization_{i+1}/batchnormalization_{i+1}_{keras_name}:0"
            ]
            data = torch.tensor(np.array(data), dtype=torch.float32)
            state_dict[f"{k}.{torch_name}"] = data

    separable_pairs = [
        ("depthwise_kernel", "0", (2, 3, 0, 1)),
        ("pointwise_kernel", "1", (3, 2, 0, 1)),
    ]
    for i, k in enumerate(separable_conv_keys):
        for keras_name, torch_name, resize in separable_pairs:
            data = items[
                f"separableconvolution2d_{i+1}/separableconvolution2d_{i+1}_{keras_name}:0"
            ]
            data = torch.tensor(np.array(data), dtype=torch.float32)
            data = data.permute(*resize)
            state_dict[f"{k}.{torch_name}.weight"] = data

    assert len(linear_keys) == 1
    linear_weight_data = items["dense_2/dense_2_W:0"]
    linear_weight_data = torch.tensor(np.array(linear_weight_data), dtype=torch.float32)
    linear_weight_data = linear_weight_data.permute(1, 0)
    state_dict["exit_flow.9.weight"] = linear_weight_data
    linear_bias_data = items["dense_2/dense_2_b:0"]
    linear_bias_data = torch.tensor(np.array(linear_bias_data), dtype=torch.float32)
    state_dict["exit_flow.9.bias"] = linear_bias_data

    return state_dict

In [9]:
state_dict = h5_to_pth(xception_items, pytorch_model)
for k, v in state_dict.items():
    assert pytorch_state_dict[k].shape == v.shape
assert len(state_dict) == len(xception_items)
torch.save(state_dict, "../assets/xception_imagenet.pth")

6 40 34 1


In [10]:
# test
weights = Xception_Weights.IMAGENET1K
xception_original(weights=weights)

Xception(
  (entry_flow): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ResidualBlock(
      (residual_branch): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (main_branch): Sequential(
        (0): ReLU()
        (1): SeparableConv(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True,