In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0, mlp_ratio=4.0, dropout=0.1):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.dropout = dropout
        
        # Window-based Multi-head Self-Attention (W-MSA)
        self.attn = WindowAttention(dim, num_heads, window_size, shift_size, dropout)

        # Feed-forward Network (MLP)
        self.ffn = MLP(dim, int(dim * mlp_ratio), dropout)

        # Layer Norms
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Apply LayerNorm
        res = x
        x = self.norm1(x)
        
        # Window-based Attention
        x = self.attn(x)

        # Add and norm
        x = res + x
        res = x
        x = self.norm2(x)

        # MLP
        x = self.ffn(x)
        
        # Add and norm
        x = res + x
        return x

In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, shift_size=0, dropout=0.1):
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

        # Define the relative position bias (used to compute self-attention)
        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size - 1) ** 2, num_heads))

        # Initialize
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x):
        B, N, C = x.shape
        H = self.num_heads
        window_size = self.window_size
        shift_size = self.shift_size
        
        # Create Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Relative position bias
        relative_position_bias = self.relative_position_bias_table.view(
            window_size, window_size, window_size, window_size, -1
        ).reshape(-1, H)
        attention_map = torch.matmul(q, k.transpose(-2, -1)) + relative_position_bias
        
        # Apply softmax
        attention_map = F.softmax(attention_map, dim=-1)
        
        # Dropout
        attention_map = self.attn_drop(attention_map)

        # Apply attention to values (v)
        out = torch.matmul(attention_map, v)

        # Project the output back to original dimension
        out = out.permute(0, 2, 1, 3).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout=0.1):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, in_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


# Patch Merging Layer
class PatchMerging(nn.Module):
    def __init__(self, dim, out_dim, stride=2):
        super(PatchMerging, self).__init__()
        self.dim = dim
        self.out_dim = out_dim
        self.stride = stride
        
        self.merge = nn.Linear(4 * dim, out_dim)

    def forward(self, x):
        B, H, W, C = x.shape
        x = x.view(B, H, W, 4, C).permute(0, 3, 1, 2, 4).reshape(B, -1, 4 * C)
        return self.merge(x)

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_channels=3, embed_dim=96):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.flatten = nn.Flatten(2)

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [None]:
class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_channels=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, num_classes=1000):
        super(SwinTransformer, self).__init__()

        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size
        self.num_classes = num_classes
        
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        self.layers = nn.ModuleList()
        num_layers = len(depths)
        for i in range(num_layers):
            self.layers.append(
                nn.ModuleList([
                    SwinTransformerBlock(
                        dim=embed_dim * (2**i), 
                        num_heads=num_heads[i], 
                        window_size=window_size, 
                        shift_size=window_size//2, 
                        mlp_ratio=mlp_ratio
                    ) for _ in range(depths[i])
                ])
            )
            
            if i < num_layers - 1:
                self.layers.append(PatchMerging(embed_dim * (2**i), embed_dim * (2**(i+1))))

        self.norm = nn.LayerNorm(embed_dim * (2**(num_layers-1)))

        # Final classification head
        self.head = nn.Linear(embed_dim * (2**(num_layers-1)), num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        
        for i, layer in enumerate(self.layers):
            for block in layer:
                x = block(x)
            if i < len(self.layers) - 1:
                x = self.layers[i+1](x)

        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x
        