In [1]:
from typing import List, Tuple, Any, Union, Optional, Callable
import math, copy
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [2]:
def clones(module: nn.Module, N: int) -> nn.ModuleList:
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# Relative Position Bias Indices

In [None]:
def get_relative_distances(window_size: int) -> torch.Tensor:
    indices = torch.tensor([[i, j] for i in range(window_size) for j in range(window_size)])
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

window_size = 3
relative_distances_1 = get_relative_distances(window_size)
relative_distances_2 = get_relative_distances(window_size) + window_size - 1
pos_embedding = torch.randn(2 * window_size - 1, 2 * window_size - 1)
print(pos_embedding[relative_distances_1[:, :, 0], relative_distances_1[:, :, 1]] == pos_embedding.roll((window_size - 1, window_size - 1), dims=(0, 1))[relative_distances_2[:, :, 0], relative_distances_2[:, :, 1]], 
      '\n', relative_distances_2[:, :, 0], '\n', relative_distances_2[:, :, 1])

tensor([[True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True]]) 
 tensor([[2, 2, 2, 3, 3, 3, 4, 4, 4],
        [2, 2, 2, 3, 3, 3, 4, 4, 4],
        [2, 2, 2, 3, 3, 3, 4, 4, 4],
        [1, 1, 1, 2, 2, 2, 3, 3, 3],
        [1, 1, 1, 2, 2, 2, 3, 3, 3],
        [1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 0, 0, 1, 1, 1, 2, 2, 2]]) tensor([[2, 3, 4, 2, 3, 4, 2, 3, 4],
        [1, 2, 3, 1, 2, 3, 1, 2, 3],
        [0, 1

# Add & Norm

In [None]:
class SublayerConnection(nn.Module):
    def __init__(self,
                 in_features: int,
                 dropout: float) -> None:
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(in_features)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,
                x: torch.Tensor,
                sublayer: Callable[[int, int], int]) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))

# Positionwise FFN

In [3]:
class FeedForward(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_features: int,
                 out_features: int,
                 dropout: float=0.1) -> None:
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(in_features, hidden_features)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_features, out_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.dropout1(self.gelu(self.linear1(x)))
        return self.dropout2(self.linear2(x))

# Window Attention

In [4]:
class CyclicShift(nn.Module):
    def __init__(self, displacement: int):
        super(CyclicShift, self).__init__()
        self.displacement = displacement

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [27]:
def create_mask(window_size: int, displacement: int, upper_lower: bool=False, left_right: bool=False) -> torch.Tensor:
    mask = torch.ones(window_size ** 2, window_size ** 2, dtype=torch.uint8)
    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = 0
        mask[:-displacement * window_size, -displacement * window_size:] = 0
    if left_right:
        mask = mask.view(window_size, window_size, window_size, window_size)
        mask[:, -displacement:, :, :-displacement] = 0
        mask[:, :-displacement, :, -displacement:] = 0
        mask = mask.view(window_size ** 2, window_size ** 2)
    return mask

In [5]:
class WindowAttention(nn.Module):
    def __init__(self,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int,
                 shift_size: int,
                 dropout_attention: float = 0.1,
                 dropout_projection: float = 0.1) -> None:
        super(WindowAttention, self).__init__()
        assert in_features % number_of_heads == 0, 'The num of input features (in_features) must be divisible by the number of heads'
        self.in_features = in_features
        self.window_size = window_size
        self.number_of_heads = number_of_heads
        self.shift_size = shift_size
        self.d_k = in_features // number_of_heads
        
        if self.shift_size:
            self.cyclic_shift = CyclicShift(-shift_size)
            self.cyclic_back_shift = CyclicShift(shift_size)
        
        self.linears = clones(nn.Linear(in_features, in_features), 4)
        
        # relative_indices = get_relative_distances(window_size)
        # self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))[relative_indices[:, :, 0], relative_indices[:, :, 1]]
        
        self.dropout_attention = nn.Dropout(dropout_attention)
        self.dropout_projection = nn.Dropout(dropout_projection)

    def forward(self,
                x: torch.Tensor,
                mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        if self.shift_size:
            x = self.cyclic_shift(x)
        b, h, w, _,  = x.shape
        nw_h, nw_w = h // self.window_size, w // self.window_size
        # shape of query, key, value: [b, h, nw_h, nw_w, window_size, window_size, d_k]
        query, key, value = [lin(x).view(b, nw_h, self.window_size, nw_w, self.window_size, self.number_of_heads, self.d_k).permute(0, 5, 1, 3, 2, 4, 6)
                             .contiguous().view(b, self.number_of_heads, nw_h * nw_w, self.window_size * self.window_size, self.d_k)
                        for lin, x in zip(self.linears, (x, x, x))]
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) # + self.pos_embedding
        
        if self.shift_size:
            scores[:, :, -nw_w:].masked_fill_(mask[0] == 0, -1e9)
            scores[:, :, nw_w - 1::nw_w].masked_fill_(mask[1] == 0, -1e9)
        p_attn = self.dropout_attention(F.softmax(scores, dim=-1))
        # shape of x: [b, h, nw_h, nw_w, window_size, window_size, d_k]
        x = torch.matmul(p_attn, value)

        x = x.view(b, self.number_of_heads, nw_h, nw_w, self.window_size, self.window_size, self.d_k).permute(0, 2, 4, 3, 5, 1, 6).contiguous().view(b, h, w, -1)
        x = self.dropout_projection(self.linears[-1](x))
        if self.shift_size:
            x = self.cyclic_back_shift(x)
        # shape of output: [B, H, W, C]
        return x

