In [2]:
import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np
from VideoSwinTransformer import *
import os
from collections import OrderedDict
import sys

### helpers

In [2]:
def conv_transpose(w):
    return w.transpose(2,3,4,1, 0)
    

def modify_tf_block( tf_component, pt_weight,  pt_bias = None, is_attn=False):
    in_shape = pt_weight.shape

    if isinstance(tf_component, tf.keras.layers.Conv3D) :
      pt_weight = conv_transpose(pt_weight)

    if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn:
      pt_weight =pt_weight.transpose()

    if isinstance(tf_component, (tf.keras.layers.Dense, tf.keras.layers.Conv3D)):
        tf_component.kernel.assign(tf.Variable(pt_weight))
        if pt_bias is not None:
            tf_component.bias.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, tf.keras.layers.LayerNormalization):

        tf_component.gamma.assign(tf.Variable(pt_weight))

        tf_component.beta.assign(tf.Variable(pt_bias))

    elif isinstance(tf_component, (tf.Variable)):
        tf_component.assign(tf.Variable(pt_weight))

    else:
        return tf.convert_to_tensor(pt_weight)
        
        

    return tf_component

In [3]:
def get_x(shape=(1,3,8,224,224)):
    x_pt = torch.rand(shape) * 255
    x_np = x_pt.numpy()
    x_tf = tf.convert_to_tensor(x_np)

    return x_tf, x_pt


In [4]:
x_tf, x_pt = get_x()


In [5]:
cfg_method = model_configs.MODEL_MAP["swin_tiny_patch244_window877_kinetics400_1k"]
cfg = cfg_method()

name = cfg["name"]
link = cfg['link']
del cfg["name"]
del cfg['link']
del cfg["drop_path_rate"]
# download_weight_command = f"wget {link} -O {name}.pth"
# os.system(download_weight_command)

In [7]:
pt_model = SwinTransformer3D_pt(**cfg,drop_rate=0.4, drop_path_rate=0., isTest= True)
print("--------")
tf_model = SwinTransformer3D(**cfg,drop_rate=0.4, drop_path_rate=0., isTest= True)
x_tf, x_pt = get_x()


print("-------pt")
basic_pt, z= pt_model(x_pt)
print("-------tf")

basic_tf, y = tf_model(x_tf)
checkpoint = torch.load(f'{name}.pth')



new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    if 'backbone' in k:
        nam = k[9:]
        new_state_dict[nam] = v 

pt_model.load_state_dict(new_state_dict) 
pt_model.eval()

#(64, 196, 96)  (64, 196, 196)

--------
-------pt
-------tf


### check basic4

<VideoSwinTransformer.BasicLayer.BasicLayer object at 0x000001A51A9C7FD0> 768 (2, 768, 2, 7, 7) 2 24 (8, 7, 7) 4.0 True None 0.4 0.0 [0.0, 0.0] <class 'keras.layers.normalization.layer_normalization.LayerNormalization'> None False


### window_partition

In [None]:
from functools import reduce
from operator import mul


In [None]:
def window_partition_pt(x, window_size):
    """
    Args:
        x: (B, D, H, W, C)
        window_size (tuple[int]): window size
    Returns:
        windows: (B*num_windows, window_size*window_size, C)
    """
    print(x.size(), window_size)
    B, D, H, W, C = x.shape
    x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
    return windows

In [None]:
x_tf, x_pt = get_x((1, 4, 56, 56, 96))

In [None]:
x = window_partition(x_tf, (4,7,7))

y = window_partition_pt(x_pt, (4,7,7))

x.shape, y.shape

partition (1, 4, 56, 56, 96) (4, 7, 7)
torch.Size([1, 4, 56, 56, 96]) (4, 7, 7)


(TensorShape([64, 196, 96]), torch.Size([64, 196, 96]))

In [None]:
rtol = 1e-4
etol = 1e-4

np.testing.assert_allclose(x.numpy(), y.detach().numpy(), etol, rtol)

### check WindowAttention

In [None]:
from timm.models.layers import DropPath, trunc_normal_

