# Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from torchvision.ops.stochastic_depth import StochasticDepth # Add stochastic depth

# Patch Partition + Linear Embedding

In [3]:
class SwinEmbedding(nn.Module):

  """
  input shape -> (b,c,h,w)
  output shape -> (b, (h/4 * w/4), C)

  Where:

  b - batch size
  c - number of channels
  h - height of the image
  w - width of the image
  C - number of channels in the output

  """

  def __init__(self, patch_size = 4, C = 96):
      super().__init__()
      self.linear_embedding = nn.Conv2d(3,C, kernel_size=patch_size, stride=patch_size)
      self.layer_norm = nn.LayerNorm(C)
      self.relu = nn.ReLU() # Tej funkcji aktywacji nie ma w oryginalnym modelu, ale jest ona potrzebna do poprawnego działania modelu

  
  def forward(self,x):
    x = self.linear_embedding(x)
    x = rearrange(x, 'b c h w -> b (h w) c')  # spłaszczenie wymiarów przestrzennych obrazu przy pomocy mnożenia h i w
    x = self.layer_norm(x) # normalizacja
    x = self.relu(x) # funkcja aktywacji (dodanie nieliniowości) (nie ma jej w oryginalnym modelu)

    return x



# Patch Merging Layer

In [4]:
class PatchMerging(nn.Module):

  """
  Reduces tokens by a factor of 4 (2x2 patches) and doubles embedding dimension.


  input shape -> (b (h w) c)
  output shape -> (b (h/2 * w/2) C*2)

  Where:

  b - batch size
  c - number of channels
  h - height of the image
  w - width of the image

  """

  def __init__(self, C) -> None:
     super().__init__()
     self.linear_layer = nn.Linear(C*4, C*2) # podwajamy wymiar embeddingów
     self.layer_norm = nn.LayerNorm(2 * C) # normalizacja

  def forward(self, x):
    height = width = int(math.sqrt(x.shape[1])/ 2) # obliczamy nową wysokość i szerokość obrazu
    x = rearrange(x, 'b (h s1 w s2) c -> b (h w) (s2 s1 c)', s1=2, s2=2, h=height, w=width)
    x = self.linear_layer(x)
    x = self.layer_norm(x)
    return x

# Shifted Window Attention Mechanism (TODO)

