In [1]:
import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np
from VideoSwinTransformer import *
import os
from collections import OrderedDict

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x28576bd8af0>

In [3]:
def get_x():
    x_pt = torch.rand((1,3,8,224,224)) * 255
    x_np = x_pt.numpy()
    x_tf = tf.convert_to_tensor(x_np)

    return x_tf, x_pt

In [4]:
x_tf, x_pt = get_x()

## Convert Weights

Initiate model and Load PyTorch Weights

In [5]:
cfg_method = model_configs.MODEL_MAP["swin_tiny_patch244_window877_kinetics400_1k"]
cfg = cfg_method()

name = cfg["name"]
link = cfg['link']
del cfg["name"]
del cfg['link']
del cfg["drop_path_rate"]
# download_weight_command = f"wget {link} -O {name}.pth"
# os.system(download_weight_command)

In [6]:
pt_model = SwinTransformer3D_pt(**cfg,drop_rate=0.4, drop_path_rate=0., isTest= True)
tf_model = SwinTransformer3D(**cfg,drop_rate=0.4, drop_path_rate=0., isTest= True)
x_tf, x_pt = get_x()



basic_pt, z= pt_model(x_pt)
print("-------")

basic_tf, y = tf_model(x_tf)
checkpoint = torch.load(f'{name}.pth')



new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    if 'backbone' in k:
        name = k[9:]
        new_state_dict[name] = v 

pt_model.load_state_dict(new_state_dict) 
pt_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


-------


Convert Functions

In [7]:
def conv_transpose(w):
    return w.transpose(2,3,4,1, 0)
    

def modify_tf_block( tf_component, pt_weight,  pt_bias = None, is_attn=False):
    in_shape = pt_weight.shape

    if isinstance(tf_component, tf.keras.layers.Conv3D) :
      pt_weight = conv_transpose(pt_weight)

    if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn:
      pt_weight =pt_weight.transpose()

    if isinstance(tf_component, (tf.keras.layers.Dense, tf.keras.layers.Conv3D)):
        tf_component.kernel.assign(tf.Variable(pt_weight))

        if pt_bias is not None:
            tf_component.bias.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, tf.keras.layers.LayerNormalization):

        tf_component.gamma.assign(tf.Variable(pt_weight))

        tf_component.beta.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, (tf.Variable)):
        tf_component.assign(tf.Variable(pt_weight))

    else:
        return tf.convert_to_tensor(pt_weight)
        
        

    return tf_component


