# A Pytorch Implementation of [SwinTransformer v2](https://arxiv.org/pdf/2111.09883) with the Most Concise Code.

In [4]:
from typing import Tuple, Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

## Relative Position Bias Indices

In [3]:
def get_relative_distances(window_size: int) -> torch.Tensor:
    indices = torch.tensor([[i, j] for i in range(window_size) for j in range(window_size)])
    distances = indices[None, :, :] - indices[:, None, :]
    # shape of distance: [window_size ** 2, window_size ** 2, 2]
    return distances

window_size = 3
relative_distances_1 = get_relative_distances(window_size)
print(relative_distances_1.shape)
# relative_distances_2 = get_relative_distances(window_size) + window_size - 1
# pos_embedding = torch.randn(2 * window_size - 1, 2 * window_size - 1)
# print(pos_embedding[relative_distances_1[:, :, 0], relative_distances_1[:, :, 1]] == pos_embedding.roll((window_size - 1, window_size - 1), dims=(0, 1))[relative_distances_2[:, :, 0], relative_distances_2[:, :, 1]], 
#       '\n', relative_distances_2[:, :, 0], '\n', relative_distances_2[:, :, 1])

torch.Size([9, 9, 2])


## Positionwise FFN

In [None]:
class FeedForward(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_features: int,
                 out_features: int,
                 dropout: float=0.1) -> None:
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(in_features, hidden_features)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_features, out_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.dropout1(self.gelu(self.linear1(x)))
        return self.dropout2(self.linear2(x))

## Window Attention

In [4]:
class CyclicShift(nn.Module):
    def __init__(self, displacement: int):
        super(CyclicShift, self).__init__()
        self.displacement = displacement

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [None]:

class WindowAttention(nn.Module):
    def __init__(self,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int,
                 shift_size: int,
                 dropout_attention: float = 0.1,
                 dropout_proj: float = 0.1,
                 meta_network_hidden_features: int = 256,
                 sequential_self_attention: bool = False) -> None:
        super(WindowAttention, self).__init__()
        assert in_features % number_of_heads == 0, 'The num of input features (in_features) must be divisible by the number of heads'
        self.in_features = in_features
        self.window_size = window_size
        self.number_of_heads = number_of_heads
        self.shift_size = shift_size
        self.d_k = in_features // number_of_heads
        
        if self.shift_size:
            self.cyclic_shift = CyclicShift(-shift_size)
            self.cyclic_back_shift = CyclicShift(shift_size)
        
        self.linears = nn.ModuleList([nn.Linear(in_features, in_features) for _ in range(4)])
        self.__set_relative_distances()
        self.meta_network: nn.Module = nn.Sequential(
            nn.Linear(in_features=2, out_features=meta_network_hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads),
        )
        self.tau = nn.Parameter(torch.ones(1, number_of_heads, 1, 1))
        self.dropout_attention = nn.Dropout(dropout_attention)
        self.dropout_proj = nn.Dropout(dropout_proj)

    def __set_relative_distances(self) -> torch.Tensor:
        indices = torch.arange(self.window_size)
        coordinates = torch.stack(torch.meshgrid([indices, indices]), dim=0)
        coordinates = coordinates.flatten(1).transpose(0, 1)
        relative_distances = coordinates[None, :, :] - coordinates[:, None, :]
        # shape of relative_distances_log: [window_size ** 2, window_size ** 2, 2]
        relative_distances_log = torch.sign(relative_distances) * torch.log(torch.abs(relative_distances) + 1)
        self.register_buffer('relative_distances_log', relative_distances_log)
    
    def __get_positional_encoding(self) -> torch.Tensor:
        relative_position_bias: torch.Tensor = self.meta_network(self.relative_distances_log)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_of_heads, 1, self.window_size ** 2, self.window_size ** 2)
        # shape of return: [num_heads, 1, window_size ** 2, window_size ** 2]
        return relative_position_bias
    
    def forward(self,
                x: torch.Tensor,
                mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        if self.shift_size:
            x = self.cyclic_shift(x)
        b, h, w, _,  = x.shape
        nw_h, nw_w = h // self.window_size, w // self.window_size
        # shape of query, key, value: [b, h, nw_h, nw_w, window_size, window_size, d_k]
        query, key, value = [lin(x).view(b, nw_h, self.window_size, nw_w, self.window_size, self.number_of_heads, self.d_k).permute(0, 5, 1, 3, 2, 4, 6)
                             .contiguous().view(b, self.number_of_heads, nw_h * nw_w, self.window_size * self.window_size, self.d_k)
                        for lin, x in zip(self.linears, (x, x, x))]
        # shape of scores: [b, h, nw_h * nw_w, window_size ** 2, d_k]
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.maximum(torch.norm(query, dim=-1, keepdim=True) * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1), 1e-6)
        scores /= max(self.tau, 0.01)
        scores += self.__get_positional_encoding()
        
        if self.shift_size:
            scores[:, :, -nw_w:].masked_fill_(mask[0] == 0, -1e9)
            scores[:, :, nw_w - 1::nw_w].masked_fill_(mask[1] == 0, -1e9)
        p_attn = self.dropout_attention(F.softmax(scores, dim=-1))
        # shape of x: [b, h, nw_h, nw_w, window_size, window_size, d_k]
        x = torch.matmul(p_attn, value)

        x = x.view(b, self.number_of_heads, nw_h, nw_w, self.window_size, self.window_size, self.d_k).permute(0, 2, 4, 3, 5, 1, 6).contiguous().view(b, h, w, -1)
        x = self.dropout_proj(self.linears[-1](x))
        if self.shift_size:
            x = self.cyclic_back_shift(x)
        # shape of output: [B, H, W, C]
        return x