class WindowAttention3D_pt(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The temporal length, height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        print( dim, window_size, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)

        self.dim = dim
        self.window_size = window_size  # Wd, Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))  # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))  # 3, Wd, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 3, Wd*Wh*Ww, Wd*Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wd*Wh*Ww, Wd*Wh*Ww, 3
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, N, N) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B_, nH, N, C

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
            N, N, -1)  # Wd*Wh*Ww,Wd*Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wd*Wh*Ww, Wd*Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


In [None]:
atten_tf = WindowAttention3D(96, (8, 7, 7), 3, True, None, 0.0, 0.4)
atten_pt = WindowAttention3D_pt(96, (8, 7, 7), 3, True, None, 0.0, 0.4)

x_tf, x_pt = get_x((64, 196, 96))


96 (8, 7, 7) 3 True None 0.0 0.4


In [None]:
from VideoSwinTransformer.SwinTransformer3D_pt import compute_mask


mask_pt = compute_mask(4,56,56, (4,7,7), (0,3,3), None)


from VideoSwinTransformer import compute_mask


mask_tf = compute_mask(4,56,56, (4,7,7), (0,3,3))
mask_tf.shape , mask_pt.shape

compute mask pt
torch.Size([1, 4, 56, 56, 1]) (4, 7, 7)
compute mask cm tf
partition (1, 4, 56, 56, 1) (4, 7, 7)


(TensorShape([64, 196, 196]), torch.Size([64, 196, 196]))

In [None]:
x = atten_tf(x_tf, mask_tf)
y = atten_pt(x_pt, mask_pt)
x[:1,:1,:10] , y[:1,:1,:10]

(<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=
 array([[[ 73.772285,  64.29088 , -50.178078,   4.166691,  -7.296245,
           52.691998, 111.41441 ,  56.152546,  20.024395, -20.882938]]],
       dtype=float32)>,
 tensor([[[  13.6281,  -74.2625,   43.2334,  -21.4849, -141.7670,  -80.8282,
             11.9369,   48.6482,  -66.5065,  -21.7590]]],
        grad_fn=<SliceBackward0>))

In [None]:
atten_pt.eval()
np_state_dict = atten_pt.state_dict()
np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict}
np_state_dict.keys()

dict_keys(['relative_position_bias_table', 'relative_position_index', 'qkv.weight', 'qkv.bias', 'proj.weight', 'proj.bias'])

In [None]:
def modify_atten_block(inner_layer, np_state_dict):


    # Relative position.
    inner_layer.relative_position_bias_table = (
        modify_tf_block(
            inner_layer.relative_position_bias_table,
            np_state_dict[
                f"relative_position_bias_table"
            ] 
        )
    )
    inner_layer.relative_position_index = (
        modify_tf_block(
            inner_layer.relative_position_index,
            np_state_dict[
                f"relative_position_index"
            ]

        )
    )

    # QKV.
    inner_layer.qkv = modify_tf_block(
        inner_layer.qkv,
        np_state_dict[f"qkv.weight"],
        np_state_dict[f"qkv.bias"],

    )

    # Projection.
    inner_layer.proj = modify_tf_block(
        inner_layer.proj,
        np_state_dict[f"proj.weight"],
        np_state_dict[f"proj.bias"],

    )

In [None]:
_ = modify_atten_block(atten_tf, np_state_dict)

In [None]:
x = atten_tf(x_tf, mask_tf)
y = atten_pt(x_pt, mask_pt)
x[:1,:1,:10] , y[:1,:1,:10]

(<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=
 array([[[  8.176866, -44.557507,  25.940012, -12.890912, -85.06018 ,
          -48.496895,   7.162156,  29.188894, -39.903904, -13.055373]]],
       dtype=float32)>,
 tensor([[[  8.1769, -44.5575,  25.9400, -12.8909, -85.0602, -48.4969,   7.1622,
            29.1889, -39.9039, -13.0554]]], grad_fn=<SliceBackward0>))

In [None]:
rtol = 1e-4
etol = 1e-4

np.testing.assert_allclose(x.numpy(), y.detach().numpy(), etol, rtol)

### check roll (passed)

In [None]:
x_tf, x_pt = get_x((1, 4, 56, 56, 96))
shift_size = (4,3,3)

shifted_x_tf = tf.roll(x_tf, shift=[-shift_size[0], -shift_size[1], -shift_size[2]], axis=[1, 2, 3]) #?