def modify_swin_blocks(np_state_dict, pt_weights_prefix, tf_block):

  for layer in tf_block:
    if isinstance(layer, PatchMerging):
      patch_merging_idx = f"{pt_weights_prefix}.downsample"

      layer.reduction = modify_tf_block( layer.reduction,
                          np_state_dict[f"{patch_merging_idx}.reduction.weight"])
      layer.norm = modify_tf_block( layer.norm,
                        np_state_dict[f"{patch_merging_idx}.norm.weight"],
                        np_state_dict[f"{patch_merging_idx}.norm.bias"]
                        )
      
  # Swin Layers
  common_prefix = f"{pt_weights_prefix}.blocks"
  block_idx = 0

  for outer_layer in tf_block:

      layernorm_idx = 1
      mlp_layer_idx = 1

      if isinstance(outer_layer, SwinTransformerBlock3D):
          for inner_layer in outer_layer.layers:
        
              # Layer norm.
              if isinstance(inner_layer, tf.keras.layers.LayerNormalization):
                  layer_norm_prefix = (
                      f"{common_prefix}.{block_idx}.norm{layernorm_idx}"
                  )
                  inner_layer.gamma.assign(
                      tf.Variable(
                          np_state_dict[f"{layer_norm_prefix}.weight"]
                      )
                  )



                  inner_layer.beta.assign(
                      tf.Variable(np_state_dict[f"{layer_norm_prefix}.bias"])
                  )

                  layernorm_idx += 1

              # Window attention.
              elif isinstance(inner_layer, WindowAttention3D):
                  attn_prefix = f"{common_prefix}.{block_idx}.attn"

                  # Relative position.
                  inner_layer.relative_position_bias_table = (
                      modify_tf_block(
                          inner_layer.relative_position_bias_table,
                          np_state_dict[
                              f"{attn_prefix}.relative_position_bias_table"
                          ] 
                      )
                  )
                  inner_layer.relative_position_index = (
                      modify_tf_block(
                          inner_layer.relative_position_index,
                          np_state_dict[
                              f"{attn_prefix}.relative_position_index"
                          ]
                      )
                  )

                  # QKV.
                  inner_layer.qkv = modify_tf_block(
                      inner_layer.qkv,
                      np_state_dict[f"{attn_prefix}.qkv.weight"],
                      np_state_dict[f"{attn_prefix}.qkv.bias"]
                  )

                  # Projection.
                  inner_layer.proj = modify_tf_block(
                      inner_layer.proj,
                      np_state_dict[f"{attn_prefix}.proj.weight"],
                      np_state_dict[f"{attn_prefix}.proj.bias"]
                  )

              # MLP.
              elif isinstance(inner_layer, tf.keras.Model):
                  mlp_prefix = f"{common_prefix}.{block_idx}.mlp"
                  for mlp_layer in inner_layer.layers:
                      if isinstance(mlp_layer, tf.keras.layers.Dense):
                          mlp_layer = modify_tf_block(
                              mlp_layer,
                              np_state_dict[
                                  f"{mlp_prefix}.fc{mlp_layer_idx}.weight"
                              ],
                              np_state_dict[
                                  f"{mlp_prefix}.fc{mlp_layer_idx}.bias"
                              ]
                          )
                          mlp_layer_idx += 1

          block_idx += 1
  return tf_block


### convert

In [8]:
np_state_dict = pt_model.state_dict()
np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict}

tf_model.projection.layers[0] = modify_tf_block(tf_model.projection.layers[0]
        ,
        np_state_dict["patch_embed.proj.weight"],
        np_state_dict["patch_embed.proj.bias"])

tf_model.projection.layers[1] = modify_tf_block(
    tf_model.projection.layers[1],
    np_state_dict["patch_embed.norm.weight"],
    np_state_dict["patch_embed.norm.bias"])


layer_normalization_idx = -1

tf_model.layers[layer_normalization_idx] = modify_tf_block(
    tf_model.layers[layer_normalization_idx] ,
    np_state_dict["norm.weight"],
    np_state_dict["norm.bias"]
    )

# swin layers
for i in range(2, len(tf_model.layers) - 1):
    _ = modify_swin_blocks(np_state_dict,
                        f"layers.{i-2}",
                        tf_model.layers[i].layers)

In [9]:
x_tf, x_pt = get_x()

In [10]:
layers_output_tf, y = tf_model(x_tf)
print("------")
layers_output_pt, z= pt_model(x_pt)

y.shape, z.shape

------


(TensorShape([1, 768, 4, 7, 7]), torch.Size([1, 768, 4, 7, 7]))

In [11]:
# Eplore the output of all layers
for layer in layers_output_pt:
    print("--------------",layer, "-------------\n TF: ", layers_output_tf[layer].numpy()[:1,:1,:1,:1,:10], "\n PT: ", layers_output_pt[layer].detach().numpy()[:1,:1,:1,:1,:10], "\n" )

-------------- PatchEmbed -------------
 TF:  [[[[[ 0.9755405   0.56495506  1.3553491  -0.28581792 -1.968935
      0.23740548  0.56161475 -0.34389305 -0.13261688 -0.4307437 ]]]]] 
 PT:  [[[[[ 0.97554076  0.564955    1.3553493  -0.285818   -1.968935
      0.23740548  0.56161475 -0.34389305 -0.132617   -0.43074358]]]]] 

-------------- drop_out -------------
 TF:  [[[[[ 0.9755405   0.56495506  1.3553491  -0.28581792 -1.968935
      0.23740548  0.56161475 -0.34389305 -0.13261688 -0.4307437 ]]]]] 
 PT:  [[[[[ 0.97554076  0.564955    1.3553493  -0.285818   -1.968935
      0.23740548  0.56161475 -0.34389305 -0.132617   -0.43074358]]]]] 

