![MetaFormer](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*djedJeSMHNiRDE8FDq4Y-Q.png)

In [None]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T

from utils import seed_everything, batch_plot, train_part_challenge, check_accuracy_part34


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything()

%load_ext tensorboard

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=patch_size, stride=patch_size, padding=0
        )

    def forward(self, x):
        return self.conv(x)


class MLPBlock(nn.Module):
    def __init__(self, channels, expansion_factor, dropout_rate=0.0):
        super().__init__()
        expanded_channels = int(expansion_factor * channels)
        self.fc1 = nn.Conv2d(channels, expanded_channels, 1)
        self.gelu = nn.GELU()
        self.fc2 = nn.Conv2d(expanded_channels, channels, 1)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class ChannelLayerNorm(nn.LayerNorm):
    def __init__(self, num_channels, **kwargs):
        super().__init__(num_channels, **kwargs)
        self.bias = None

    def forward(self, x):
        # x.shape = (batch_size, channels, image_size, image_size)
        x = x.permute(0, 2, 3, 1)  # (batch_size, image_size, image_size, channels)
        x = super().forward(x)  # (batch_size, image_size, image_size, channels)
        x = x.permute(0, 3, 1, 2)  # (batch_size, channels, image_size, image_size)
        return x


class ConvNormAct(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        padding=0,
        stride=1,
        groups=1,
        bias=False,
        dropout_rate=0.0,
        norm=nn.BatchNorm2d,
        activation=nn.GELU,
    ):
        super().__init__()
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0.0 else nn.Identity()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            stride=stride,
            groups=groups,
            bias=bias,
        )
        self.norm = norm(out_channels) if norm is not None else nn.Identity()
        self.act = activation() if activation is not None else nn.Identity()

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


# MobileNetV2: https://arxiv.org/abs/1801.04381
class DepthSeparableConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        expansion_factor=2,
        kernel_size=3,
        padding=1,
        stride=1,
        bias=False,
        dropout_rate=0.0,
        norm=nn.BatchNorm2d,
        activation=nn.GELU,
    ):
        super().__init__()
        expanded_channels = int(expansion_factor * in_channels)
        self.point_wise_expander = (
            ConvNormAct(
                in_channels=in_channels,
                out_channels=expanded_channels,
                kernel_size=1,
                bias=bias,
                norm=norm,
                activation=activation,
                dropout_rate=dropout_rate,
            )
            if in_channels != expanded_channels
            else nn.Identity()
        )
        self.depth_with_conv = ConvNormAct(
            in_channels=expanded_channels,
            out_channels=expanded_channels,
            kernel_size=kernel_size,
            padding=padding,
            stride=stride,
            groups=expanded_channels,
            bias=bias,
            norm=norm,
            activation=activation,
            dropout_rate=dropout_rate,
        )
        self.point_wise_compressor = ConvNormAct(
            in_channels=expanded_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=bias,
            norm=norm,
            activation=None,
        )

    def forward(self, x):
        x = self.point_wise_expander(x)
        x = self.depth_with_conv(x)
        x = self.point_wise_compressor(x)
        return x


class MetaFormerBlock(nn.Module):
    def __init__(self, channels, expansion_factor, dropout_rate=0.0, kernel_size=3, norm=nn.BatchNorm2d):
        super().__init__()
        self.norm1 = norm(channels)
        self.token_mixer = DepthSeparableConv(
            in_channels=channels,
            out_channels=channels,
            expansion_factor=expansion_factor,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            dropout_rate=dropout_rate,
            norm=norm,
        )
        self.norm2 = norm(channels)
        self.mlp2 = MLPBlock(channels, expansion_factor, dropout_rate)

    def forward(self, x):
        # x.shape = (batch_size, channels, image_size, image_size)
        x = x + self.token_mixer(self.norm1(x))
        x = x + self.mlp2(self.norm2(x))
        return x


