In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0, mlp_ratio=4.0, dropout=0.1):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.dropout = dropout
        
        # Window-based Multi-head Self-Attention (W-MSA)
        self.attn = WindowAttention(dim, num_heads, window_size, shift_size, dropout)

        # Feed-forward Network (MLP)
        self.ffn = MLP(dim, int(dim * mlp_ratio), dropout)

        # Layer Norms
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Apply LayerNorm
        res = x
        x = self.norm1(x)
        
        # Window-based Attention
        x = self.attn(x)

        # Add and norm
        x = res + x
        res = x
        x = self.norm2(x)

        # MLP
        x = self.ffn(x)
        
        # Add and norm
        x = res + x
        return x

In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0, dropout=0.1):
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

        # Define the relative position bias (used to compute self-attention)
        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size - 1) ** 2, num_heads))

        # Initialize
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x):
        B, N, C = x.shape
        H = self.num_heads
        window_size = self.window_size
        shift_size = self.shift_size
        
        # Create Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Relative position bias
        relative_position_bias = self.relative_position_bias_table.view(
            window_size, window_size, window_size, window_size, -1
        ).reshape(-1, H)
        attention_map = torch.matmul(q, k.transpose(-2, -1)) + relative_position_bias
        
        # Apply softmax
        attention_map = F.softmax(attention_map, dim=-1)
        
        # Dropout
        attention_map = self.attn_drop(attention_map)

        # Apply attention to values (v)
        out = torch.matmul(attention_map, v)

        # Project the output back to original dimension
        out = out.permute(0, 2, 1, 3).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out