In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
import torch.optim as optim
from torch.utils import data

In [4]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [6]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([50, 1, 28, 28])

# Model

## MLP-Mixer 

In [7]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [8]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU()
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU()
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [9]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [10]:
## image has C,H,W, which when converted into patches has, P, C

In [11]:
mb = MixerBlock(32, 8)

In [12]:
mb(torch.randn(1, 32, 8)).shape

torch.Size([1, 32, 8])

In [13]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d0*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
        final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [14]:
mixer = MlpMixer((1, 28, 28), (4, 4), hidden_expansion=2, num_blocks=1, num_classes=10)
mixer

MLP Mixer : Channes per patch -> Initial:16 Final:32


MlpMixer(
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=16, out_features=32, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=49, out_features=98, bias=True)
          (1): GELU()
          (2): Linear(in_features=98, out_features=49, bias=True)
        )
      )
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=32, out_features=64, bias=True)
          (1): GELU()
          (2): Linear(in_features=64, out_features=32, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=1568, out_features=10, bias=True)
)

In [15]:
mixer(torch.randn(3, 1, 28, 28))

tensor([[-0.3522,  0.2494,  0.4296, -0.4552, -0.3652, -0.4657,  0.0391,  0.4927,
          0.0118, -0.2612],
        [ 0.0070, -0.2540,  0.2559, -0.1609,  0.4585, -0.0064,  0.5060,  0.2739,
         -0.1393, -1.3371],
        [-0.4300,  0.2414,  0.3529,  0.5127,  0.2702,  0.0775,  0.0328, -0.3654,
          0.1663, -0.0572]], grad_fn=<AddmmBackward>)

#### Final Model

In [16]:
model = MlpMixer((1, 28, 28), (4, 4), hidden_expansion=1, num_blocks=5, num_classes=10)
model = model.to(device)

MLP Mixer : Channes per patch -> Initial:16 Final:16


In [17]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  62557


## Training

In [18]:
 ## debugging to find the good classifier/output distribution.
model_name = 'mlp_mixer_fmnist_v0'

In [19]:
EPOCHS = 20
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [20]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [21]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [22]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [23]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|██████████| 1200/1200 [00:39<00:00, 30.17it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 0 Loss: 0.445 | Acc: 83.617 50170/60000


100%|██████████| 200/200 [00:02<00:00, 70.76it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 0 Loss: 0.402 | Acc: 85.170 8517/10000
Saving..


100%|██████████| 1200/1200 [00:39<00:00, 30.24it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 1 Loss: 0.336 | Acc: 87.622 52573/60000


100%|██████████| 200/200 [00:02<00:00, 70.22it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 1 Loss: 0.358 | Acc: 87.050 8705/10000
Saving..


100%|██████████| 1200/1200 [00:39<00:00, 30.27it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 2 Loss: 0.300 | Acc: 88.883 53330/60000


100%|██████████| 200/200 [00:02<00:00, 70.19it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 2 Loss: 0.331 | Acc: 87.960 8796/10000
Saving..


100%|██████████| 1200/1200 [00:39<00:00, 30.34it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 3 Loss: 0.271 | Acc: 89.918 53951/60000


100%|██████████| 200/200 [00:02<00:00, 70.28it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 3 Loss: 0.334 | Acc: 87.400 8740/10000


100%|██████████| 1200/1200 [00:39<00:00, 30.20it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 4 Loss: 0.250 | Acc: 90.712 54427/60000


100%|██████████| 200/200 [00:02<00:00, 70.23it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 4 Loss: 0.337 | Acc: 87.790 8779/10000


100%|██████████| 1200/1200 [00:39<00:00, 30.24it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

[Train] 5 Loss: 0.228 | Acc: 91.363 54818/60000


100%|██████████| 200/200 [00:02<00:00, 70.00it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

[Test] 5 Loss: 0.316 | Acc: 88.760 8876/10000
Saving..


 16%|█▋        | 196/1200 [00:06<00:34, 29.51it/s]


KeyboardInterrupt: 

In [24]:
best_acc

88.76

In [25]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(88.76, 5)

In [26]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [27]:
model

MlpMixer(
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=16, out_features=16, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=49, out_features=98, bias=True)
          (1): GELU()
          (2): Linear(in_features=98, out_features=49, bias=True)
        )
      )
      (ln1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=16, out_features=32, bias=True)
          (1): GELU()
          (2): Linear(in_features=32, out_features=16, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=49, out_features=9