# Imports

In [129]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from torchvision.ops.stochastic_depth import StochasticDepth # Add stochastic depth

# Patch Partition + Linear Embedding

In [130]:
class SwinEmbedding(nn.Module):

  """
  input shape -> (b,c,h,w)
  output shape -> (b, h/4 , w/4, C)

  Where:

  b - batch size
  h - height of the image
  w - width of the image
  C - number of channels

  """

  def __init__(self, patch_size = 4, C = 96):
      super().__init__()
      self.linear_embedding = nn.Conv2d(3,C, kernel_size=patch_size, stride=patch_size)
      self.layer_norm = nn.LayerNorm(C)
      self.relu = nn.ReLU() # activation function (not present in the torchvision model)

  
  def forward(self,x):
    x = self.linear_embedding(x) # image partitioning into patches
    x = rearrange(x, 'b c h w -> b h w c')  # change the shape of the tensor
    x = self.layer_norm(x) 
    x = self.relu(x) # activation function (not present in the torchvision model)

    return x



# Patch Merging Layer

In [131]:
class PatchMerging(nn.Module):

  """
  Reduces tokens by a factor of 4 (2x2 patches) and doubles embedding dimension.


  input shape -> (b h w c)
  output shape -> (b h/2 w/2 C*2)

  Where:

  b - batch size
  c - number of channels
  h - height of the image
  w - width of the image

  """

  def __init__(self, C) -> None:
     super().__init__()
     self.linear_layer = nn.Linear(C*4, C*2) # Doubles the embedding dimension
     self.layer_norm = nn.LayerNorm(2 * C) # Layer normalization

  def forward(self, x):
    x = rearrange(x, 'b (h ph) (w pw) c -> b h w (ph pw c)', ph=2, pw=2) # Merge patches and double the embedding dimension
    x = self.linear_layer(x) 
    x = self.layer_norm(x)
    return x

# Shifted Window Attention Mechanism

In order for the model to effectively process input data, the feature maps are padded before applying the attention mechanism. This is necessary to ensure that the feature map dimensions are divisible by the window size, as the model operates by partitioning the input into non-overlapping windows, and padding ensures that all elements are considered. After padding, the feature map is divided into windows, allowing attention to be computed within each window. This step prevents potential distortions that could occur if the window sizes were not uniform and ensures stable operation of the model.

Next, to address the issue of a few pixel pairs dominating the self-attention mechanism when the model size increases, the method of **scaled cosine attention** is proposed. Instead of computing attention using the standard dot product of the query and key vectors, attention is computed using the cosine similarity between these vectors. This similarity is then scaled by a learnable parameter $\tau$, which is different for each head and layer. This scaling helps control the magnitude of the attention values, ensuring the stability of the model, especially in larger versions. The formula for computing the similarity between pixels $i$ and $j$ is as follows:

$$

\text{Sim}(q_i, k_j) = \frac{\cos(q_i, k_j)}{\tau} + B_{ij}

$$

where $q_i$ and $k_j$ are the query and key vectors, and $\tau$ is the learnable scaling factor. An important addition is $B_{ij}$, which represents the relative position bias. This bias is added to account for the relative positions of pixels within the window, allowing the model to better handle varying window sizes.

Thus, the relative position bias $B_{ij}$ plays a key role in improving the attention computation. Unlike the fixed positional encoding used in traditional transformers, this bias is computed dynamically for each pair of pixels based on their relative positions within the window. Introducing **relative position bias** into the attention formula enables the model to be more adaptive when working with images of different resolutions and window sizes, improving its performance and stability.

![image](../images/Pasted%20image%20(2).png)

