[![deep-learning-notes](https://github.com/semilleroCV/deep-learning-notes/raw/main/assets/banner-notebook.png)](https://github.com/semilleroCV/deep-learning-notes)

## MLP Mixer

In [1]:
%%capture
#@title **Install required packages**

! pip install torchinfo einops

In [None]:
#@title **Importing libraries**

import torch #2.3.1+cu121
import torch.nn as nn 
import torchinfo #1.8.0

import einops #0.8.0
from einops.layers.torch import Rearrange

In [3]:
# Note: Not all dependencies have the __version__ method.

print(f"torch version: {torch.__version__}")
print(f"torchinfo version: {torchinfo.__version__}")
print(f"einops version: {einops.__version__}")

torch version: 2.3.1+cu121
torchinfo version: 1.8.0
einops version: 0.8.0


#### MLP-Mixer architecture code



In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int, patch_size: int, embed_dim: int):
        super().__init__()

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.rearrange = Rearrange('b e h w -> b (h w) e')

    def forward(self, x):
        _, _, H, W = x.size()

        x = self.proj(x)
        x = self.rearrange(x)

        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()

        self.mlp_blk = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x):
      return self.mlp_blk(x)


class MixerBlock(nn.Module):
    def __init__(self, dim: int, pix_per_patch: int, token_dim: int, channel_dim: int):
        super().__init__()

        self.token_mixer = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            MLPBlock(pix_per_patch, token_dim),
            Rearrange('b c n -> b n c'),
        )

        self.channel_mixer = nn.Sequential(
            nn.LayerNorm(dim),
            MLPBlock(dim, channel_dim),
        )

    def forward(self, x):

        x = x + self.token_mixer(x)
        x = x + self.channel_mixer(x)

        return x


class MLPMixer(nn.Module):
    def __init__(self, num_classes: int, hidden_dim: int, depth: int, in_channels: int = 3, img_size: int = 224,
                 patch_size: int = 16, token_dim: int = 256, channel_dim: int = 256):
        super().__init__()

        self.patch_embed = PatchEmbedding(in_channels, patch_size, hidden_dim)
        pix_per_patch =  (img_size// patch_size) ** 2

        self.mixer_blks = nn.Sequential()

        for i in range(depth):
            self.mixer_blks.add_module(f"MixerBlock_{i}", 
                                       MixerBlock(hidden_dim, pix_per_patch, token_dim, channel_dim))

        self.norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.mixer_blks(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)

        return x

In [7]:
model = MLPMixer(num_classes=1000, hidden_dim=512, depth=8, patch_size=16,
                 token_dim=256, channel_dim=2048)
torchinfo.summary(model, (3, 224, 224), batch_dim = 0)

Layer (type:depth-idx)                        Output Shape              Param #
MLPMixer                                      [1, 1000]                 --
├─PatchEmbedding: 1-1                         [1, 196, 512]             --
│    └─Conv2d: 2-1                            [1, 512, 14, 14]          393,728
│    └─Rearrange: 2-2                         [1, 196, 512]             --
├─Sequential: 1-2                             [1, 196, 512]             --
│    └─MixerBlock: 2-3                        [1, 196, 512]             --
│    │    └─Sequential: 3-1                   [1, 196, 512]             101,828
│    │    └─Sequential: 3-2                   [1, 196, 512]             2,100,736
│    └─MixerBlock: 2-4                        [1, 196, 512]             --
│    │    └─Sequential: 3-3                   [1, 196, 512]             101,828
│    │    └─Sequential: 3-4                   [1, 196, 512]             2,100,736
│    └─MixerBlock: 2-5                        [1, 196, 512]       