In [1]:
import torch
import tensorflow as tf

import tensorflow
import numpy as np

In [2]:
import torch.nn as nn
import torch.nn.functional as F

Declaring tensor with random values

In [7]:
x_pt = torch.randint(255,(10,8,32,32,3))
x_np = x_pt.numpy()
x_tf = tf.convert_to_tensor(x_np)

x_pt.shape , x_tf.shape

(torch.Size([10, 8, 32, 32, 3]), TensorShape([10, 8, 32, 32, 3]))

## Window Partition Function

In [None]:
def get_window_size(x_size, window_size, shift_size=None):
    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)

In [None]:
window_size = (4,4,4)

B, D, H, W, C = x_pt.shape
window_size = get_window_size((D, H, W), window_size)
window_size

(4, 4, 4)

### PyTorch Window Partition

In [None]:
from operator import mul 
from functools import reduce


In [None]:
def window_partition_pt(x, window_size):
    """
    Args:
        x: (B, D, H, W, C)
        window_size (tuple[int]): window size
    Returns:
        windows: (B*num_windows, window_size*window_size, C)
    """
    B, D, H, W, C = x.shape
    x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
    return windows

In [None]:
x_windows_pt = window_partition_pt(x_pt, window_size)
x_windows_pt.shape

torch.Size([1280, 64, 3])

### TF window partition func

In [None]:
from functools import reduce

In [None]:
def window_partition_tf(x, window_size):
    """
    Args:
        x: (B, D, H, W, C)
        window_size (tuple[int]): window size
    Returns:
        windows: (B*num_windows, window_size*window_size, C)
    """
    B, D, H, W, C = x.shape
    x = tf.reshape(x, [B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C])
    windows = tf.reshape(tf.transpose(x, perm=[0, 1, 3, 5, 2, 4, 6, 7]), [-1, reduce((lambda x, y: x * y), window_size), C])                                    
                                               
    return windows

In [None]:
x_windows_tf = window_partition_tf(x_tf, window_size)
x_windows_tf.shape

TensorShape([1280, 64, 3])

Comparision

In [None]:
np.array_equal(x_windows_tf.numpy(), x_windows_pt.numpy())

True

## Patch Merging

PyTorch

In [124]:
class PatchMerging_pt(nn.Module):
    """ Patch Merging Layer
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, D, H, W, C).
        """
        B, D, H, W, C = x.shape

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, :, 0::2, 0::2, :]  # B D H/2 W/2 C
        x1 = x[:, :, 1::2, 0::2, :]  # B D H/2 W/2 C
        x2 = x[:, :, 0::2, 1::2, :]  # B D H/2 W/2 C
        x3 = x[:, :, 1::2, 1::2, :]  # B D H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B D H/2 W/2 4*C

        x = self.norm(x)
        #x = self.reduction(x)

        return x

TF

In [125]:
from tensorflow.keras.layers import Dense, LayerNormalization, Normalization 

In [126]:
class PatchMerging_tf(tf.keras.layers.Layer):
    def __init__(self, dim, norm_layer=LayerNormalization):
        super().__init__()
        self.dim = dim
        #self.reduction = Dense(2 * dim, use_bias=False,  activation=None)
        self.norm = norm_layer(epsilon=1e-5)

    def call(self, x):
        B, D, H, W, C = x.shape

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = tf.pad(x_tf, [[0,0], [0,0], [0,H%2], [0,W%2], [0,0]])

        x0 = x[:, :, 0::2, 0::2, :]  # B D H/2 W/2 C
        x1 = x[:, :, 1::2, 0::2, :]  # B D H/2 W/2 C
        x2 = x[:, :, 0::2, 1::2, :]  # B D H/2 W/2 C
        x3 = x[:, :, 1::2, 1::2, :]  # B D H/2 W/2 C

        x = tf.concat([x0, x1, x2, x3], axis=-1) # B D H/2 W/2 4*C

        x = self.norm(x)
        #x = self.reduction(x)

        return x

In [118]:
x_pt = torch.rand(10,8,32,32,3)
x_np = x_pt.numpy()
x_tf = tf.convert_to_tensor(x_np)

In [127]:
patchMerging_pt = PatchMerging_pt(3)
patchMerging_tf = PatchMerging_tf(3)

x_pt_merge = patchMerging_pt(x_pt) 
x_tf_merge = patchMerging_tf(x_tf) 

In [128]:
x_pt_merge.shape, x_tf_merge.shape

(torch.Size([10, 8, 16, 16, 12]), TensorShape([10, 8, 16, 16, 12]))

In [129]:
print( x_pt_merge[0][0][0][1].detach().numpy(),"\n\n\n", x_tf_merge[0][0][0][1].numpy() )

[-0.72882813 -1.2383318  -0.7402359  -0.9829437   0.8310238   1.1086562
 -0.6501629   1.327112    1.4385794  -1.2904475   0.18586369  0.7397152 ] 


 [-0.72882813 -1.2383318  -0.7402359  -0.98294365  0.8310237   1.1086562
 -0.6501629   1.327112    1.4385796  -1.2904475   0.18586373  0.73971534]


In [109]:
#np.array_equal(x_pt_merge.detach().numpy(), x_tf_merge.numpy())

False