<a href="https://colab.research.google.com/github/z-arabi/pytorch-transformer/blob/main/VisionTransformers/Swin_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# the source code is: https://github.com/berniwal/swin-transformer-pytorch

In [1]:
!pip install torch==1.8.1
!pip install einops==0.3.0

[31mERROR: Could not find a version that satisfies the requirement torch==1.8.1 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1)[0m[31m
[0m[31mERROR: No matching distribution found for torch==1.8.1[0m[31m
[0mCollecting einops==0.3.0
  Downloading einops-0.3.0-py2.py3-none-any.whl.metadata (10 kB)
Downloading einops-0.3.0-py2.py3-none-any.whl (25 kB)
Installing collected packages: einops
  Attempting uninstall: einops
    Found existing installation: einops 0.8.0
    Uninstalling einops-0.8.0:
      Successfully uninstalled einops-0.8.0
Successfully installed einops-0.3.0


In [29]:
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat
import torch.nn.functional as F


class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        # minus values shows the shifting to the right and down
        # positive values > left and up
        # we have the b,h,w,c > so the dim would be 1 and 2
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        # for the v2 of sein transformers the layer norm occurs after the attention > self.norm(self.fn(x,**kwargs)
        return self.fn(self.norm(x), **kwargs) # that's prenorm for the v1 of swin transformers


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2) # 49*49

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    # print('indices: ', indices.shape) # [49,2]
    distances = indices[None, :, :] - indices[:, None, :]
    print("distance", distances.shape) # (49,49,2)
    return distances


class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        # dim=hid_dimension = (96, 192, 384, 768)
        # heads=number_heads = (3,6,12,24)
        # head_dim=32
        # window_size=7
        super().__init__()
        inner_dim = head_dim * heads # inner_dim == C == dim == hid_dimension == (32*3=96, 32*6=192, 32*12=384, 32*24=768)

        self.heads = heads
        self.scale = head_dim ** -0.5 # scaling in the softmax
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted # shifting to the right and down with the value of half of the window size and pad them
        # we have two kind of padding > naive(adding zero) and cyclic padding

        if self.shifted:
            displacement = window_size // 2 # 7//2=3
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            # Persistence: The mask is stored as part of the model's state and will be included when saving/loading the model.
            # not learnable
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        # for each stage we have the positional embedding because previously we had the image . patchify it > then give the embedding to the model
        if self.relative_pos_embedding:
            # the full matrix distances: 49,49,2
            # to have the indices in the correct range > the most neg > -6 (win-1) > add 6 to start from 0
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            # number of params are 13*13 instead of 49*49
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            # we don't need the cls token anymore and also embedding is different is has 49,49
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        # inner_dim == dim >> Wo
        self.to_out = nn.Linear(inner_dim, dim)

        ## for the cosine similiarity
        self.tau = nn.Parameter(torch.tensor(0.01), requires_grad=True)

    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x) #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

        b, n_h, n_w, _, h = *x.shape, self.heads

        # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96*3, 192*3, 384*3, 768*3))
        qkv = self.to_qkv(x).chunk(3, dim=-1) # chunk the last dim in order to have the same dim for all q,k,v

        nw_h = n_h // self.window_size # (56//7=8, 28//7=4, 14//7=2, 7//7=1)
        nw_w = n_w // self.window_size

        # just to separate the head and dim for each head
        #(b=1, h=(3,6,12,24), (nw_h*nw_w)=(64,16,4,1), (w_h*w_w)=49, d=32) where d=head_dim; h=#heads;
        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)

        # w stands for how many windows do we have in total
        # d is the hidden dim
        # we want to find the connection between all windows > ij
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
        # print('dots.size: ', dots.size()) # dots.size:  torch.Size([1, 3, 64, 49, 49])

        # ## if the cosine similiarity:
        # q = F.normalized(q, p=2, dim=-1)
        # k = F.normalized(k, p=2, dim=-1)
        # dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) / self.tau

        # we add the positional embeddings after dot
        if self.relative_pos_embedding:
            # first you find the indexes and then you put each parameter based on its indices
            # you have 13*13 params > [0,0] , [0,1], [0,2] , ...
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            # last row windows > would be the last 8 windows in stage 1
            dots[:, :, -nw_w:] += self.upper_lower_mask
            # last column windows
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)

        #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        return out


class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x) #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        x = self.mlp_block(x) #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        return x


class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)
        return x

class PatchMerging_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()

        # not overlapping kernels
        self.patch_merge = nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=downscaling_factor,
                                     stride=downscaling_factor,
                                     padding=0)

    def forward(self, x):
        print('x.size in patch_merging: ', x.size()) # (1, (3,96,192,384), (224,56,28,14), (224,56,28,14))
        #self.patch_merge(x) # (1, (96, 192, 384, 768), (56, 28, 14, 7), (56, 28, 14, 7))

        # conv2d layer format input: (N, C, H, W) >> why permuting for the output???
        x = self.patch_merge(x).permute(0, 2, 3, 1) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)) > b,h,w,c
        print('x.size after patch_merging: ', x.size())
        return x