In [5]:
class ShiftedWindowMSA(nn.Module):

    """
    input shape -> (b , (h*w), C)
    output shape -> (b , (h*w), C)

    Where:

    b - batch size
    h - height of the image
    w - width of the image
    C - number of channels in the output
    """
      
    def __init__(self, embed_dim, num_heads, window_size=7, mask=False):
        super().__init__()
        self.embed_dim = embed_dim # wymiar embeddingów
        self.num_heads = num_heads # liczba głów
        self.window_size = window_size # rozmiar okna
        self.mask = mask # maska (True/False)
        self.proj1 = nn.Linear(embed_dim, 3*embed_dim) # projekcja wejścia
        self.proj2 = nn.Linear(embed_dim, embed_dim) # projekcja wyjścia
        self.embeddings = RelativeEmbeddings() 

    def forward(self, x):
        h_dim = self.embed_dim / self.num_heads # obliczamy wymiar pojedynczej głowy
        height = width = int(math.sqrt(x.shape[1])) 
        x = self.proj1(x) 
        x = rearrange(x, 'b (h w) (c K) -> b h w c K', K=3, h=height, w=width) # zmiana wymiarów, gdzie K to liczba macierzy Q,K,V
 
        if self.mask: # jeśli maska jest True, to wykonujemy przesunięcie okna o połowę
            x = torch.roll(x, (-self.window_size//2, -self.window_size//2), dims=(1,2))

        # zmiana wymiarów
        x = rearrange(x, 'b (h m1) (w m2) (H E) K -> b H h w (m1 m2) E K', H=self.num_heads, m1=self.window_size, m2=self.window_size)
       
        # podział na macierze Q,K,V
        Q, K, V = x.chunk(3, dim=6)
        Q, K, V = Q.squeeze(-1), K.squeeze(-1), V.squeeze(-1)
        attention_scores = (Q @ K.transpose(4,5)) / math.sqrt(h_dim) # obliczamy self-attention score
        attention_scores = self.embeddings(attention_scores) # dodajemy embeddingsy

        '''
        H - attention heads 
        h,w - vertical and horizontal dimensions of the image
        (m1 m2) - total size of the window
        E - head dimension
        K = 3 - constant to break our matrix into 3 Q,K,V matricies
      
        shape of attention_scores = (b, H, h, w, (m1*m2), (m1*m2))
        we simply have to generate our row/column masks and apply them
        to the last row and columns of windows which are [:,:,-1,:] and [:,:,:,-1]
        
        '''

        if self.mask: # jeśli maska jest True, to wykonujemy maskowanie ostatnich wierszy i kolumn w oknie 
            row_mask = torch.zeros((self.window_size**2, self.window_size**2)).cuda()
            row_mask[-self.window_size * (self.window_size//2):, 0:-self.window_size * (self.window_size//2)] = float('-inf')
            row_mask[0:-self.window_size * (self.window_size//2), -self.window_size * (self.window_size//2):] = float('-inf')
            column_mask = rearrange(row_mask, '(r w1) (c w2) -> (w1 r) (w2 c)', w1=self.window_size, w2=self.window_size).cuda()
            attention_scores[:, :, -1, :] += row_mask
            attention_scores[:, :, :, -1] += column_mask

        attention = F.softmax(attention_scores, dim=-1) @ V # Softmax i mnożenie przez V 
        x = rearrange(attention, 'b H h w (m1 m2) E -> b (h m1) (w m2) (H E)', m1=self.window_size, m2=self.window_size)

        if self.mask: # Z powrotem przesuwamy okno o połowę
            x = torch.roll(x, (self.window_size//2, self.window_size//2), (1,2))

        x = rearrange(x, 'b h w c -> b (h w) c')
        return self.proj2(x) # projekcja wyjścia


# Relative Position Embeddings (TODO)

In [6]:
class RelativeEmbeddings(nn.Module):
    def __init__(self, window_size=7):
        super().__init__()
        B = nn.Parameter(torch.randn(2*window_size-1, 2*window_size-1))
        x = torch.arange(1,window_size+1,1/window_size)
        x = (x[None, :]-x[:, None]).int()
        y = torch.concat([torch.arange(1,window_size+1)] * window_size)
        y = (y[None, :]-y[:, None])
        self.embeddings = nn.Parameter((B[x[:,:], y[:,:]]), requires_grad=False)

    def forward(self, x):
        return x + self.embeddings

# Swin Transformer Block v2 

The main difference of the Swin Transformer block of the second version is the change of the normalization order. The normalization layer was moved before the skip connection adder, which reduced the amplitude of activations and provided more stable and efficient learning.

In addition, a stochastic drop path operation was added to the block to improve regularization. This is especially important for deep models and transformers, where this approach has been shown to perform better according to research.

Initialization of weights and biases was also introduced, which promotes stable learning and accelerates model convergence due to correct distribution of initial parameters

![image.png](/home/wladyka/Swin-Transformer/images/swin_transformer_block_v2.png)

In [7]:
class SwinEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, mask):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.stochastic_depth = StochasticDepth(0.2, "row") # Stochastic Depth with 0.2 probability of dropping out a row for tiny version of Swin Transformer

        self.WMSA = ShiftedWindowMSA(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=mask)
        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Dropout(p=0.1), # Default dropout probability is 0.0 in the torchvision implementation
            nn.Linear(embed_dim*4, embed_dim)
        )

        # Initialization of weights and biases (bias) in linear layers 
        for m in self.MLP:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight) # Xavier initialization for weights, which prevents the disappearance or explosion of gradients during training.
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6) # Set a small offset, to have a small impact in the initial stages of training.

    def forward(self, x):
        
        # Attention path with pre-normalization 
        res1 = x # Save input for the skip connection
        x = self.stochastic_depth(self.layer_norm(self.WMSA(x))) # Attention block with LayerNorm and Stochastic Depth(more efficient than Dropout for training transformers)
        x = res1 + x # Skip connection

        # MLP path with pre-normalization
        res2 = x  # Save intermediate result for skip connection
        x = self.stochastic_depth(self.layer_norm(self.MLP(x))) # MLP block with LayerNorm and Dropout
        x = res2 + x  # Skip connection

        return x
    
class AlternatingEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size=7):
        super().__init__()
        self.WSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=False)
        self.SWSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=True)
    
    def forward(self, x):
        return self.SWSA(self.WSA(x))

# Final Swin-Transformer Class (Review) 

In [8]:
class SwinTransformerTiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.Embedding = SwinEmbedding() # Embedding layer
        self.PatchMerge1 = PatchMerging(96)
        self.PatchMerge2 = PatchMerging(192)
        self.PatchMerge3 = PatchMerging(384)
        self.Stage1 = AlternatingEncoderBlock(96, 3)
        self.Stage2 = AlternatingEncoderBlock(192, 6)
        self.Stage3_1 = AlternatingEncoderBlock(384, 12)
        self.Stage3_2 = AlternatingEncoderBlock(384, 12)
        self.Stage3_3 = AlternatingEncoderBlock(384, 12)
        self.Stage4 = AlternatingEncoderBlock(768, 24)

    def forward(self, x):
        x = self.Embedding(x)
        x = self.PatchMerge1(self.Stage1(x))
        x = self.PatchMerge2(self.Stage2(x))
        x = self.Stage3_1(x)
        x = self.Stage3_2(x)
        x = self.Stage3_3(x)
        x = self.PatchMerge3(x)
        x = self.Stage4(x)
        return x

In [9]:
def main():
    x = torch.randn((1,3,224,224)).cuda()
    model = SwinTransformerTiny().cuda()
    print(model(x).shape)

if __name__ == '__main__':
    main()

torch.Size([1, 49, 768])