In [132]:
class ShiftedWindowMSA(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size=8, mask=False, attention_dropout=0.0, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.mask = mask # mask (True/False)
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.proj_dropout = nn.Dropout(dropout)
        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))

        self.relative_embeddings = RelativeEmbeddingsV2(window_size, num_heads)

    def forward(self, input):
        
        B, H, W, C = input.shape

        # pad feature maps to multiples of window size
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
        _, pad_H, pad_W, _ = x.shape
       
        # Cyclic shift
        if self.mask:
            x = torch.roll(x, (-self.window_size//2, -self.window_size//2), dims=(1,2))

        # Partition windows
        num_windows = (pad_H //self.window_size) * (pad_W // self.window_size)
        x = rearrange(
                    x, 
                    'b (h w_h) (w w_w) c -> (b h w) (w_h w_w) c', 
                    w_h=self.window_size, w_w=self.window_size
                )

        # QKV computation
        qkv = F.linear(x, self.qkv.weight)
        qkv = qkv.reshape(x.size(0), x.size(1), 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention with logit scaling 
        attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)  # cos(qi, kj), cos - normalizetion
        logit_scale = torch.clamp(self.logit_scale, max=math.log(100.0)).exp() # tau 
        attn = attn * logit_scale  # cos(qi, kj) / tau

        # Add relative position bias
        relative_position_bias = self.relative_embeddings()
        attn = attn + relative_position_bias  # cos(qi, kj) / tau + b_ij
       
        if self.mask:
            # Create attention mask
            attn_mask = torch.zeros((pad_H, pad_W), device=x.device)
            
            # Generate coordinates for the mask
            for i in range(0, pad_H, self.window_size):
                for j in range(0, pad_W, self.window_size):
                    attn_mask[i:i + self.window_size, j:j + self.window_size] += 1
            
            # Create mask for each window
            attn_mask = rearrange(
                attn_mask, 
                '(h winh) (w winw) -> (h w) (winh winw)', 
                winh=self.window_size, 
                winw=self.window_size
            )

            # Create mask for each window
            attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)  # Shape: (num_windows, window_size^2, window_size^2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, 0.0)

            # Add a dimension for num_heads
            attn_mask = attn_mask.unsqueeze(1)  # Shape: (num_windows, 1, window_size^2, window_size^2)

            # Broadcast over batch and num_heads
            attn = attn.view(-1, num_windows, self.num_heads, x.size(1), x.size(1))
            attn = attn + attn_mask.unsqueeze(0)  # Broadcasting over batch and num_heads
            attn = attn.view(-1, self.num_heads, x.size(1), x.size(1))


        attn = attn.softmax(dim=-1)
        attn = self.attention_dropout(attn)

        # Attention output
        x = (attn @ v).transpose(1, 2).reshape(B, -1, C) # attn = softmax(cos(qi, kj) / tau + b_ij) @ v
        x = self.proj(x)
        x = self.proj_dropout(x)

        # Reverse cyclic shift
        x = rearrange(
            x, 
            'b (h ws1 w ws2) c -> b (h ws1) (w ws2) c', 
            ws1=self.window_size, 
            ws2=self.window_size,
            h = pad_H // self.window_size,
            w = pad_W // self.window_size
        )
        if self.mask: 
            x = torch.roll(x, (self.window_size//2, self.window_size//2), (1,2))

        #unpad features
        x = x[:, :H, :W, :].contiguous()
        return x

# Relative Position Embeddings 

In the paper, **relative position bias** is introduced as a mechanism to handle varying window resolutions effectively. The traditional approach to relative position bias, which directly optimizes parameterized biases, can become problematic when transferring models across different window sizes. To address this, the paper proposes a **log-spaced continuous position bias approach** that can be smoothly transferred to fine-tuning tasks with arbitrary window sizes.

The main idea is to use a **meta network**, denoted as $G(\Delta x, \Delta y)$, which generates bias values for relative coordinates. Instead of directly optimizing biases, this network computes biases based on the relative positions of elements in the window. In this way, the generated biases can be adapted to different window sizes without requiring retraining. The continuous position bias approach allows the bias values to be precomputed and stored as model parameters, making inference efficient and consistent with the original parameterized bias approach.

Moreover, the paper addresses the challenge of **extrapolating biases** when transferring models across significantly different window sizes. The original approach used linearly spaced coordinates, which led to large extrapolation ratios. To mitigate this, the paper introduces **log-spaced coordinates**, which reduce the required extrapolation when transferring biases across window resolutions. The transformation from linear coordinates $\Delta x, \Delta y$ to log-spaced coordinates $\Delta cx$, $\Delta cy$ is defined as:


$$
\Delta cx = \text{sign}(x) \cdot \log(1 + | \Delta x |),
\quad \Delta cy = \text{sign}(y) \cdot \log(1 + | \Delta y |).
$$


This log transformation helps to keep the extrapolation ratio smaller, which in turn makes the transfer of relative position biases more stable and effective across different window sizes.

![image](../images/Pasted%20image%20(3).png)


In [133]:
class RelativeEmbeddingsV2(nn.Module):
    def __init__(self, window_size, num_heads):
        super().__init__()
        self.window_size = window_size  # Size of the window (e.g., 8x8 or 16x16)
        self.num_heads = num_heads  # Number of attention heads

        # Define the MLP used for continuous position bias
        self.cpb_mlp = nn.Sequential(
            nn.Linear(2, 512, bias=True),  # First linear layer (maps 2D position to 512 features)
            nn.ReLU(inplace=True),         # ReLU activation
            nn.Linear(512, num_heads, bias=False),  # Second linear layer (maps 512 features to num_heads)
        )

        # Initialize relative coordinates and relative position index
        self.define_relative_coords()
        self.define_relative_position_index()

    def define_relative_coords(self):
        """
        This method defines the relative coordinates for each pixel within the window. 
        It creates a grid of relative coordinates and applies log transformation.
        """
        # Generate range of relative coordinates along height and width (e.g., -7 to 7 for 8x8 window)
        relative_coords_h = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)
        relative_coords_w = torch.arange(-(self.window_size - 1), self.window_size, dtype=torch.float32)

        # Create a meshgrid of relative coordinates for height and width
        relative_coords_table = torch.stack(torch.meshgrid(relative_coords_h, relative_coords_w, indexing="ij"))

        # Permute to match the required shape (2D, 2D, 2 -> 1, 2, 2, 2)
        relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0)

        # Normalize the relative coordinates to the range [-1, 1] for both axes
        relative_coords_table[..., 0] /= (self.window_size - 1)
        relative_coords_table[..., 1] /= (self.window_size - 1)

        # Apply scaling to the coordinates (multiply by 8)
        relative_coords_table *= 8

        # Apply the log transformation to the relative coordinates
        relative_coords_table = (
            torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
        )
        
        # Register the relative coordinates table as a buffer to be used during training
        self.register_buffer("relative_coords_table", relative_coords_table)

    def define_relative_position_index(self):
        """
        This method defines the relative position index for each pixel pair in the window.
        It calculates the differences in positions and generates a unique index for each relative position.
        """
        # Generate coordinates for the height and width of the window
        coords_h = torch.arange(self.window_size)
        coords_w = torch.arange(self.window_size)

        # Create a meshgrid for all the coordinates
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))

        # Flatten the coordinates into a 2D array
        coords_flatten = torch.flatten(coords, 1)

        # Calculate the relative position by subtracting each pair of coordinates
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()

        # Shift the coordinates to ensure positive indices
        relative_coords[:, :, 0] += self.window_size - 1
        relative_coords[:, :, 1] += self.window_size - 1

        # Scale the coordinates to a larger range (for uniqueness)
        relative_coords[:, :, 0] *= 2 * self.window_size - 1

        # Sum the two coordinate differences to get a unique index
        relative_position_index = relative_coords.sum(-1).flatten()

        # Register the relative position index as a buffer to be used during training
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self):
        """
        This method computes the relative position bias using the pre-defined meta network (MLP) and relative position table.
        """
        # Use the relative position index and the relative coordinates table to compute the bias
        relative_position_bias = F.embedding(
            self.relative_position_index,  # Look up bias values from the relative position index
            self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),  # Apply MLP to relative coords table
        )

        # Reshape the bias values to match the shape of the attention logits (window_size * window_size, window_size * window_size, num_heads)
        relative_position_bias = relative_position_bias.view(
            self.window_size * self.window_size, self.window_size * self.window_size, self.num_heads
        )

        # Permute the bias to match the attention mechanism (num_heads, window_size * window_size, window_size * window_size)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)

        # Apply a sigmoid activation to the bias values to smooth them and scale by a factor of 16
        return 16 * torch.sigmoid(relative_position_bias)


