![MLP-Mixer](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-07-20_at_12.09.16_PM_aLnxO7E.png)

In [None]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T

from conv import generate_output_size


from utils import seed_everything, batch_plot, train_part_challenge, check_accuracy_part34


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything()

%load_ext tensorboard

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, flatten=True):
        super().__init__()
        self.patch_size = patch_size
        self.flatten = flatten

    def forward(self, x):
        # x.shape = (batch_size, channel, image_size, image_size)
        batch_size, channel, image_height, image_width = x.shape
        stride_batch, stride_channel, stride_height, stride_width = x.stride()
        output_height = generate_output_size(
            image_height, kernel_size=self.patch_size, stride=self.patch_size, padding=0
        )
        output_width = generate_output_size(image_width, kernel_size=self.patch_size, stride=self.patch_size, padding=0)

        x = torch.as_strided(
            x,
            size=(batch_size, channel, output_height, output_width, self.patch_size, self.patch_size),
            stride=(
                stride_batch,
                stride_channel,
                stride_height * self.patch_size,
                stride_width * self.patch_size,
                stride_height,
                stride_width,
            ),
        )
        if self.flatten:
            return x.flatten(2, 3).transpose(1, 2).flatten(2)  # (batch_size, patches, channels)
        return x


class MLPBlock(nn.Module):
    def __init__(self, channels, expansion_factor, dropout_rate=0.0):
        super().__init__()
        expanded_channels = expansion_factor * channels
        self.fc1 = nn.Linear(channels, expanded_channels)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(expanded_channels, channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class MixerBlock(nn.Module):
    def __init__(self, patches, channels, expansion_factor, dropout_rate=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(channels)
        self.mlp1 = MLPBlock(patches, expansion_factor, dropout_rate)
        self.norm2 = nn.LayerNorm(channels)
        self.mlp2 = MLPBlock(channels, expansion_factor, dropout_rate)

    def forward(self, x):
        # x.shape = (batch_size, patches, channels)
        x = x + self.mlp1(self.norm1(x).transpose(1, 2)).transpose(1, 2)
        x = x + self.mlp2(self.norm2(x))
        return x


class MLPMixer(nn.Module):
    def __init__(self, image_size, patch_size, hidden, expansion_factor, num_blocks, num_classes=10, dropout_rate=0.0):
        super().__init__()
        self.projection = nn.Linear(patch_size**2 * 3, hidden)
        self.patch_embedding = PatchEmbedding(patch_size)
        self.mixer_blocks = nn.ModuleList(
            [
                MixerBlock(
                    patches=(image_size // patch_size) ** 2,
                    channels=hidden,
                    expansion_factor=expansion_factor,
                    dropout_rate=dropout_rate,
                )
                for _ in range(num_blocks)
            ]
        )
        self.norm = nn.LayerNorm(hidden)
        self.fc = nn.Linear(hidden, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.projection(x)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

In [None]:
batch_size = 2
channel = 3
image_size = 32
x = torch.randn(batch_size, channel, image_size, image_size)
patch_size = 4
num_patches = (image_size // patch_size) ** 2
assert PatchEmbedding(patch_size)(x).shape == (batch_size, num_patches, channel * patch_size**2)

In [None]:
NUM_TRAIN = 49000
data_path = "../code/cs231n/datasets"

transform_train = T.Compose(
    [T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
)

transform = T.Compose([T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

cifar10_train = dset.CIFAR10(data_path, train=True, download=True, transform=transform_train)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10(data_path, train=True, download=True, transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10(data_path, train=False, download=True, transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

cifra10_visualize = dset.CIFAR10(data_path, train=True, download=True, transform=T.ToTensor())
loader_visualize = DataLoader(
    cifra10_visualize, batch_size=16, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000))
)

images, labels = next(iter(loader_visualize))
batch_plot(images.permute(0, 2, 3, 1).numpy(), labels.numpy())

In [None]:
patch_size = 8
patches = PatchEmbedding(patch_size, flatten=False)(images).permute(0, 2, 3, 4, 5, 1)

for patch in patches[:4]:
    batch_plot(
        patch.reshape(-1, patch_size, patch_size, channel).numpy(),
        with_border=False,
        cmap="gray",
        tight_layout=None,
        wspace=0.01,
        hspace=0.01,
        imgsize=2,
        vmin=0,
        vmax=1,
    )

In [None]:
epoch = 10
log_dir = Path(data_path).joinpath("runs")

model = MLPMixer(
    image_size=32, patch_size=4, hidden=150, expansion_factor=2, num_blocks=4, num_classes=10, dropout_rate=0.25
)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch)

check_point_name = train_part_challenge(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=epoch,
    device=device,
    train_loader=loader_train,
    valid_loader=loader_val,
    log_dir=log_dir,
)

In [None]:
checkpoint = torch.load(check_point_name)
print(
    f"best model with train accuracy: {checkpoint['train_accuracy']:.2f} and valid accuracy {checkpoint['valid_accuracy']:.2f}"
)
model.load_state_dict(checkpoint["model_state_dict"])
check_accuracy_part34(loader_test, model, device=device)

<img src="images/training_process_mlp_mixer.png">

```
best model with train accuracy: 80.81 and valid accuracy 76.10
Checking accuracy on test set
Got 7456 / 10000 correct (74.56)
```