# Imports

In [91]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from torchvision.ops.stochastic_depth import StochasticDepth # Add stochastic depth
from typing import List 

# Patch Partition + Linear Embedding

In [92]:
class SwinEmbedding(nn.Module):

  """
  input shape -> (b,c,h,w)
  output shape -> (b, (h/4 * w/4), C)

  Where:

  b - batch size
  c - number of channels
  h - height of the image
  w - width of the image
  C - number of channels in the output

  """

  def __init__(self, patch_size = 4, C = 96):
      super().__init__()
      self.linear_embedding = nn.Conv2d(3,C, kernel_size=patch_size, stride=patch_size)
      self.layer_norm = nn.LayerNorm(C)
      self.relu = nn.ReLU() # Tej funkcji aktywacji nie ma w oryginalnym modelu, ale jest ona potrzebna do poprawnego działania modelu

  
  def forward(self,x):
    x = self.linear_embedding(x)
    x = rearrange(x, 'b c h w -> b (h w) c')  # spłaszczenie wymiarów przestrzennych obrazu przy pomocy mnożenia h i w
    x = self.layer_norm(x) # normalizacja
    x = self.relu(x) # funkcja aktywacji (dodanie nieliniowości) (nie ma jej w oryginalnym modelu)

    return x



# Patch Merging Layer

In [93]:
class PatchMerging(nn.Module):

  """
  Reduces tokens by a factor of 4 (2x2 patches) and doubles embedding dimension.


  input shape -> (b (h w) c)
  output shape -> (b (h/2 * w/2) C*2)

  Where:

  b - batch size
  c - number of channels
  h - height of the image
  w - width of the image

  """

  def __init__(self, C) -> None:
     super().__init__()
     self.linear_layer = nn.Linear(C*4, C*2) # podwajamy wymiar embeddingów
     self.layer_norm = nn.LayerNorm(2 * C) # normalizacja

  def forward(self, x):
    height = width = int(math.sqrt(x.shape[1])/ 2) # obliczamy nową wysokość i szerokość obrazu
    x = rearrange(x, 'b (h s1 w s2) c -> b (h w) (s2 s1 c)', s1=2, s2=2, h=height, w=width)
    x = self.linear_layer(x)
    x = self.layer_norm(x)
    return x

# Shifted Window Attention Mechanism (TODO)

