In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

In [5]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [6]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [7]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Model

In [9]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [10]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU(approximate='none')
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

## MLP-Mixer 

In [11]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [12]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [13]:
# class MlpMixer(nn.Module):
    
#     def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_size:tuple, num_blocks:int, num_classes:int, use1x1scale=True):
#         super().__init__()
        
#         self.img_dim = hidden_image_dim ### must contain (C, H, W) or (H, W)
        
#         ### find patch dim
#         d0 = int(self.img_dim[-2]/patch_size[0])
#         d1 = int(self.img_dim[-1]/patch_size[1])
#         assert d0*patch_size[0]==self.img_dim[-2], "Image must be divisible into patch size"
#         assert d1*patch_size[1]==self.img_dim[-1], "Image must be divisible into patch size"
        
#         ### find channel dim
#         channel_size = d0*d1 ## number of patches
        
#         ### after the number of channels are changed
#         init_dim = patch_size[0]*patch_size[1]*input_channels ## number of channels in each patch
#         final_dim = patch_size[0]*patch_size[1]*self.img_dim[0]

#         self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
#         if use1x1scale:
#             self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
#             if input_channels == self.img_dim[0]:
#                 self.conv1x1 = nn.Identity()
#         else:
#             #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
#             self.channel_change = nn.Linear(init_dim, final_dim) ## apply after unfold
            
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

#         print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
#         ### after the axis are swapped
#         self.channel_dim = final_dim
#         self.patch_dim = channel_size
        
#         self.mixer_blocks = []
#         for i in range(num_blocks):
#             self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
#         self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
#         self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
#     def forward(self, x):
#         bs = x.shape[0]
#         x = self.scaler(x) ## scale to high dimension
#         x = self.conv1x1(x)
#         x = self.unfold(x).swapaxes(-1, -2).contiguous() ## convert to patches, bring channels first
# #         x = self.channel_change(x) ## change the number of channels
#         x = self.mixer_blocks(x) ## the mixer architecture
#         x = self.linear(x.view(bs, -1)) ## classify
#         return x

In [14]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.5, num_blocks=1, num_classes=10)
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:120


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=120, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=120, out_features=240, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=240, out_features=120, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=7680, out_features=10, bias=True)
)

In [15]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  157706


In [16]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[-0.1648, -0.0510, -0.0486, -0.1505, -0.3137, -0.5226,  0.0023,  0.5233,
         -0.2587, -0.5398]], grad_fn=<AddmmBackward0>)

## Patch Mixer

In [17]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.ln0 = nn.LayerNorm(ps)
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.ln0(y) 
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [18]:
pmb = PatchMixerBlock(8, 3)
pmb

PatchMixerBlock(
  (ln0): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): MlpBLock(
    (mlp): Sequential(
      (0): Linear(in_features=192, out_features=384, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
)

In [19]:
# pmb(torch.randn(1, 3, 35, 35)).shape

In [20]:
def get_factors(n):
    facts = []
    for i in range(2, n+1):
        if n%i == 0:
            facts.append(i)
    return facts

class PatchMlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_sizes:tuple, hidden_channels:int, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W)
        self.target_dim = np.prod(patch_sizes)
        
        ### find number of channel for input, the channel is 
        num_channel = image_dim[0]
        
        self.conv1x1 = nn.Conv2d(num_channel, hidden_channels, kernel_size=1, stride=1)
        if num_channel == hidden_channels:
            self.conv1x1 = nn.Identity()
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            for ps in patch_sizes:
                self.mixer_blocks.append(PatchMixerBlock(ps, hidden_channels))
                
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        self.linear = nn.Linear(self.target_dim*self.target_dim*hidden_channels, num_classes)
    
    def forward(self, x):
        bs = x.shape[0]
        
        x = nn.functional.interpolate(x, size=self.target_dim, mode='bilinear', align_corners=True)
        
        x = self.conv1x1(x) 
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [21]:
4*3*5

60

In [22]:
patch_mixer = PatchMlpMixer((3, 35, 35), patch_sizes=[5, 7], hidden_channels=3, num_blocks=1, num_classes=10)

In [23]:
patch_mixer

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=3675, out_features=10, bias=True)
)