## Swin Block & Add & Norm

In [None]:
class DropPath(nn.Module):
    def __init__(self,
                 drop_prob: float = 0.,
                 scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
        
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor

In [None]:
class SwinBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int,
                 shift_size: int = 0,
                 ffn_feature_ratio: int = 4,
                 dropout_attention: float = 0.1,
                 dropout: float = 0.1,
                 dropout_path: float = 0.) -> None:
        super(SwinBlock, self).__init__()
        self.norm1 = nn.LayerNorm(in_features)
        self.attention = WindowAttention(in_features, window_size, number_of_heads, shift_size, dropout_attention, dropout)
        self.drop_path = DropPath(dropout_path)
        self.norm2 = nn.LayerNorm(in_features)
        self.feed_forward = FeedForward(in_features, in_features * ffn_feature_ratio, in_features, dropout)
        

    def forward(self,
                x: torch.Tensor,
                mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        x = x + self.drop_path(self.attention(self.norm1(x), mask))
        return x + self.drop_path(self.feed_forward(self.norm2(x)))

## Patch Merging

In [7]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, downscaling_factor: int):
        super(PatchMerging, self).__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        # shape of output: [B, H // scale_factor, W // scale_factor, C']
        return self.linear(x)

## Stage

In [None]:

class StageModule(nn.Module):
    def __init__(self,
                 in_channels: int,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int, 
                 ffn_feature_ratio: int,
                 layers: int,
                 downscaling_factor: int,
                 dropout_attention: float=0.1,
                 dropout: float=0.1,
                 dropout_path=0.1) -> None:
        super(StageModule, self).__init__()
        assert layers % 2 == 0, 'number of layers must be divisible by 2 for regular and shifted block'
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.patch_merging = PatchMerging(in_channels, in_features, downscaling_factor)
        self.layers_regular = nn.ModuleList([SwinBlock(in_features, window_size, number_of_heads, 0, ffn_feature_ratio, dropout_attention, dropout, dropout_path) for _ in range(layers // 2)])
        self.layers_shifted = nn.ModuleList([SwinBlock(in_features, window_size, number_of_heads, window_size // 2, ffn_feature_ratio, dropout_attention, dropout, dropout_path) for _ in range(layers // 2)])
        self.__create_mask()
        
    def __create_mask(self) -> None:
        # create left_right_mask
        left_right_mask: torch.Tensor = torch.ones(self.window_size, self.window_size, self.window_size, self.window_size, dtype=torch.uint8)
        left_right_mask[:, -self.shift_size:, :, :-self.shift_size] = 0
        left_right_mask[:, :-self.shift_size, :, -self.shift_size:] = 0
        left_right_mask = left_right_mask.view(self.window_size ** 2, self.window_size ** 2)
        # create upper_lower_mask
        upper_lower_mask: torch.Tensor = torch.ones(self.window_size ** 2, self.window_size ** 2, dtype=torch.uint8)
        upper_lower_mask[-self.shift_size * self.window_size:, :-self.shift_size * self.window_size] = 0
        upper_lower_mask[:-self.shift_size * self.window_size, -self.shift_size * self.window_size:] = 0
        self.register_buffer('left_right_mask', left_right_mask)
        self.register_buffer('upper_lower_mask', upper_lower_mask)

    def forward(self, 
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        x = self.patch_merging(x)
        for layer_regular, layer_shifted in zip(self.layers_regular, self.layers_shifted):
            x = layer_regular(x)
            x = layer_shifted(x, (self.upper_lower_mask, self.left_right_mask))
        # shape of output: [B, C', H', W']
        return x.permute(0, 3, 1, 2)

## Swin Transformer

In [None]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 *,
                 in_channels: int = 3,
                 num_classes: int = 10,
                 in_features: int = 96,
                 window_size: int = 7,
                 heads: Tuple[int, int, int, int] = (3, 6, 12, 24),
                 ffn_feature_ratio: int = 4,
                 num_layers: Tuple[int, int, int, int] = (2, 2, 6, 2),
                 downscaling_factors: Tuple[int, int, int, int] = (4, 2, 2, 2),
                 dropout_attention: float = 0.1,
                 dropout: float = 0.1,
                 dropout_path: float = 0.1) -> None:
        super(SwinTransformer, self).__init__()
        self.stage1 = StageModule(in_channels, in_features, window_size, heads[0], ffn_feature_ratio, num_layers[0], downscaling_factors[0], dropout_attention, dropout, dropout_path)
        self.stage2 = StageModule(in_features, in_features * 2, window_size, heads[1], ffn_feature_ratio, num_layers[1], downscaling_factors[1], dropout_attention, dropout, dropout_path)
        self.stage3 = StageModule(in_features * 2, in_features * 4, window_size, heads[2], ffn_feature_ratio, num_layers[2], downscaling_factors[2], dropout_attention, dropout, dropout_path)
        self.stage4 = StageModule(in_features * 4, in_features * 8, window_size, heads[3], ffn_feature_ratio, num_layers[3], downscaling_factors[3], dropout_attention, dropout, dropout_path)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(in_features * 8),
            nn.Linear(in_features * 8, num_classes)
        )
        
    def forward(self, 
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        # shape of output: [B, num_classes]
        return self.mlp_head(x.mean(dim=[2, 3]))

In [33]:
def swin_tiny_patch4_window7_224(d_model: int=96, num_layers: tuple[int, int, int, int]=(2, 2, 6, 2), heads: tuple[int, int, int, int]=(3, 6, 12, 24)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_small_patch4_window7_224(d_model: int=96, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(3, 6, 12, 24)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_base_patch4_window7_224(d_model: int=128, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(4, 8, 16, 32)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_large_patch4_window7_224(d_model: int=192, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(6, 12, 24, 48)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

## Train

In [10]:
batch_size = 128
img_size = 96
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor()
])
data_train = datasets.FashionMNIST(root="../data", train=True, download=True, transform=transform)
data_val = datasets.FashionMNIST(root="../data", train=False, download=True, transform=transform)
loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_layers = (2, 2, 4, 2)
window_size = 3
in_channels = 1
model = SwinTransformer(in_channels=1, window_size=window_size, num_layers=num_layers).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
loss = F.cross_entropy

In [12]:
max_epochs = 50
for epoch in range(max_epochs):
    model.train()
    train_loss = train_count = train_acc = 0
    for i, (x, y) in enumerate(loader_train):
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        l = loss(y_pred, y, reduction="sum")
        optim.zero_grad()
        l.backward()
        optim.step()
        with torch.no_grad():
            train_loss += l.item()
            train_count += y.size(0)
            train_acc += (y_pred.argmax(1) == y).sum().item()
    model.eval()
    val_loss = val_count = val_acc = 0
    with torch.no_grad():
        for x, y in loader_val:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            l = loss(y_pred, y, reduction="sum")
            val_loss += l.item()
            val_count += y.size(0)
            val_acc += (y_pred.argmax(1) == y).sum().item()
    print(f"Epoch: {epoch + 1:03d}, Train Loss: {train_loss / train_count:.4f}, Train Acc: {train_acc / train_count:.4f}, Val Loss: {val_loss / val_count:.4f}, Val Acc: {val_acc / val_count:.4f}")

Epoch: 001, Train Loss: 0.6149, Train Acc: 0.7741, Val Loss: 0.4051, Val Acc: 0.8484
Epoch: 002, Train Loss: 0.3585, Train Acc: 0.8674, Val Loss: 0.3821, Val Acc: 0.8585
Epoch: 003, Train Loss: 0.3142, Train Acc: 0.8823, Val Loss: 0.3202, Val Acc: 0.8794
Epoch: 004, Train Loss: 0.2847, Train Acc: 0.8941, Val Loss: 0.3064, Val Acc: 0.8844
Epoch: 005, Train Loss: 0.2627, Train Acc: 0.9017, Val Loss: 0.2867, Val Acc: 0.8935
Epoch: 006, Train Loss: 0.2417, Train Acc: 0.9082, Val Loss: 0.2700, Val Acc: 0.9024
Epoch: 007, Train Loss: 0.2240, Train Acc: 0.9157, Val Loss: 0.2989, Val Acc: 0.8912
Epoch: 008, Train Loss: 0.2095, Train Acc: 0.9202, Val Loss: 0.2691, Val Acc: 0.9027
Epoch: 009, Train Loss: 0.1950, Train Acc: 0.9261, Val Loss: 0.2655, Val Acc: 0.9051
Epoch: 010, Train Loss: 0.1800, Train Acc: 0.9313, Val Loss: 0.2564, Val Acc: 0.9086
Epoch: 011, Train Loss: 0.1719, Train Acc: 0.9352, Val Loss: 0.2547, Val Acc: 0.9126
Epoch: 012, Train Loss: 0.1564, Train Acc: 0.9399, Val Loss: 0.25

In [13]:
torch.save(model.state_dict(), "model2.pth")

In [2]:
26 ** 2

676