In [101]:
class ShiftedWindowMSA(nn.Module):

    """
    input shape -> (b , (h*w), C)
    output shape -> (b , (h*w), C)

    Where:

    b - batch size
    h - height of the image
    w - width of the image
    C - number of channels in the output
    """
      
    def __init__(self, embed_dim, num_heads, window_size=7, mask=False):
        super().__init__()
        self.embed_dim = embed_dim # wymiar embeddingów
        self.num_heads = num_heads # liczba głów
        self.window_size = window_size # rozmiar okna
        self.mask = mask # maska (True/False)
        self.qkv = nn.Linear(embed_dim, 3*embed_dim) # projekcja wejścia
        self.proj = nn.Linear(embed_dim, embed_dim) # projekcja wyjścia
        self.embeddings = RelativeEmbeddings()
        self.embeddings_v2 = RelativeEmbeddingsv2()

    def forward(self, x):
        h_dim = self.embed_dim / self.num_heads # obliczamy wymiar pojedynczej głowy
        height = width = int(math.sqrt(x.shape[1])) 
        x = self.qkv(x) 
        x = rearrange(x, 'b (h w) (c K) -> b h w c K', K=3, h=height, w=width) # zmiana wymiarów, gdzie K to liczba macierzy Q,K,V
 
        if self.mask: # jeśli maska jest True, to wykonujemy przesunięcie okna o połowę
            x = torch.roll(x, (-self.window_size//2, -self.window_size//2), dims=(1,2))

        # zmiana wymiarów
        x = rearrange(x, 'b (h m1) (w m2) (H E) K -> b H h w (m1 m2) E K', H=self.num_heads, m1=self.window_size, m2=self.window_size)
       
        # podział na macierze Q,K,V
        Q, K, V = x.chunk(3, dim=6)
        Q, K, V = Q.squeeze(-1), K.squeeze(-1), V.squeeze(-1)
        attention_scores = (Q @ K.transpose(4,5)) / math.sqrt(h_dim) # obliczamy self-attention score
        # print(attention_scores.shape)
        # print("add embeddings")
        # print( self.embeddings(attention_scores).shape)
        attention_scores = self.embeddings(attention_scores) # dodajemy embeddingsy

        #print(self.embeddings_v2(attention_scores).shape)


        '''
        H - attention heads 
        h,w - vertical and horizontal dimensions of the image
        (m1 m2) - total size of the window
        E - head dimension
        K = 3 - constant to break our matrix into 3 Q,K,V matricies
      
        shape of attention_scores = (b, H, h, w, (m1*m2), (m1*m2))
        we simply have to generate our row/column masks and apply them
        to the last row and columns of windows which are [:,:,-1,:] and [:,:,:,-1]
        
        '''

        if self.mask: # jeśli maska jest True, to wykonujemy maskowanie ostatnich wierszy i kolumn w oknie 
            row_mask = torch.zeros((self.window_size**2, self.window_size**2)).cuda()
            row_mask[-self.window_size * (self.window_size//2):, 0:-self.window_size * (self.window_size//2)] = float('-inf')
            row_mask[0:-self.window_size * (self.window_size//2), -self.window_size * (self.window_size//2):] = float('-inf')
            column_mask = rearrange(row_mask, '(r w1) (c w2) -> (w1 r) (w2 c)', w1=self.window_size, w2=self.window_size).cuda()
            attention_scores[:, :, -1, :] += row_mask
            attention_scores[:, :, :, -1] += column_mask

        attention = F.softmax(attention_scores, dim=-1) @ V # Softmax i mnożenie przez V 
        x = rearrange(attention, 'b H h w (m1 m2) E -> b (h m1) (w m2) (H E)', m1=self.window_size, m2=self.window_size)

        if self.mask: # Z powrotem przesuwamy okno o połowę
            x = torch.roll(x, (self.window_size//2, self.window_size//2), (1,2))

        x = rearrange(x, 'b h w c -> b (h w) c')
        return self.proj(x) # projekcja wyjścia


# Relative Position Embeddings (TODO)

In [95]:
class RelativeEmbeddings(nn.Module):
    def __init__(self, window_size=7):
        super().__init__()
        B = nn.Parameter(torch.randn(2*window_size-1, 2*window_size-1))
        x = torch.arange(1,window_size+1,1/window_size)
        x = (x[None, :]-x[:, None]).int()
        y = torch.concat([torch.arange(1,window_size+1)] * window_size)
        y = (y[None, :]-y[:, None])
        self.embeddings = nn.Parameter((B[x[:,:], y[:,:]]), requires_grad=False)

    def forward(self, x):
        print(f"RelativeEmbeddings shape: {self.embeddings.shape} + x shape: {x.shape} = {(self.embeddings + x).shape}")
        return  self.embeddings  + x 

In [96]:
class RelativeEmbeddingsv2(nn.Module):
    """
    See :func:`shifted_window_attention`.
    """

    def __init__(
        self,
        dim = 96,
        window_size = [7,7],
        qkv_bias: bool = True,
        attention_dropout: float = 0.0,
        dropout: float = 0.0,
        num_heads: int = 8,
    ):
        super().__init__()

        self.window_size = window_size
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.dropout = dropout

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.define_relative_position_bias_table()
        self.define_relative_position_index()


        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
        
        # mlp to generate continuous relative position bias
        self.cpb_mlp = nn.Sequential(
                nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
            )
        if qkv_bias:
            length = self.qkv.bias.numel() // 3
            self.qkv.bias[length : 2 * length].data.zero_()

    def _get_relative_position_bias(
        relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
    ) -> torch.Tensor:
        N = window_size[0] * window_size[1]
        relative_position_bias = relative_position_bias_table[relative_position_index]  # type: ignore[index]
        relative_position_bias = relative_position_bias.view(N, N, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
        return relative_position_bias

    def define_relative_position_bias_table(self):
        # get relative_coords_table
        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
        relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
        relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2

        relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
        relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1

        relative_coords_table *= 8  # normalize to -8, 8
        relative_coords_table = (
            torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
        )
        self.register_buffer("relative_coords_table", relative_coords_table)

    def define_relative_position_index(self):
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1).flatten()  # Wh*Ww*Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

    def get_relative_position_bias(self) -> torch.Tensor:
        relative_position_bias = _get_relative_position_bias(
            self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
            self.relative_position_index,  # type: ignore[arg-type]
            self.window_size,
        )
        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
        return relative_position_bias


    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
        """
        relative_position_bias = self.get_relative_position_bias()
        print(f"RelativeEmbeddingsV2 shape: {relative_position_bias.shape} + x shape: {x.shape} = {(relative_position_bias + x).shape}")
        return relative_position_bias + x

# Swin Transformer Block v2 

The main difference of the Swin Transformer block of the second version is the change of the normalization order. The normalization layer was moved before the skip connection adder, which reduced the amplitude of activations and provided more stable and efficient learning.

In addition, a stochastic drop path operation was added to the block to improve regularization. This is especially important for deep models and transformers, where this approach has been shown to perform better according to research.

Initialization of weights and biases was also introduced, which promotes stable learning and accelerates model convergence due to correct distribution of initial parameters

![image.png](/home/wladyka/Swin-Transformer/images/swin_transformer_block_v2.png)

In [97]:
class SwinEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, mask, sd_prob=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.stochastic_depth = StochasticDepth(sd_prob, "row") # Stochastic Depth with 0.1 probability of dropping out a row for tiny version of Swin Transformer

        self.WMSA = ShiftedWindowMSA(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=mask)
        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Dropout(p=0.1), # Default dropout probability is 0.0 in the torchvision implementation
            nn.Linear(embed_dim*4, embed_dim)
        )

        # Initialization of weights and biases (bias) in linear layers 
        for m in self.MLP:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight) # Xavier initialization for weights, which prevents the disappearance or explosion of gradients during training.
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6) # Set a small offset, to have a small impact in the initial stages of training.

    def forward(self, x):
        
        # Attention path with pre-normalization 
        res1 = x # Save input for the skip connection
        x = self.stochastic_depth(self.layer_norm(self.WMSA(x))) # Attention block with LayerNorm and Stochastic Depth(more efficient than Dropout for training transformers)
        x = res1 + x # Residual connection

        # MLP path with pre-normalization
        res2 = x  # Save intermediate result for skip connection
        x = self.stochastic_depth(self.layer_norm(self.MLP(x))) # MLP block with LayerNorm and Dropout
        x = res2 + x  # Residual connection

        return x
    
class AlternatingEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, sd_prob, window_size=7):
        super().__init__()
        self.WSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=False, sd_prob=sd_prob[0])
        self.SWSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=True, sd_prob=sd_prob[1])
    
    def forward(self, x):
        return self.SWSA(self.WSA(x))

# Final Swin-Transformer Class v2 

In [98]:
class SwinTransformerTiny(nn.Module):
    def __init__(self, depth=[2, 2, 6, 2], embed_dim=96, stochastic_depth_prob=0.2):
        super().__init__()
        self.Embedding = SwinEmbedding()  # Embedding layer

        # Calculate total number of blocks
        total_stage_blocks = sum(depth)
        stage_block_id = 0

        self.stages = nn.ModuleList()

        in_channels = embed_dim
        for i_stage, num_blocks in enumerate(depth):
            temp_sd_prob = []
            for _ in range(num_blocks):
                # Calculate probability for the current layer
                sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
                temp_sd_prob.append(sd_prob)
                stage_block_id += 1

            #Add alternating encoder blocks recording to the depth list divided by 2, because each block has 2 sub-blocks
            sd_prob = [temp_sd_prob[i:i+2] for i in range(0, len(temp_sd_prob), 2)]
            for _ in range(int(num_blocks / 2)):
                num_heads = in_channels // 32
                #print(f"AlternatingEncoderBlock({in_channels}, {num_heads}, {sd_prob[0]})") # Debug
                self.stages.append(
                    AlternatingEncoderBlock(in_channels, num_heads, sd_prob[0])
                )
                sd_prob.pop(0)
                    
            # Add patch merging layer if this is not the last stage
            if i_stage < len(depth) - 1:
                self.stages.append(PatchMerging(in_channels))
                #print(f"PatchMerging({in_channels})") # Debug
                in_channels *= 2

    def forward(self, x):
        x = self.Embedding(x) 
        for stage in self.stages: 
            x = stage(x)
        
        return x


In [102]:
def main():
    x = torch.randn((1,3,224,224)).cuda()
    model = SwinTransformerTiny().cuda()
    print(model(x).shape)
    # model = RelativeEmbeddingsv2().cuda()
    # print(model(x).shape)
    # # model = RelativeEmbeddings().cuda()
    # # print(model(x).shape)

if __name__ == '__main__':
    main()

RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 3, 8, 8, 49, 49]) = torch.Size([1, 3, 8, 8, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 3, 8, 8, 49, 49]) = torch.Size([1, 3, 8, 8, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 6, 4, 4, 49, 49]) = torch.Size([1, 6, 4, 4, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 6, 4, 4, 49, 49]) = torch.Size([1, 6, 4, 4, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 12, 2, 2, 49, 49]) = torch.Size([1, 12, 2, 2, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 12, 2, 2, 49, 49]) = torch.Size([1, 12, 2, 2, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 12, 2, 2, 49, 49]) = torch.Size([1, 12, 2, 2, 49, 49])
RelativeEmbeddings shape: torch.Size([49, 49]) + x shape: torch.Size([1, 12, 2, 2, 49, 49]) = torch.Size([1, 12,