class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        # layers > how many swin block in one stage > we have to assign the pairs of the regular and shifting ones > it must be even
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        # the values for each stage > bs,3,224,224 > bs,96,56,56 > bs,192,28,28 > bs,14,14,382
        # the format is #(1, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))
        # change the size and create the hierarchical > The input of the transformer has the same shape of the output of that
        # for stage one you can merge the patch_partition to the linear embedding > the target: change the dimensions
        # What is the difference between patch partition and the linear embedding ????
        x = self.patch_partition(x)
        #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        # for conv > b,c,h,w
        # for the multihead attention > b,h,w,c
        for regular_block, shifted_block in self.layers:
            x = regular_block(x) #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
            x = shifted_block(x) #(1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        return x.permute(0, 3, 1, 2) # Final output for the last stage (1, 768, 7, 7)


class SwinTransformer(nn.Module):
    # downscaling factor for the H and W > /4 > /8 > /16 > /32
    # window_size > last feature should have a size of 7 > H/32=7
    # channel sizes > hid_dim=96 > *2 > *4 > *8=768
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        # four stages of swin blocks
        # layers=(2, 2, 6, 2), heads=(3, 6, 12, 24)
        # initial hid_dim = 96 > *2 > *4 > *8
        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        # classification head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        # what is the shape of the img? for each block we are giving the image size > BS H W 3 > bs,3,224,224
        print("Input image shape:", img.shape)
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        # why mean??? [it's ok to feed the 1,768,49 to the MLP]
        print("X of the last stage", x.shape) #1, 768, 7, 7
        x = x.mean(dim=[2, 3])
        print("X after mean", x.shape) #1, 768
        return self.mlp_head(x)

In [16]:
# layers > how many swin blocks in each stage
# input_channel = 3
# but the hidden_dim = 96 > C in the paper
# heads is for the number of heads
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


model = swin_t(num_classes=3)
sample = torch.randn(1, 3, 224, 224)
output = model(sample)
print(output.shape)
print(output)