In [24]:
print("number of params: ", sum(p.numel() for p in patch_mixer.parameters())) 

number of params:  146806


In [25]:
patch_mixer(torch.randn(1, 3, 32, 32)).shape

torch.Size([1, 10])

#### Final Model

In [26]:
### Model 1 : balanced with (1) of PatchOnly ; higher image scaling, low channel scaling for similar params
#### it has large dimension of final hidden unit.
## 81, 150
# model = MlpMixer((3, 5*9, 5*9), (5, 5), hidden_expansion=2, num_blocks=10, num_classes=10)
### balanced model with 4*9 expansion as well
## 81, 144
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=10, num_classes=10)

### Model 2 : use simple method of expanding only hidden dimension/channel
## 64, 135
# model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=10, num_classes=10)


model = model.to(device)

MLP Mixer : Channes per patch -> Initial:48 Final:144


In [27]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=81, out_features=162, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=162, out_features=81, bias=True)
        )
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=144, out_features=288, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=288, out_features=144, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)

In [28]:
# print("number of params: ", sum(p.numel() for p in model.parameters())) 
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

number of params:  1104390


In [29]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [30]:
### Model 1 : balanced using best settings for default patch mixer without scaling hidden channels.
model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=10, num_classes=10)

model = model.to(device)

In [31]:
model

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
    (2): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(a

In [32]:
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

number of params:  1100460


In [33]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [34]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 
## Patch ||  1137220
## Mixer ||  1141703

number of params:  1137220


## Training

In [35]:
# model_name = f'mlp_mixer_c10_s{SEED}'
# model_name = f'temp_c10_s{SEED}'

In [36]:
# EPOCHS = 200
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [37]:
# STAT ={'train_stat':[], 'test_stat':[]}

In [38]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [39]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    time_taken = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)

            start = time.time()

            outputs = model(inputs)

            start = time.time()-start
            time_taken.append(start)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    with open(f"./output/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [40]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [41]:
# ### Train the whole damn thing

# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     train(epoch)
#     test(epoch)
#     scheduler.step()

In [42]:
# best_acc

In [43]:
# checkpoint = torch.load(f'./models/{model_name}.pth')
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# best_acc, start_epoch

In [44]:
# model.load_state_dict(checkpoint['model'])

In [45]:
# model

In [46]:
# STAT

In [47]:
# train_stat = np.array(STAT['train_stat'])
# test_stat = np.array(STAT['test_stat'])

In [48]:
# plt.plot(train_stat[:,1], label='train')
# plt.plot(test_stat[:,1], label='test')
# plt.ylabel("Loss")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_loss.svg")
# plt.show()

In [49]:
# plt.plot(train_stat[:,2], label='train')
# plt.plot(test_stat[:,2], label='test')
# plt.ylabel("Accuracy")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_accs.svg")
# plt.show()

## Benchmark Training

In [50]:
def get_data_loaders(seed, ds):
    BS = 64
    if ds == 'c100': BS = 128
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)
    return train_loader, test_loader

In [51]:
def benchmark():
    global model, optimizer, train_loader, test_loader, model_name, criterion, STAT, best_acc
    EPOCHS = 200
    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    DS = 'c100'
#     DS = 'c10'
    for SEED in [369]:
        for num_layers in [10]:

            for i in range(3): ## 3 models training
                print("Experiment index:", i)
                train_loader, test_loader = get_data_loaders(SEED, DS)
                num_cls = 10
                if DS=='c100': num_cls = 100
                ### FOR ORIGINAL MIXER V0
                if i==0:
                    ## hard core ignore