-------------- basic layer1 -------------
 TF:  [[[[[ 0.34182435  0.4569996   0.8055373   0.7067946   0.06197317
      0.7964725  -0.14857802  0.36377767  0.5041112  -0.36719453]]]]] 
 PT:  [[[[[ 0.34182394  0.45699912  0.80553675  0.7067941   0.06197314
      0.7964725  -0.1485779   0.363778    0.50411075 -0.36719453]]]]] 

-------------- basic layer2 -----

In [12]:
# compare layers' output. It asserts  at basic layer1
for layer in layers_output_pt:
    print("Testing", layer)
    np.testing.assert_allclose(layers_output_tf[layer].numpy(), layers_output_pt[layer].detach().numpy(), 1e-4, 1e-4)

Testing PatchEmbed
Testing drop_out
Testing basic layer1
Testing basic layer2
Testing basic layer3
Testing basic layer4


AssertionError: 
Not equal to tolerance rtol=0.0001, atol=0.0001

Mismatched elements: 149050 / 150528 (99%)
Max absolute difference: 7.2183533
Max relative difference: 9455.028
 x: array([[[[[ 8.325021e-01,  6.014856e-01,  8.798330e-01, ...,
            1.851248e-01,  1.317621e-01,  5.927004e-01],
          [ 6.490970e-02,  5.751448e-01,  1.031788e-01, ...,...
 y: array([[[[[ 0.831596,  0.592121,  0.878625, ...,  0.174236,  0.156727,
            0.581311],
          [ 0.045499,  0.590599,  0.11038 , ...,  0.936142,  0.23714 ,...

In [None]:
e = 1e-4
np.testing.assert_allclose(layers_output_tf["basic layer4"].numpy(), layers_output_pt["basic layer4"].detach().numpy(), e, e)


In [None]:
# comparing the outputs
np.testing.assert_allclose(y.numpy(), z.detach().numpy(), 1e-2, 1e-2)


In [None]:
# compare layers' output. It asserts  at basic layer1
for layer in layers_output_pt:
    print("Testing", layer)
    np.testing.assert_allclose(layers_output_tf[layer].numpy(), layers_output_pt[layer].detach().numpy(), 1e-4, 1e-4)

### PT basic layer outputs comparison

In [None]:
attempts = 10
# x_pt = torch.rand((1,3,8,224,224))

outputs = []

for i in range(attempts):
    layer_out , result = pt_model(x_pt)
    outputs.append(layer_out)

i = 0
for layer in outputs[0]:
    print("--------------", layer, "---------------")
    for idx, layer_out in enumerate(outputs) :
        print(f"attempt {idx} : ",layer_out[layer].detach().numpy()[:1,:1,:1,:1,:10], "\n")
    print()


### All close testing

Compare the first attempt with another attempt. Enter the attempt value in the following cell

In [None]:
attempt_no = 2        # Change the attempt_no value to compare the first attempt with another attempt


In [None]:
# PatchEmbed Layer
output1 = outputs[0]["PatchEmbed"]
output2 =  outputs[attempt_no]["PatchEmbed"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)

In [None]:
# Basic Layer1
output1 = outputs[0]["basic layer1"]
output2 =  outputs[attempt_no]["basic layer1"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)

In [None]:
# Basic Layer2

output1 = outputs[0]["basic layer2"]
output2 =  outputs[attempt_no]["basic layer2"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)

In [None]:
# Basic Layer3

output1 = outputs[0]["basic layer3"]
output2 =  outputs[attempt_no]["basic layer3"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)

In [None]:
# Basic Layer4

output1 = outputs[0]["basic layer4"]
output2 =  outputs[attempt_no]["basic layer4"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)

In [None]:
# Final Output

output1 = outputs[0]["Final Output"]
output2 =  outputs[attempt_no]["Final Output"]

np.testing.assert_allclose(output1.detach().numpy(), output2.detach().numpy(), 1e-4, 1e-4)