In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### ConvMixer

In [2]:
class ResBlock(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

In [3]:
class ConvMixerLayer(nn.Module):
    def __init__(self, hidden_size, kernel_size):
        super().__init__()

        self.hidden_size = hidden_size
        self.kernel_size = kernel_size

        self.model = nn.Sequential(
            ResBlock(
                nn.Sequential(
                    nn.Conv2d(self.hidden_size, self.hidden_size, self.kernel_size, groups=self.hidden_size, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(self.hidden_size)
                )
            ),
            nn.Conv2d(hidden_size, hidden_size, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(self.hidden_size)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class ConvMixer(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.num_channels = cfg.num_channels
        self.hidden_size = cfg.hidden_size
        self.patch_size = cfg.patch_size
        self.n_classes = cfg.n_classes

        self.projection = nn.Conv(
                self.num_channels, self.hidden_size, 
                kernel_size=self.patch_size, stride=self.patch_size
        )
        self.prenorm = nn.Sequential(
            nn.GELU(),
            nn.BatchNorm2d(self.hidden_size)
        )
        self.encoder = nn.ModuleList(
            [ConvMixerLayer(self.hidden_size, self.kernel_size) for _ in range(cfg.num_layers)]
        )
        self.pooler = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(self.hidden_size, self.n_classes)
        )

    def forward(self, x):
        x = self.projection(x)
        x = self.prenorm(x)
        for layer in self.encoder:
            x = layer(x)
        x = self.pooler(x)
        return x