class MetaFormerStage(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        expansion_factor,
        num_blocks,
        dropout_rate=0.0,
        kernel_size=3,
        down_sample=True,
        norm=nn.BatchNorm2d,
    ):
        super().__init__()
        self.down_sampler = (
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1)
            if down_sample
            else nn.Identity()
        )
        self.mixer_blocks = nn.ModuleList(
            [
                MetaFormerBlock(
                    channels=out_channels,
                    expansion_factor=expansion_factor,
                    dropout_rate=dropout_rate,
                    kernel_size=kernel_size,
                    norm=norm,
                )
                for _ in range(num_blocks)
            ]
        )

    def forward(self, x):
        x = self.down_sampler(x)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        return x


# https://arxiv.org/abs/2111.11418
# https://arxiv.org/abs/2210.13452
class MetaFormer(nn.Module):
    def __init__(
        self,
        patch_size,
        hidden=64,
        expansion_factors=None,
        dropout_rates=None,
        num_blocks=None,
        num_classes=10,
        kernel_size=3,
        num_layers=2,
        down_sample=True,
        use_head=True,
        norm=nn.BatchNorm2d,
    ):
        super().__init__()
        if expansion_factors is None:
            expansion_factors = [2] * num_layers
        if dropout_rates is None:
            dropout_rates = [0.0] * num_layers
        if num_blocks is None:
            num_blocks = [2] * num_layers
        self.patch_embedding = PatchEmbedding(patch_size, in_channels=3, out_channels=hidden)

        down_sample_factor = 1 + int(down_sample)
        self.stacks = nn.ModuleList(
            [
                MetaFormerStage(
                    in_channels=hidden * down_sample_factor ** (max(0, i - 1)),
                    out_channels=hidden * down_sample_factor**i,
                    expansion_factor=expansion_factors[i],
                    num_blocks=num_blocks[i],
                    dropout_rate=dropout_rates[i],
                    kernel_size=kernel_size,
                    down_sample=down_sample and i > 0,
                    norm=norm,
                )
                for i in range(num_layers)
            ]
        )

        expanded = down_sample_factor ** (num_layers - 1)
        head = hidden if use_head else hidden * expanded
        self.norm = norm(hidden * expanded)
        self.down = nn.Linear(hidden * expanded, hidden) if use_head else nn.Identity()
        self.fc = nn.Linear(head, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        for stack in self.stacks:
            x = stack(x)
        x = self.norm(x)
        x = x.mean(dim=[2, 3])
        x = self.down(x)
        x = self.fc(x)
        return x

In [None]:
NUM_TRAIN = 49000
data_path = "../code/cs231n/datasets"

transform_train = T.Compose(
    [T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
)

transform = T.Compose([T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

cifar10_train = dset.CIFAR10(data_path, train=True, download=True, transform=transform_train)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10(data_path, train=True, download=True, transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10(data_path, train=False, download=True, transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

cifra10_visualize = dset.CIFAR10(data_path, train=True, download=True, transform=T.ToTensor())
loader_visualize = DataLoader(
    cifra10_visualize, batch_size=16, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000))
)

images, labels = next(iter(loader_visualize))
batch_plot(images.permute(0, 2, 3, 1).numpy(), labels.numpy())

In [None]:
epoch = 10
log_dir = Path(data_path).joinpath("runs")

model = MetaFormer(
    patch_size=4,
    hidden=128,
    num_layers=2,
    expansion_factors=[3, 2],
    dropout_rates=[0.1, 0.2],
    num_blocks=[3, 2],
    num_classes=10,
    kernel_size=3,
    use_head=True,
    down_sample=True,
)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch)

check_point_name = train_part_challenge(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=epoch,
    device=device,
    train_loader=loader_train,
    valid_loader=loader_val,
    log_dir=log_dir,
)

In [None]:
checkpoint = torch.load(check_point_name)
print(
    f"best model with train accuracy: {checkpoint['train_accuracy']:.2f} and valid accuracy {checkpoint['valid_accuracy']:.2f}"
)
model.load_state_dict(checkpoint["model_state_dict"])
check_accuracy_part34(loader_test, model, device=device)

<img src="images/training_process_conformer.png">

```
best model with train accuracy: 90.23 and valid accuracy 84.80
Checking accuracy on test set
Got 8240 / 10000 correct (82.40)
```