shifted_x_pt = torch.roll(x_pt, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
shifted_x_tf.shape, shifted_x_pt.shape

(TensorShape([1, 4, 56, 56, 96]), torch.Size([1, 4, 56, 56, 96]))

In [None]:
rtol = 1e-6
etol = 1e-6

np.testing.assert_allclose(shifted_x_tf.numpy(), shifted_x_pt.detach().numpy(), etol, rtol)


### check identity (passed)

In [None]:
x_tf, x_pt = get_x((1, 4, 56, 56, 96))

In [None]:

x = tf.identity(x_tf)

drop_path_pt = nn.Identity()
y = drop_path_pt(x_pt)

x.shape, y.shape

(TensorShape([1, 4, 56, 56, 96]), torch.Size([1, 4, 56, 56, 96]))

In [None]:
rtol = 1e-7
etol = 1e-7
np.testing.assert_allclose(x.numpy(), y.detach().numpy(), etol, rtol)



### check mlp (passed)

In [None]:
class Mlp_pt(nn.Module):
    """ Multilayer perceptron."""

    def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, out_features=None,  drop=0.):
        super().__init__()
        print(in_features, hidden_features, act_layer,  out_features, drop)

        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # print('dense',hidden_features, in_features)

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        print("mlp", x.shape)

        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [None]:
x_tf, x_pt = get_x((1, 4, 56, 56, 96))

In [None]:
mlp_tf = mlp_block(96, 384 , tf.keras.activations.gelu ,None, 0.4)
print()
mlp_pt = Mlp_pt(96, 384 , nn.GELU ,None, 0.4)



96 384 <class 'torch.nn.modules.activation.GELU'> None 0.4


In [None]:
x = mlp_tf(x_tf)
y = mlp_pt(x_pt)

x[:1,:1,:1,:1,:10], y[:1,:1,:1,:1,:10] 


mlp torch.Size([1, 4, 56, 56, 96])


(<tf.Tensor: shape=(1, 1, 1, 1, 10), dtype=float32, numpy=
 array([[[[[ -55.042526 ,   30.78402  , -280.38745  ,  -72.948235 ,
             -39.252743 ,  -13.429337 ,   -5.7349243, -161.74384  ,
             -73.916504 ,   24.173918 ]]]]], dtype=float32)>,
 tensor([[[[[-125.6045,    0.0000,   77.3646,  124.4250,  -21.5418,  -22.6836,
               50.7710,   63.3386,   21.1723,   -0.0000]]]]],
        grad_fn=<SliceBackward0>))

In [None]:
mlp_pt.eval()
np_state_dict = mlp_pt.state_dict()
np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict}

In [None]:
def modify_mlp(inner_layer, np_state_dict):
    mlp_layer_idx = 1
    for mlp_layer in inner_layer.layers:

        if isinstance(mlp_layer, tf.keras.layers.Dense):
            mlp_layer = modify_tf_block(
                mlp_layer,
                np_state_dict[
                    f"fc{mlp_layer_idx}.weight"
                ],
                np_state_dict[
                    f"fc{mlp_layer_idx}.bias"
                ]

            )
            mlp_layer_idx += 1

In [None]:
_ = modify_mlp(mlp_tf, np_state_dict)

In [None]:
x = mlp_tf(x_tf)
y = mlp_pt(x_pt)

x[:1,:1,:1,:1,:10], y[:1,:1,:1,:1,:10] 

mlp torch.Size([1, 4, 56, 56, 96])


(<tf.Tensor: shape=(1, 1, 1, 1, 10), dtype=float32, numpy=
 array([[[[[-10.9844885,   9.016938 , -46.57778  ,  45.277203 ,
             16.67046  ,   2.3475373,  -6.9900846,   2.1323876,
            -10.313778 , -48.950424 ]]]]], dtype=float32)>,
 tensor([[[[[-10.9845,   9.0169, -46.5778,  45.2772,  16.6705,   2.3475,
              -6.9901,   2.1324, -10.3138, -48.9504]]]]],
        grad_fn=<SliceBackward0>))

In [None]:
rtol = 1e-2
etol = 1e-2

np.testing.assert_allclose(x.numpy(), y.detach().numpy(), etol, rtol)

### check 