In [1]:
import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np
from VideoSwinTransformer import *
import os
from collections import OrderedDict
import sys

In [2]:
def get_x():
    x_pt = torch.rand((1,3,8,224,224))
    x_np = x_pt.numpy()
    x_tf = tf.convert_to_tensor(x_np)

    return x_tf, x_pt

## Convert Weights

Initiate model and Load PyTorch Weights

In [3]:
cfg_method = model_configs.MODEL_MAP["swin_tiny_patch244_window877_kinetics400_1k"]
cfg = cfg_method()

name = cfg["name"]
link = cfg['link']
del cfg["name"]
del cfg['link']
download_weight_command = f"wget {link} -O {name}.pth"
os.system(download_weight_command)

pt_model = SwinTransformer3D_pt(**cfg, isTest= True)
tf_model = SwinTransformer3D(**cfg, isTest= True)
x_tf, x_pt = get_x()


basic_pt, y = tf_model(x_tf)

basic_tf, z= pt_model(x_pt)

checkpoint = torch.load(f'{name}.pth')



new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    if 'backbone' in k:
        name = k[9:]
        new_state_dict[name] = v 

pt_model.load_state_dict(new_state_dict) 


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>

Convert Functions

In [4]:
def conv_transpose(w):
    return w.transpose(2,3,4,1, 0)
    

def modify_tf_block( tf_component, pt_weight,  pt_bias = None, is_attn=False):
    in_shape = pt_weight.shape

    if isinstance(tf_component, tf.keras.layers.Conv3D) :
      pt_weight = conv_transpose(pt_weight)

    if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn:
      pt_weight =pt_weight.transpose()

    if isinstance(tf_component, (tf.keras.layers.Dense, tf.keras.layers.Conv3D)):
        tf_component.kernel.assign(tf.Variable(pt_weight))

        if pt_bias is not None:
            tf_component.bias.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, tf.keras.layers.LayerNormalization):

        tf_component.gamma.assign(tf.Variable(pt_weight))

        tf_component.beta.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, (tf.Variable)):
        tf_component.assign(tf.Variable(pt_weight))

    else:
        return tf.convert_to_tensor(pt_weight)
        
        

    return tf_component


def modify_swin_blocks(np_state_dict, pt_weights_prefix, tf_block):

  for layer in tf_block:
    if isinstance(layer, PatchMerging):
      patch_merging_idx = f"{pt_weights_prefix}.downsample"

      layer.reduction = modify_tf_block( layer.reduction,
                          np_state_dict[f"{patch_merging_idx}.reduction.weight"])
      layer.norm = modify_tf_block( layer.norm,
                        np_state_dict[f"{patch_merging_idx}.norm.weight"],
                        np_state_dict[f"{patch_merging_idx}.norm.bias"]
                        )
      
  # Swin Layers
  common_prefix = f"{pt_weights_prefix}.blocks"
  block_idx = 0

  for outer_layer in tf_block:

      layernorm_idx = 1
      mlp_layer_idx = 1

      if isinstance(outer_layer, SwinTransformerBlock3D):
          for inner_layer in outer_layer.layers:
        
              # Layer norm.
              if isinstance(inner_layer, tf.keras.layers.LayerNormalization):
                  layer_norm_prefix = (
                      f"{common_prefix}.{block_idx}.norm{layernorm_idx}"
                  )
                  inner_layer.gamma.assign(
                      tf.Variable(
                          np_state_dict[f"{layer_norm_prefix}.weight"]
                      )
                  )



                  inner_layer.beta.assign(
                      tf.Variable(np_state_dict[f"{layer_norm_prefix}.bias"])
                  )

                  layernorm_idx += 1

              # Window attention.
              elif isinstance(inner_layer, WindowAttention3D):
                  attn_prefix = f"{common_prefix}.{block_idx}.attn"

                  # Relative position.
                  inner_layer.relative_position_bias_table = (
                      modify_tf_block(
                          inner_layer.relative_position_bias_table,
                          np_state_dict[
                              f"{attn_prefix}.relative_position_bias_table"
                          ] 
                      )
                  )
                  inner_layer.relative_position_index = (
                      modify_tf_block(
                          inner_layer.relative_position_index,
                          np_state_dict[
                              f"{attn_prefix}.relative_position_index"
                          ]
                      )
                  )

                  # QKV.
                  inner_layer.qkv = modify_tf_block(
                      inner_layer.qkv,
                      np_state_dict[f"{attn_prefix}.qkv.weight"],
                      np_state_dict[f"{attn_prefix}.qkv.bias"]
                  )

                  # Projection.
                  inner_layer.proj = modify_tf_block(
                      inner_layer.proj,
                      np_state_dict[f"{attn_prefix}.proj.weight"],
                      np_state_dict[f"{attn_prefix}.proj.bias"]
                  )

              # MLP.
              elif isinstance(inner_layer, tf.keras.Model):
                  mlp_prefix = f"{common_prefix}.{block_idx}.mlp"
                  for mlp_layer in inner_layer.layers:
                      if isinstance(mlp_layer, tf.keras.layers.Dense):
                          mlp_layer = modify_tf_block(
                              mlp_layer,
                              np_state_dict[
                                  f"{mlp_prefix}.fc{mlp_layer_idx}.weight"
                              ],
                              np_state_dict[
                                  f"{mlp_prefix}.fc{mlp_layer_idx}.bias"
                              ]
                          )
                          mlp_layer_idx += 1

          block_idx += 1
  return tf_block