#                     if SEED==258 and num_layers==7: continue
                    model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'original_mixer0_l{num_layers}_{DS}_s{SEED}'
                elif i == 1:
                    ### FOR ORIGINAL MIXER V1
#                     if SEED==258 and num_layers==7: continue
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'original_mixer1_l{num_layers}_{DS}_s{SEED}'
                elif i == 2:
                    model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=num_layers, num_classes=num_cls)
                    model_name = f'patchonly_mixer0_l{num_layers}_{DS}_s{SEED}'
                else:
                    print("JPT........!!!!")
                model = model.to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                num_params = sum(p.numel() for p in model.parameters())

                model_name = "03.0_" + model_name
                print(f"EXPERIMENTING FOR : {model_name} | params: {num_params}  .......\n.......")
                
#                 continue
                STAT ={'train_stat':[], 'test_stat':[], 'num_params':num_params}
                best_acc = -1
                for epoch in range(0, EPOCHS): ## for 200 epochs
                    train(epoch)
                    test(epoch)
                    scheduler.step()
                print(f"Training finished\n")
                pass
            pass
        pass
    return 0           

In [None]:
benchmark()

Experiment index: 0
MLP Mixer : Channes per patch -> Initial:48 Final:144
EXPERIMENTING FOR : 03.0_original_mixer0_l10_c100_s369 | params: 2277946  .......
.......


100%|███████████████████████████████████████████████████| 391/391 [00:13<00:00, 29.19it/s]


[Train] 0 Loss: 3.723 | Acc: 15.414 7707/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.85it/s]


[Test] 0 Loss: 3.171 | Acc: 24.090 2409/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.43it/s]


[Train] 1 Loss: 3.098 | Acc: 25.172 12586/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 54.19it/s]


[Test] 1 Loss: 2.862 | Acc: 30.060 3006/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.84it/s]


[Train] 2 Loss: 2.823 | Acc: 30.350 15175/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.43it/s]


[Test] 2 Loss: 2.615 | Acc: 34.840 3484/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.69it/s]


[Train] 3 Loss: 2.613 | Acc: 34.572 17286/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.97it/s]


[Test] 3 Loss: 2.514 | Acc: 37.100 3710/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.85it/s]


[Train] 4 Loss: 2.441 | Acc: 38.094 19047/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 62.27it/s]


[Test] 4 Loss: 2.445 | Acc: 38.760 3876/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.64it/s]


[Train] 5 Loss: 2.288 | Acc: 41.498 20749/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.21it/s]


[Test] 5 Loss: 2.285 | Acc: 42.750 4275/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.62it/s]


[Train] 6 Loss: 2.150 | Acc: 44.430 22215/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.95it/s]


[Test] 6 Loss: 2.214 | Acc: 44.620 4462/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.55it/s]


[Train] 7 Loss: 2.029 | Acc: 46.924 23462/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 61.65it/s]


[Test] 7 Loss: 2.136 | Acc: 45.980 4598/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.83it/s]


[Train] 8 Loss: 1.918 | Acc: 49.632 24816/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.68it/s]


[Test] 8 Loss: 2.150 | Acc: 46.940 4694/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.67it/s]


[Train] 9 Loss: 1.814 | Acc: 52.066 26033/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.13it/s]


[Test] 9 Loss: 2.146 | Acc: 47.650 4765/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.49it/s]


[Train] 10 Loss: 1.705 | Acc: 54.364 27182/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.07it/s]


[Test] 10 Loss: 2.132 | Acc: 48.690 4869/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.79it/s]


[Train] 11 Loss: 1.616 | Acc: 56.414 28207/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.61it/s]


[Test] 11 Loss: 2.137 | Acc: 48.730 4873/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.50it/s]


[Train] 12 Loss: 1.510 | Acc: 58.868 29434/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.19it/s]


[Test] 12 Loss: 2.158 | Acc: 49.920 4992/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.54it/s]


