In [1]:
!pip install einops



In [2]:
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange
from torchvision import datasets, transforms

In [3]:
class MlpBlock(nn.Module):
    def __init__(self, in_dim, hidden_dim, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden_dim, in_dim),
            nn.Dropout(p)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
class MixerBlock(nn.Module):
    def __init__(self, num_patches, chn_dim, tok_hid_dim, chn_hid_dim, p=0.):
        super().__init__()
        self.token_mixing = nn.Sequential(
                                nn.LayerNorm(chn_dim),
                                Rearrange('b t d -> b d t'),
                                MlpBlock(num_patches, tok_hid_dim, p),
                                Rearrange('b d t -> b t d')
                            )
        self.channel_mixing = nn.Sequential(
                                nn.LayerNorm(chn_dim),
                                MlpBlock(chn_dim, chn_hid_dim, p)
                            )
        
    def forward(self, x):
        x = x + self.token_mixing(x)
        x = x + self.channel_mixing(x)
        return x

In [5]:
class MlpMixer(nn.Module):
    def __init__(self, in_channels, img_size, chn_dim, patch_size, num_blocks, tok_hid_dim, chn_hid_dim, num_classes, p=0.):
        super().__init__()
        assert img_size % patch_size == 0, 'image size must be divisible by patch size!!'
        num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = nn.Sequential(
                                    nn.Conv2d(in_channels, chn_dim, kernel_size=patch_size, stride=patch_size),
                                    Rearrange('b c h w -> b (h w) c')
                                )
        self.mixer_blocks = nn.ModuleList([MixerBlock(num_patches, chn_dim, tok_hid_dim, chn_dim, p) for _ in range(num_blocks)])
        self.ln = nn.LayerNorm(chn_dim)
        self.fc_out = nn.Linear(chn_dim, num_classes)
    
    def forward(self, x):
        x = self.patch_embedding(x)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.ln(x)
        x = x.mean(1)
        return self.fc_out(x)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 10
in_channels = 3
img_size = 224
chn_dim = 512
patch_size = 32
num_blocks = 8
tok_hid_dim = 256
chn_hid_dim = 2048
num_classes = 10
p = 0.
batch_size = 64
lr = 3e-4
T = transforms.Compose(
    [
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor()
    ]
)
print(device)

cuda


In [7]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=T)
val_data = datasets.CIFAR10("data/", train=False, download=True, transform=T)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
x, y = next(iter(train_loader))
print(len(train_data), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
50000 torch.Size([64, 3, 224, 224]) torch.Size([64])


In [8]:
net = MlpMixer(in_channels, img_size, chn_dim, patch_size, num_blocks, tok_hid_dim, chn_hid_dim, num_classes, p).to(device)
inp = torch.randn(1, 3, 224, 224).to(device)
out = net(inp)
print(out.shape)
del inp, out

torch.Size([1, 10])


In [9]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)
    return acc

In [10]:
def loop(net, loader, is_train):
    net.train(is_train)
    losses = []
    accs = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())
        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f'epoch={epoch}, train={int(is_train)}')
        pbar.set_postfix(loss=f'{np.mean(losses):.4f}', acc=f'{np.mean(accs):.4f}')

In [11]:
for epoch in range(n_epochs):
    loop(net, train_loader, True)
    loop(net, val_loader, False)

epoch=0, train=1: 100%|██████████| 782/782 [01:19<00:00,  9.89it/s, acc=0.4245, loss=1.5849]
epoch=0, train=0: 100%|██████████| 157/157 [00:14<00:00, 11.16it/s, acc=0.5237, loss=1.3279]
epoch=1, train=1: 100%|██████████| 782/782 [01:18<00:00,  9.90it/s, acc=0.5642, loss=1.2171]
epoch=1, train=0: 100%|██████████| 157/157 [00:14<00:00, 11.11it/s, acc=0.5903, loss=1.1324]
epoch=2, train=1: 100%|██████████| 782/782 [01:19<00:00,  9.87it/s, acc=0.6265, loss=1.0543]
epoch=2, train=0: 100%|██████████| 157/157 [00:14<00:00, 11.18it/s, acc=0.6266, loss=1.0482]
epoch=3, train=1: 100%|██████████| 782/782 [01:19<00:00,  9.85it/s, acc=0.6669, loss=0.9388]
epoch=3, train=0: 100%|██████████| 157/157 [00:14<00:00, 11.08it/s, acc=0.6485, loss=0.9915]
epoch=4, train=1: 100%|██████████| 782/782 [01:19<00:00,  9.78it/s, acc=0.7047, loss=0.8340]
epoch=4, train=0: 100%|██████████| 157/157 [00:14<00:00, 11.06it/s, acc=0.6559, loss=0.9702]
epoch=5, train=1: 100%|██████████| 782/782 [01:19<00:00,  9.81it/s, ac