### convert

In [5]:
np_state_dict = pt_model.state_dict()
np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict}

tf_model.projection.layers[0] = modify_tf_block(tf_model.projection.layers[0]
        ,
        np_state_dict["patch_embed.proj.weight"],
        np_state_dict["patch_embed.proj.bias"])

tf_model.projection.layers[1] = modify_tf_block(
    tf_model.projection.layers[1],
    np_state_dict["patch_embed.norm.weight"],
    np_state_dict["patch_embed.norm.bias"])


layer_normalization_idx = -1

tf_model.layers[layer_normalization_idx] = modify_tf_block(
    tf_model.layers[layer_normalization_idx] ,
    np_state_dict["norm.weight"],
    np_state_dict["norm.bias"]
    )

# swin layers
for i in range(2, len(tf_model.layers) - 1):
    _ = modify_swin_blocks(np_state_dict,
                        f"layers.{i-2}",
                        tf_model.layers[i].layers)

In [6]:
basic_pt, y = tf_model(x_tf)
basic_tf, z= pt_model(x_pt)

y.shape, z.shape

(TensorShape([1, 768, 4, 7, 7]), torch.Size([1, 768, 4, 7, 7]))

In [7]:
for layer in basic_pt:
    print("--------------",layer, "-------------\n TF: ", basic_tf[layer].detach().numpy(), "\n PT: ", basic_pt[layer].numpy(), "\n" )

-------------- PatchEmbed -------------
 TF:  [[[[[ 9.49883819e-01  8.18407297e-01  8.50878477e-01 ...
      9.61524844e-01  8.64977002e-01  5.85084915e-01]
    [ 5.38985372e-01  1.10915351e+00  8.26075912e-01 ...
     -1.87682509e-02  5.95614314e-01  1.19097888e-01]
    [ 1.05364227e+00  6.69517040e-01  1.05497873e+00 ...
      8.82818103e-01  5.51013231e-01  8.38700294e-01]
    ...
    [ 7.81873941e-01  3.01905096e-01  8.88370752e-01 ...
      5.46061516e-01  7.31691480e-01 -2.30017722e-01]
    [ 6.22520924e-01  3.80051732e-01  5.35116792e-01 ...
      3.82915378e-01  1.25588787e+00  2.92785645e-01]
    [-1.67807341e-02  4.30829406e-01  3.58594894e-01 ...
      6.61100030e-01  8.97313833e-01  6.37549758e-01]]

   [[ 8.56279731e-01  6.70972466e-01  5.05105972e-01 ...
      1.01178658e+00  6.83066130e-01 -1.16000652e-01]
    [ 5.56352139e-01  1.03463113e+00  5.27025819e-01 ...
      5.79492092e-01  3.62853527e-01  7.77755737e-01]
    [ 7.39050627e-01  1.01775050e+00  2.61735559e-01 ...

### PT basic layer outputs comparison

In [8]:
# attempts = 10
# x_pt = torch.rand((1,3,8,224,224))

# outputs = []

# for i in range(attempts):
#     layer_out , result = pt_model(x_pt)
#     outputs.append(layer_out)

# i = 0
# for layer in outputs[0]:
#     print("--------------", layer, "---------------")
#     for idx, layer_out in enumerate(outputs) :
#         print(f"attempt {idx} : ",layer_out[layer].detach().numpy(), "\n")
#     print()



-------------- PatchEmbed ---------------
attempt 0 :  