Input image shape: torch.Size([1, 3, 224, 224])
x.size in patch_merging:  torch.Size([1, 3, 224, 224])
x.size after patch_merging:  torch.Size([1, 56, 56, 96])
dots.size:  torch.Size([1, 3, 64, 49, 49])
dots.size:  torch.Size([1, 3, 64, 49, 49])
x.size in patch_merging:  torch.Size([1, 96, 56, 56])
x.size after patch_merging:  torch.Size([1, 28, 28, 192])
dots.size:  torch.Size([1, 6, 16, 49, 49])
dots.size:  torch.Size([1, 6, 16, 49, 49])
x.size in patch_merging:  torch.Size([1, 192, 28, 28])
x.size after patch_merging:  torch.Size([1, 14, 14, 384])
dots.size:  torch.Size([1, 12, 4, 49, 49])
dots.size:  torch.Size([1, 12, 4, 49, 49])
dots.size:  torch.Size([1, 12, 4, 49, 49])
dots.size:  torch.Size([1, 12, 4, 49, 49])
dots.size:  torch.Size([1, 12, 4, 49, 49])
dots.size:  torch.Size([1, 12, 4, 49, 49])
x.size in patch_merging:  torch.Size([1, 384, 14, 14])
x.size after patch_merging:  torch.Size([1, 7, 7, 768])
dots.size:  torch.Size([1, 24, 1, 49, 49])
dots.size:  torch.Size([1, 24, 

In [4]:

from torchsummary import summary
# Total params: 28,247,560
summary(model, (3, 224, 224))

Input image shape: torch.Size([2, 3, 224, 224])
x.size in patch_merging:  torch.Size([2, 3, 224, 224])
x.size after patch_merging:  torch.Size([2, 56, 56, 96])
x.size in patch_merging:  torch.Size([2, 96, 56, 56])
x.size after patch_merging:  torch.Size([2, 28, 28, 192])
x.size in patch_merging:  torch.Size([2, 192, 28, 28])
x.size after patch_merging:  torch.Size([2, 14, 14, 384])
x.size in patch_merging:  torch.Size([2, 384, 14, 14])
x.size after patch_merging:  torch.Size([2, 7, 7, 768])
X of the last stage torch.Size([2, 768, 7, 7])
X after mean torch.Size([2, 768])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
 PatchMerging_Conv-2           [-1, 56, 56, 96]               0
         LayerNorm-3           [-1, 56, 56, 96]             192
            Linear-4          [-1, 56, 56, 288]          27,648
            Linear-5           [-1, 56

In [7]:
torch.manual_seed(0)
B, H, W, C = 1, 2, 2, 3
input = torch.randn(B, H, W, C) * 100
print('input: ', input.shape)
print(input)
layer_norm = nn.LayerNorm(C)
output = layer_norm(input)
print('output: ', output.shape)
print(output)

# Layer norm has the gamma and beta that these can be the parameters for that
x_mean = input.mean(dim=-1, keepdim=True)
x_std = input.std(dim=-1, keepdim=True)
# normalize the input based on these values
x_ = (input - x_mean) / x_std
print('x_: ', x_.shape)
print(x_)

input:  torch.Size([1, 2, 2, 3])
tensor([[[[ 154.0996,  -29.3429, -217.8789],
          [  56.8431, -108.4522, -139.8595]],

         [[  40.3347,   83.8026,  -71.9258],
          [ -40.3344,  -59.6635,   18.2036]]]])
output:  torch.Size([1, 2, 2, 3])
tensor([[[[ 1.2191,  0.0112, -1.2303],
          [ 1.3985, -0.5173, -0.8813]],

         [[ 0.3495,  1.0120, -1.3615],
          [-0.3948, -0.9787,  1.3735]]]], grad_fn=<NativeLayerNormBackward0>)
x_:  torch.Size([1, 2, 2, 3])
tensor([[[[ 0.9954,  0.0091, -1.0045],
          [ 1.1419, -0.4223, -0.7195]],

         [[ 0.2854,  0.8263, -1.1117],
          [-0.3223, -0.7991,  1.1214]]]])


In [12]:
new_mask = create_mask(window_size=7, displacement=3, upper_lower=True, left_right=False)
print(new_mask)
print(new_mask.shape)

tensor([[0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [-inf, -inf, -inf,  ..., 0., 0., 0.],
        [-inf, -inf, -inf,  ..., 0., 0., 0.],
        [-inf, -inf, -inf,  ..., 0., 0., 0.]])
torch.Size([49, 49])


In [23]:
relative_indexes = get_relative_distances(3)
print(relative_indexes[:,:,0])
print("----")
print(relative_indexes[:,:,1])

distance torch.Size([9, 9, 2])
tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0]])
----
tensor([[ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0]])


In [28]:
# pos_embedding       # (13, 13)
p = torch.tensor([[1, 2],
                  [3, 4]])
print(p.size())

# relative_indices is (49, 49, 2)
r = torch.tensor([[[0,0],[0,0],[0,0]],
                  [[1,1],[1,1],[1,1]],
                  [[0,1],[0,1],[0,1]]])
print(r.size())

# self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] #(49, 49)
print(p[r[:,:,0], r[:,:,1]])

torch.Size([2, 2])
torch.Size([3, 3, 2])
tensor([[1, 1, 1],
        [4, 4, 4],
        [2, 2, 2]])