[Train] 13 Loss: 1.398 | Acc: 61.522 30761/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.03it/s]


[Test] 13 Loss: 2.177 | Acc: 50.170 5017/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.86it/s]


[Train] 14 Loss: 1.306 | Acc: 63.718 31859/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.88it/s]


[Test] 14 Loss: 2.258 | Acc: 50.450 5045/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.51it/s]


[Train] 15 Loss: 1.189 | Acc: 66.702 33351/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.27it/s]


[Test] 15 Loss: 2.305 | Acc: 50.580 5058/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.57it/s]


[Train] 16 Loss: 1.108 | Acc: 68.524 34262/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.09it/s]


[Test] 16 Loss: 2.388 | Acc: 50.720 5072/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.62it/s]


[Train] 17 Loss: 1.012 | Acc: 71.278 35639/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.18it/s]


[Test] 17 Loss: 2.484 | Acc: 50.030 5003/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.48it/s]


[Train] 18 Loss: 0.940 | Acc: 73.008 36504/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.09it/s]


[Test] 18 Loss: 2.584 | Acc: 50.100 5010/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.50it/s]


[Train] 19 Loss: 0.864 | Acc: 75.256 37628/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.37it/s]


[Test] 19 Loss: 2.688 | Acc: 51.080 5108/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.45it/s]


[Train] 20 Loss: 0.788 | Acc: 77.208 38604/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 60.31it/s]


[Test] 20 Loss: 2.803 | Acc: 51.080 5108/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.52it/s]


[Train] 21 Loss: 0.737 | Acc: 78.702 39351/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.42it/s]


[Test] 21 Loss: 2.826 | Acc: 51.920 5192/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.52it/s]


[Train] 22 Loss: 0.679 | Acc: 80.342 40171/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.44it/s]


[Test] 22 Loss: 3.011 | Acc: 50.310 5031/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.48it/s]


[Train] 23 Loss: 0.630 | Acc: 81.800 40900/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 55.90it/s]


[Test] 23 Loss: 3.167 | Acc: 51.280 5128/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.84it/s]


[Train] 24 Loss: 0.614 | Acc: 82.440 41220/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.13it/s]


[Test] 24 Loss: 3.239 | Acc: 51.350 5135/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.55it/s]


[Train] 25 Loss: 0.562 | Acc: 83.664 41832/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.47it/s]


[Test] 25 Loss: 3.423 | Acc: 51.460 5146/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.45it/s]


[Train] 26 Loss: 0.543 | Acc: 84.500 42250/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.80it/s]


[Test] 26 Loss: 3.479 | Acc: 51.040 5104/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.03it/s]


[Train] 27 Loss: 0.511 | Acc: 85.408 42704/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 54.55it/s]


[Test] 27 Loss: 3.667 | Acc: 51.060 5106/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.63it/s]


[Train] 28 Loss: 0.486 | Acc: 86.220 43110/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.00it/s]


[Test] 28 Loss: 3.637 | Acc: 51.470 5147/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.57it/s]


[Train] 29 Loss: 0.458 | Acc: 87.042 43521/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.07it/s]


[Test] 29 Loss: 3.754 | Acc: 51.470 5147/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.57it/s]


[Train] 30 Loss: 0.446 | Acc: 87.518 43759/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 62.73it/s]


[Test] 30 Loss: 3.895 | Acc: 51.980 5198/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.63it/s]


[Train] 31 Loss: 0.453 | Acc: 87.506 43753/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.51it/s]


[Test] 31 Loss: 3.970 | Acc: 51.420 5142/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.58it/s]


[Train] 32 Loss: 0.423 | Acc: 88.496 44248/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.29it/s]


[Test] 32 Loss: 4.215 | Acc: 51.340 5134/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.60it/s]


[Train] 33 Loss: 0.403 | Acc: 89.172 44586/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 55.40it/s]