# Swin Transformer Block v2 

The main difference of the Swin Transformer block of the second version is the change of the normalization order. The normalization layer was moved before the skip connection adder, which reduced the amplitude of activations and provided more stable and efficient learning.

In addition, a stochastic drop path operation was added to the block to improve regularization. This is especially important for deep models and transformers, where this approach has been shown to perform better according to research.

Initialization of weights and biases was also introduced, which promotes stable learning and accelerates model convergence due to correct distribution of initial parameters

![image.png](/home/wladyka/Swin-Transformer/images/swin_transformer_block_v2.png)

In [134]:
class SwinEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, mask, sd_prob=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.stochastic_depth = StochasticDepth(sd_prob, "row") # Stochastic Depth with 0.1 probability of dropping out a row for tiny version of Swin Transformer

        self.WMSA = ShiftedWindowMSA(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=mask)
        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Dropout(p=0.1), # Default dropout probability is 0.0 in the torchvision implementation
            nn.Linear(embed_dim*4, embed_dim)
        )

        # Initialization of weights and biases (bias) in linear layers 
        for m in self.MLP:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight) # Xavier initialization for weights, which prevents the disappearance or explosion of gradients during training.
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6) # Set a small offset, to have a small impact in the initial stages of training.

    def forward(self, x):
        
        # Attention path with pre-normalization 
        res1 = x # Save input for the skip connection
        x = self.stochastic_depth(self.layer_norm(self.WMSA(x))) # Attention block with LayerNorm and Stochastic Depth(more efficient than Dropout for training transformers)
        x = res1 + x # Residual connection

        # MLP path with pre-normalization
        res2 = x  # Save intermediate result for skip connection
        x = self.stochastic_depth(self.layer_norm(self.MLP(x))) # MLP block with LayerNorm and Dropout
        x = res2 + x  # Residual connection

        return x
    
class AlternatingEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, sd_prob, window_size=8):
        super().__init__()
        self.WSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=False, sd_prob=sd_prob[0])
        self.SWSA = SwinEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, window_size=window_size, mask=True, sd_prob=sd_prob[1])
    
    def forward(self, x):
        return self.SWSA(self.WSA(x))

# Final Swin-Transformer Class v2 

In [135]:
class SwinTransformer(nn.Module):
    def __init__(self, depth=[2, 2, 6, 2], embed_dim=96, stochastic_depth_prob=0.2, window_size= 8):
        super().__init__()
        self.Embedding = SwinEmbedding()  # Embedding layer

        # Calculate total number of blocks
        total_stage_blocks = sum(depth)
        stage_block_id = 0

        self.stages = nn.ModuleList()

        in_channels = embed_dim
        for i_stage, num_blocks in enumerate(depth):
            temp_sd_prob = []
            for _ in range(num_blocks):
                # Calculate probability for the current layer
                sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
                temp_sd_prob.append(sd_prob)
                stage_block_id += 1

            #Add alternating encoder blocks recording to the depth list divided by 2, because each block has 2 sub-blocks
            sd_prob = [temp_sd_prob[i:i+2] for i in range(0, len(temp_sd_prob), 2)]
            for _ in range(int(num_blocks / 2)):
                num_heads = in_channels // 32
                #print(f"AlternatingEncoderBlock({in_channels}, {num_heads}, {sd_prob[0]})") # Debug
                self.stages.append(
                    AlternatingEncoderBlock(in_channels, num_heads, sd_prob[0], window_size=window_size)
                )
                sd_prob.pop(0)
                    
            # Add patch merging layer if this is not the last stage
            if i_stage < len(depth) - 1:
                self.stages.append(PatchMerging(in_channels))
                #print(f"PatchMerging({in_channels})") # Debug
                in_channels *= 2

    def forward(self, x):
        x = self.Embedding(x) 
        for stage in self.stages: 
            x = stage(x)
            print(x.shape, stage._get_name()) # Debug

        return x


In [136]:
def main():
    x = torch.randn((1,3,256,256)).cuda()
    model = SwinTransformer(depth=[2, 2, 6, 2], embed_dim=96, window_size=8).cuda()
    print(f"Output shape: {model(x).shape}")

if __name__ == '__main__':
    main()

torch.Size([1, 64, 64, 96]) AlternatingEncoderBlock
torch.Size([1, 32, 32, 192]) PatchMerging
torch.Size([1, 32, 32, 192]) AlternatingEncoderBlock
torch.Size([1, 16, 16, 384]) PatchMerging
torch.Size([1, 16, 16, 384]) AlternatingEncoderBlock
torch.Size([1, 16, 16, 384]) AlternatingEncoderBlock
torch.Size([1, 16, 16, 384]) AlternatingEncoderBlock
torch.Size([1, 8, 8, 768]) PatchMerging
torch.Size([1, 8, 8, 768]) AlternatingEncoderBlock
Output shape: torch.Size([1, 8, 8, 768])