# Patch Merging & Swin Block & Stage

In [6]:
class SwinBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int,
                 shift_size: int = 0,
                 ffn_feature_ratio: int = 4,
                 dropout_attention: float = 0.1,
                 dropout: float = 0.1,
                 ) -> None:
        super(SwinBlock, self).__init__()
        self.norm1 = nn.LayerNorm(in_features)
        self.attention = WindowAttention(in_features, window_size, number_of_heads, shift_size, dropout_attention, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(in_features)
        self.feed_forward = FeedForward(in_features, in_features * ffn_feature_ratio, in_features, dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor,
                mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
        # shape of input: [B, H, W, C]
        x = x + self.dropout1(self.attention(self.norm1(x), mask))
        return x + self.dropout2(self.feed_forward(self.norm2(x)))

In [7]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, downscaling_factor: int):
        super(PatchMerging, self).__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        # shape of output: [B, H // scale_factor, W // scale_factor, C']
        return self.linear(x)

In [8]:

class StageModule(nn.Module):
    def __init__(self,
                 in_channels: int,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int, 
                 ffn_feature_ratio: int,
                 layers: int,
                 downscaling_factor: int,
                 dropout_attention: float=0.1,
                 dropout: float=0.1):
        super(StageModule, self).__init__()
        assert layers % 2 == 0, 'number of layers must be divisible by 2 for regular and shifted block'
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.patch_merging = PatchMerging(in_channels, in_features, downscaling_factor)
        self.layers_regular = nn.ModuleList([SwinBlock(in_features, window_size, number_of_heads, 0, ffn_feature_ratio, dropout_attention, dropout) for _ in range(layers // 2)])
        self.layers_shifted = nn.ModuleList([SwinBlock(in_features, window_size, number_of_heads, window_size // 2, ffn_feature_ratio, dropout_attention, dropout) for _ in range(layers // 2)])
        self.__create_mask()
        
    def __create_mask(self) -> None:
        # create left_right_mask
        left_right_mask: torch.Tensor = torch.ones(self.window_size, self.window_size, self.window_size, self.window_size, dtype=torch.uint8)
        left_right_mask[:, -self.shift_size:, :, :-self.shift_size] = 0
        left_right_mask[:, :-self.shift_size, :, -self.shift_size:] = 0
        left_right_mask = left_right_mask.view(window_size ** 2, window_size ** 2)
        # create upper_lower_mask
        upper_lower_mask: torch.Tensor = torch.ones(self.window_size ** 2, self.window_size ** 2, dtype=torch.uint8)
        upper_lower_mask[-self.shift_size * self.window_size:, :-self.shift_size * self.window_size] = 0
        upper_lower_mask[:-self.shift_size * self.window_size, -self.shift_size * self.window_size:] = 0
        self.register_buffer('left_right_mask', left_right_mask)
        self.register_buffer('upper_lower_mask', upper_lower_mask)

    def forward(self, 
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        x = self.patch_merging(x)
        for layer_regular, layer_shifted in zip(self.layers_regular, self.layers_shifted):
            x = layer_regular(x)
            x = layer_shifted(x, (self.upper_lower_mask, self.left_right_mask))
        # shape of output: [B, C', H', W']
        return x.permute(0, 3, 1, 2)

# Swin Transformer

In [9]:
class SwinTransformer(nn.Module):
    def __init__(self,
                 *,
                 in_channels: int = 3,
                 num_classes: int=10,
                 in_features: int = 96,
                 window_size: int=7,
                 heads: Tuple[int, int, int, int] = (3, 6, 12, 24),
                 ffn_feature_ratio: int = 4,
                 num_layers: Tuple[int, int, int, int] = (2, 2, 6, 2),
                 downscaling_factors: Tuple[int, int, int, int] = (4, 2, 2, 2),
                 dropout_attention: float=0.1,
                 dropout: float=0.1) -> None:
        super(SwinTransformer, self).__init__()
        self.stage1 = StageModule(in_channels, in_features, window_size, heads[0], ffn_feature_ratio, num_layers[0], downscaling_factors[0], dropout_attention, dropout)
        self.stage2 = StageModule(in_features, in_features * 2, window_size, heads[1], ffn_feature_ratio, num_layers[1], downscaling_factors[1], dropout_attention, dropout)
        self.stage3 = StageModule(in_features * 2, in_features * 4, window_size, heads[2], ffn_feature_ratio, num_layers[2], downscaling_factors[2], dropout_attention, dropout)
        self.stage4 = StageModule(in_features * 4, in_features * 8, window_size, heads[3], ffn_feature_ratio, num_layers[3], downscaling_factors[3], dropout_attention, dropout)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(in_features * 8),
            nn.Linear(in_features * 8, num_classes)
        )
        
    def forward(self, 
                x: torch.Tensor) -> torch.Tensor:
        # shape of input: [B, C, H, W]
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        # shape of output: [B, num_classes]
        return self.mlp_head(x.mean(dim=[2, 3]))

In [33]:
def swin_tiny_patch4_window7_224(d_model: int=96, num_layers: tuple[int, int, int, int]=(2, 2, 6, 2), heads: tuple[int, int, int, int]=(3, 6, 12, 24)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_small_patch4_window7_224(d_model: int=96, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(3, 6, 12, 24)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_base_patch4_window7_224(d_model: int=128, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(4, 8, 16, 32)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

def swin_large_patch4_window7_224(d_model: int=192, num_layers: tuple[int, int, int, int]=(2, 2, 18, 2), heads: tuple[int, int, int, int]=(6, 12, 24, 48)) -> SwinTransformer:
    return SwinTransformer(d_model=d_model, heads=heads, num_layers=num_layers)

# Train

In [10]:
batch_size = 128
img_size = 96
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor()
])
data_train = datasets.FashionMNIST(root="../data", train=True, download=True, transform=transform)
data_val = datasets.FashionMNIST(root="../data", train=False, download=True, transform=transform)
loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_model = 96
num_layers = (2, 2, 4, 2)
heads = (3, 6, 12, 24)
window_size = 3
dropout = 0.1
num_classes = 10
in_channels = 1
model = SwinTransformer(in_channels=1, window_size=window_size, num_layers=num_layers).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
loss = F.cross_entropy


In [12]:
max_epochs = 50
for epoch in range(max_epochs):
    model.train()
    train_loss = train_count = train_acc = 0
    for i, (x, y) in enumerate(loader_train):
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        l = loss(y_pred, y, reduction="sum")
        optim.zero_grad()
        l.backward()
        optim.step()
        with torch.no_grad():
            train_loss += l.item()
            train_count += y.size(0)
            train_acc += (y_pred.argmax(1) == y).sum().item()
    model.eval()
    val_loss = val_count = val_acc = 0
    with torch.no_grad():
        for x, y in loader_val:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            l = loss(y_pred, y, reduction="sum")
            val_loss += l.item()
            val_count += y.size(0)
            val_acc += (y_pred.argmax(1) == y).sum().item()
    print(f"Epoch: {epoch + 1:03d}, Train Loss: {train_loss / train_count:.4f}, Train Acc: {train_acc / train_count:.4f}, Val Loss: {val_loss / val_count:.4f}, Val Acc: {val_acc / val_count:.4f}")

Epoch: 001, Train Loss: 0.6883, Train Acc: 0.7469, Val Loss: 0.4597, Val Acc: 0.8376
Epoch: 002, Train Loss: 0.3808, Train Acc: 0.8591, Val Loss: 0.3746, Val Acc: 0.8605
Epoch: 003, Train Loss: 0.3269, Train Acc: 0.8798, Val Loss: 0.3281, Val Acc: 0.8779
Epoch: 004, Train Loss: 0.3001, Train Acc: 0.8884, Val Loss: 0.3089, Val Acc: 0.8848
Epoch: 005, Train Loss: 0.2753, Train Acc: 0.8971, Val Loss: 0.2954, Val Acc: 0.8878
Epoch: 006, Train Loss: 0.2586, Train Acc: 0.9013, Val Loss: 0.2887, Val Acc: 0.8940
Epoch: 007, Train Loss: 0.2430, Train Acc: 0.9072, Val Loss: 0.2869, Val Acc: 0.8951
Epoch: 008, Train Loss: 0.2273, Train Acc: 0.9146, Val Loss: 0.2832, Val Acc: 0.8972
Epoch: 009, Train Loss: 0.2138, Train Acc: 0.9185, Val Loss: 0.2726, Val Acc: 0.8999
Epoch: 010, Train Loss: 0.2019, Train Acc: 0.9235, Val Loss: 0.2803, Val Acc: 0.8990
Epoch: 011, Train Loss: 0.1923, Train Acc: 0.9265, Val Loss: 0.2577, Val Acc: 0.9093
Epoch: 012, Train Loss: 0.1799, Train Acc: 0.9314, Val Loss: 0.26

KeyboardInterrupt: 

In [13]:
torch.save(model.state_dict(), "model1.pth")