[Test] 33 Loss: 4.253 | Acc: 50.880 5088/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.93it/s]


[Train] 34 Loss: 0.381 | Acc: 89.804 44902/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.90it/s]


[Test] 34 Loss: 4.270 | Acc: 51.340 5134/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.60it/s]


[Train] 35 Loss: 0.370 | Acc: 89.944 44972/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.79it/s]


[Test] 35 Loss: 4.528 | Acc: 51.350 5135/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.47it/s]


[Train] 36 Loss: 0.367 | Acc: 90.364 45182/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.82it/s]


[Test] 36 Loss: 4.527 | Acc: 50.910 5091/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.81it/s]


[Train] 37 Loss: 0.376 | Acc: 90.098 45049/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 55.94it/s]


[Test] 37 Loss: 4.576 | Acc: 51.360 5136/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.58it/s]


[Train] 38 Loss: 0.340 | Acc: 90.960 45480/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.05it/s]


[Test] 38 Loss: 4.733 | Acc: 50.650 5065/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.67it/s]


[Train] 39 Loss: 0.340 | Acc: 91.128 45564/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.92it/s]


[Test] 39 Loss: 4.750 | Acc: 51.540 5154/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.51it/s]


[Train] 40 Loss: 0.339 | Acc: 91.450 45725/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 61.72it/s]


[Test] 40 Loss: 4.923 | Acc: 51.070 5107/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.52it/s]


[Train] 41 Loss: 0.340 | Acc: 91.274 45637/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.39it/s]


[Test] 41 Loss: 5.097 | Acc: 51.210 5121/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.49it/s]


[Train] 42 Loss: 0.321 | Acc: 91.890 45945/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.34it/s]


[Test] 42 Loss: 5.022 | Acc: 51.550 5155/10000


100%|███████████████████████████████████████████████████| 391/391 [00:16<00:00, 24.33it/s]


[Train] 43 Loss: 0.311 | Acc: 92.132 46066/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 55.76it/s]


[Test] 43 Loss: 5.083 | Acc: 51.860 5186/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.77it/s]


[Train] 44 Loss: 0.296 | Acc: 92.648 46324/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.16it/s]


[Test] 44 Loss: 5.167 | Acc: 51.390 5139/10000


100%|███████████████████████████████████████████████████| 391/391 [00:16<00:00, 24.42it/s]


[Train] 45 Loss: 0.272 | Acc: 93.152 46576/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.93it/s]


[Test] 45 Loss: 5.426 | Acc: 51.410 5141/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.54it/s]


[Train] 46 Loss: 0.293 | Acc: 92.710 46355/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.24it/s]


[Test] 46 Loss: 5.496 | Acc: 51.660 5166/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.03it/s]


[Train] 47 Loss: 0.280 | Acc: 93.030 46515/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.60it/s]


[Test] 47 Loss: 5.515 | Acc: 51.040 5104/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.54it/s]


[Train] 48 Loss: 0.294 | Acc: 92.920 46460/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.18it/s]


[Test] 48 Loss: 5.828 | Acc: 50.670 5067/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.57it/s]


[Train] 49 Loss: 0.277 | Acc: 93.394 46697/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.19it/s]


[Test] 49 Loss: 5.612 | Acc: 52.450 5245/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.71it/s]


[Train] 50 Loss: 0.266 | Acc: 93.712 46856/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.81it/s]


[Test] 50 Loss: 5.855 | Acc: 51.180 5118/10000


100%|███████████████████████████████████████████████████| 391/391 [00:16<00:00, 24.42it/s]


[Train] 51 Loss: 0.258 | Acc: 93.870 46935/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.46it/s]


[Test] 51 Loss: 5.906 | Acc: 51.900 5190/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.45it/s]


[Train] 52 Loss: 0.259 | Acc: 93.892 46946/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.11it/s]


[Test] 52 Loss: 6.043 | Acc: 51.040 5104/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.58it/s]


[Train] 53 Loss: 0.258 | Acc: 93.948 46974/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 59.48it/s]


[Test] 53 Loss: 6.015 | Acc: 51.630 5163/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.76it/s]


[Train] 54 Loss: 0.238 | Acc: 94.372 47186/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.27it/s]


[Test] 54 Loss: 5.978 | Acc: 51.670 5167/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.72it/s]


[Train] 55 Loss: 0.232 | Acc: 94.710 47355/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.26it/s]


[Test] 55 Loss: 6.175 | Acc: 51.050 5105/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.52it/s]


[Train] 56 Loss: 0.242 | Acc: 94.416 47208/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.87it/s]


[Test] 56 Loss: 6.257 | Acc: 50.730 5073/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 25.03it/s]


[Train] 57 Loss: 0.215 | Acc: 95.160 47580/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.85it/s]


[Test] 57 Loss: 6.384 | Acc: 51.250 5125/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.61it/s]


[Train] 58 Loss: 0.232 | Acc: 94.770 47385/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.69it/s]


[Test] 58 Loss: 6.421 | Acc: 51.120 5112/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.62it/s]


[Train] 59 Loss: 0.218 | Acc: 95.044 47522/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.33it/s]


[Test] 59 Loss: 6.648 | Acc: 51.540 5154/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.86it/s]


[Train] 60 Loss: 0.205 | Acc: 95.294 47647/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 59.72it/s]


[Test] 60 Loss: 6.496 | Acc: 52.340 5234/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.47it/s]


[Train] 61 Loss: 0.207 | Acc: 95.354 47677/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.56it/s]


[Test] 61 Loss: 6.761 | Acc: 51.530 5153/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.61it/s]


[Train] 62 Loss: 0.205 | Acc: 95.402 47701/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.93it/s]


[Test] 62 Loss: 6.571 | Acc: 52.050 5205/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.51it/s]


[Train] 63 Loss: 0.185 | Acc: 95.816 47908/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.23it/s]


[Test] 63 Loss: 6.759 | Acc: 52.020 5202/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.72it/s]


[Train] 64 Loss: 0.197 | Acc: 95.650 47825/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.47it/s]


[Test] 64 Loss: 6.890 | Acc: 51.700 5170/10000


100%|███████████████████████████████████████████████████| 391/391 [00:16<00:00, 24.30it/s]


[Train] 65 Loss: 0.208 | Acc: 95.436 47718/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.50it/s]


[Test] 65 Loss: 6.994 | Acc: 51.530 5153/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.57it/s]


[Train] 66 Loss: 0.197 | Acc: 95.748 47874/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.43it/s]


[Test] 66 Loss: 6.944 | Acc: 52.320 5232/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.94it/s]


[Train] 67 Loss: 0.184 | Acc: 95.960 47980/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.79it/s]


[Test] 67 Loss: 7.042 | Acc: 52.250 5225/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.56it/s]


[Train] 68 Loss: 0.181 | Acc: 96.076 48038/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 53.41it/s]


[Test] 68 Loss: 7.237 | Acc: 51.610 5161/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.53it/s]


[Train] 69 Loss: 0.177 | Acc: 96.124 48062/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 52.30it/s]


[Test] 69 Loss: 7.173 | Acc: 52.200 5220/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.81it/s]


[Train] 70 Loss: 0.183 | Acc: 96.252 48126/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 58.46it/s]


[Test] 70 Loss: 7.320 | Acc: 51.950 5195/10000


100%|███████████████████████████████████████████████████| 391/391 [00:15<00:00, 24.67it/s]


[Train] 71 Loss: 0.166 | Acc: 96.424 48212/50000


100%|█████████████████████████████████████████████████████| 79/79 [00:01<00:00, 51.31it/s]


[Test] 71 Loss: 7.163 | Acc: 53.030 5303/10000
Saving..


 39%|████████████████████                               | 154/391 [00:06<00:09, 24.17it/s]

In [None]:
asdfasdfwdf

# Models -- New Test 
1. Scaling is different for original mlp mixer so not useful for benchmarking
2. MLP mixer uses channel wise mixing, here channels are number of patches; it changes differntly from patch size when scaling.

In [None]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [None]:
MlpBLock(2, [3,4])

## MLP-Mixer 

In [None]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [None]:
class MlpMixer(nn.Module):
    
    def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_size:tuple, num_blocks:int, num_classes:int, use1x1scale=True):
        super().__init__()
        
        self.img_dim = hidden_image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(self.img_dim[-2]/patch_size[0])
        d1 = int(self.img_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==self.img_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==self.img_dim[-1], "Image must be divisible into patch size"
        
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = patch_size[0]*patch_size[1]*input_channels ## number of channels in each patch
        final_dim = patch_size[0]*patch_size[1]*self.img_dim[0]

        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        if use1x1scale:
            self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
            if input_channels == self.img_dim[0]:
                self.conv1x1 = nn.Identity()
        else:
            #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
            self.channel_change = nn.Linear(init_dim, final_dim) ## apply after unfold
            
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        ### after the axis are swapped
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x) ## scale to high dimension
        x = self.conv1x1(x)
        x = self.unfold(x).swapaxes(-1, -2).contiguous() ## convert to patches, bring channels first
#         x = self.channel_change(x) ## change the number of channels
        x = self.mixer_blocks(x) ## the mixer architecture
        x = self.linear(x.view(bs, -1)) ## classify
        return x

In [None]:
mixer = MlpMixer(3, (5, 35, 35), (5, 5), num_blocks=1, num_classes=10, use1x1scale=False)
mixer

In [None]:
mixer(torch.randn(1, 3, 32, 32))

## Patch Mixer

In [None]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.ln0 = nn.LayerNorm(ps)
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.ln0(y) 
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [None]:
pmb = PatchMixerBlock(7, 5)
pmb

In [None]:
pmb(torch.randn(1, 5, 35, 35)).shape

In [None]:
def get_factors(n):
    facts = []
    for i in range(2, n+1):
        if n%i == 0:
            facts.append(i)
    return facts

class PatchMlpMixer(nn.Module):
    
    def __init__(self, input_channels:int, hidden_image_dim:tuple, patch_sizes:list, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = hidden_image_dim ### must contain (C, H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        assert np.prod(patch_sizes) == hidden_image_dim[-1], "The product of patches must equal image dimension"
        assert np.prod(patch_sizes) == hidden_image_dim[-2], "The product of patches must equal image dimension"
        
        ### find number of channel for input, the channel is 
        self.conv1x1 = nn.Conv2d(input_channels, self.img_dim[0], kernel_size=1, stride=1)
        if input_channels == self.img_dim[0]:
            self.conv1x1 = nn.Identity()
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            for ps in patch_sizes:
                self.mixer_blocks.append(PatchMixerBlock(ps, self.img_dim[0]))
                
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        self.linear = nn.Linear(np.prod(self.img_dim), num_classes)
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.conv1x1(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [None]:
patch_mixer = PatchMlpMixer(3, (5, 35, 35), patch_sizes=[5, 7], num_blocks=1, num_classes=10)

In [None]:
patch_mixer

In [None]:
patch_mixer(torch.randn(1, 3, 32, 32)).shape

#### Final Model

In [None]:
model = MlpMixer(3, (9, 32, 32), patch_size=(4, 4), num_blocks=10, num_classes=10)
model = model.to(device)

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 

In [None]:
model

In [None]:
model = PatchMlpMixer(3, (3, 35, 35), patch_sizes=[5, 7], num_blocks=10, num_classes=10)
model = model.to(device)

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 

In [None]:
model

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 
## Patch ||  1137220
## Mixer ||  1141703