In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

# Model

In [5]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [6]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU(approximate='none')
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

## MLP-Mixer 

In [7]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim):
        super().__init__()
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = MlpBLock(patch_dim, [2])
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = MlpBLock(channel_dim, [2])
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2)
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = self.mlp_channel(y)
        x = x+y
        return x

In [8]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [9]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.5, num_blocks=1, num_classes=10)
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:120


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=120, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=120, out_features=240, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=240, out_features=120, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=7680, out_features=10, bias=True)
)

In [10]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  157706


In [11]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[-0.1648, -0.0510, -0.0486, -0.1505, -0.3137, -0.5226,  0.0023,  0.5233,
         -0.2587, -0.5398]], grad_fn=<AddmmBackward0>)

## Patch Mixer

In [12]:
class PatchMixerBlock(nn.Module):
    
    def __init__(self, patch_size, num_channel):
        super().__init__()
        self.patch_size = patch_size
        
#         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        ps = None
        if isinstance(patch_size, int):
            ps = patch_size**2
        else:
            ps = patch_size[0]*patch_size[1]
        ps = ps*num_channel
        
        self.ln0 = nn.LayerNorm(ps)
        self.mlp_patch = MlpBLock(ps, [2])
        
#         self.fold = nn.Fold(kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        ## x has shape-> N, C, H, W; C=Channel
        
        sz = x.shape
        
        y = nn.functional.unfold(x, 
                                 kernel_size=self.patch_size, 
                                 stride=self.patch_size
                                )
        #### mix per patch
        y = torch.swapaxes(y, -1, -2)
        y = self.ln0(y) 
        y = self.mlp_patch(y)
        y = torch.swapaxes(y, -1, -2)
        
        y = nn.functional.fold(y, (sz[-2], sz[-1]), 
                               kernel_size=self.patch_size, 
                               stride=self.patch_size
                              )
        x = x+y
        return x

In [13]:
pmb = PatchMixerBlock(8, 3)
pmb

PatchMixerBlock(
  (ln0): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): MlpBLock(
    (mlp): Sequential(
      (0): Linear(in_features=192, out_features=384, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
)

In [14]:
# pmb(torch.randn(1, 3, 35, 35)).shape

In [15]:
def get_factors(n):
    facts = []
    for i in range(2, n+1):
        if n%i == 0:
            facts.append(i)
    return facts

class PatchMlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_sizes:tuple, hidden_channels:int, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W)
        self.target_dim = np.prod(patch_sizes)
        
        ### find number of channel for input, the channel is 
        num_channel = image_dim[0]
        
        self.conv1x1 = nn.Conv2d(num_channel, hidden_channels, kernel_size=1, stride=1)
        if num_channel == hidden_channels:
            self.conv1x1 = nn.Identity()
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            for ps in patch_sizes:
                self.mixer_blocks.append(PatchMixerBlock(ps, hidden_channels))
                
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        self.linear = nn.Linear(self.target_dim*self.target_dim*hidden_channels, num_classes)
    
    def forward(self, x):
        bs = x.shape[0]
        
        x = nn.functional.interpolate(x, size=self.target_dim, mode='bilinear', align_corners=True)
        
        x = self.conv1x1(x) 
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [16]:
4*3*5

60

In [17]:
patch_mixer = PatchMlpMixer((3, 35, 35), patch_sizes=[5, 7], hidden_channels=3, num_blocks=1, num_classes=10)

In [18]:
patch_mixer

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
  )
  (linear): Linear(in_features=3675, out_features=10, bias=True)
)

In [19]:
print("number of params: ", sum(p.numel() for p in patch_mixer.parameters())) 

number of params:  146806


In [20]:
patch_mixer(torch.randn(1, 3, 32, 32)).shape

torch.Size([1, 10])

#### Final Model

In [21]:
from ptflops import get_model_complexity_info

SEED = -1
for num_cls in [10, 100]:
    for num_layers in [7, 10]:
        for i in range(3):
            if i==0:
                ## hard core ignore
                model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=num_layers, num_classes=num_cls)
                model_name = f'original_mixer0_l{num_layers}_{num_cls}_s{SEED}'
            elif i == 1:
                ### FOR ORIGINAL MIXER V1
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=num_layers, num_classes=num_cls)
                model_name = f'original_mixer1_l{num_layers}_{num_cls}_s{SEED}'
            elif i == 2:
                model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=num_layers, num_classes=num_cls)
                model_name = f'patchonly_mixer0_l{num_layers}_{num_cls}_s{SEED}'

            macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, ignore_modules=['channel_change'],
                                           print_per_layer_stat=False, verbose=False)
            
            print(model_name)
            print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
            print('{:<30}  {:<8}'.format('Number of parameters: ', params))
            print('')

MLP Mixer : Channes per patch -> Initial:48 Final:144
original_mixer0_l7_10_s-1
Computational complexity:       74.65 MMac
Number of parameters:           896.78 k

MLP Mixer : Channes per patch -> Initial:48 Final:153
original_mixer1_l7_10_s-1
Computational complexity:       60.48 MMac
Number of parameters:           884.41 k

patchonly_mixer0_l7_10_s-1
Computational complexity:       23.04 MMac
Number of parameters:           807.08 k

MLP Mixer : Channes per patch -> Initial:48 Final:144
original_mixer0_l10_10_s-1
Computational complexity:       106.36 MMac
Number of parameters:           1.23 M  

MLP Mixer : Channes per patch -> Initial:48 Final:153
original_mixer1_l10_10_s-1
Computational complexity:       86.16 MMac
Number of parameters:           1.22 M  

patchonly_mixer0_l10_10_s-1
Computational complexity:       32.9 MMac
Number of parameters:           1.14 M  

MLP Mixer : Channes per patch -> Initial:48 Final:144
original_mixer0_l7_100_s-1
Computational complexity:       

In [22]:
'''
original_mixer0_l7_10_s-1
Computational complexity:       74.65 MMac
Number of parameters:           896.78 k
channel_change: 560.02 KMac

original_mixer1_l7_10_s-1
Computational complexity:       60.48 MMac
Number of parameters:           884.41 k
channel_change: 470.17 KMac

patchonly_mixer0_l7_10_s-1
Computational complexity:       23.04 MMac
Number of parameters:           807.08 k

original_mixer0_l10_10_s-1
Computational complexity:       106.36 MMac
Number of parameters:           1.23 M  

original_mixer1_l10_10_s-1
Computational complexity:       86.16 MMac
Number of parameters:           1.22 M  

patchonly_mixer0_l10_10_s-1
Computational complexity:       32.9 MMac
Number of parameters:           1.14 M  

original_mixer0_l7_100_s-1
Computational complexity:       75.7 MMac
Number of parameters:           1.95 M  

original_mixer1_l7_100_s-1
Computational complexity:       61.36 MMac
Number of parameters:           1.77 M  

patchonly_mixer0_l7_100_s-1
Computational complexity:       23.37 MMac
Number of parameters:           1.14 M  

original_mixer0_l10_100_s-1
Computational complexity:       107.41 MMac
Number of parameters:           2.28 M  

original_mixer1_l10_100_s-1
Computational complexity:       87.04 MMac
Number of parameters:           2.1 M   

patchonly_mixer0_l10_100_s-1
Computational complexity:       33.23 MMac
Number of parameters:           1.47 M  
'''

'\noriginal_mixer0_l7_10_s-1\nComputational complexity:       74.65 MMac\nNumber of parameters:           896.78 k\nchannel_change: 560.02 KMac\n\noriginal_mixer1_l7_10_s-1\nComputational complexity:       60.48 MMac\nNumber of parameters:           884.41 k\nchannel_change: 470.17 KMac\n\npatchonly_mixer0_l7_10_s-1\nComputational complexity:       23.04 MMac\nNumber of parameters:           807.08 k\n\noriginal_mixer0_l10_10_s-1\nComputational complexity:       106.36 MMac\nNumber of parameters:           1.23 M  \n\noriginal_mixer1_l10_10_s-1\nComputational complexity:       86.16 MMac\nNumber of parameters:           1.22 M  \n\npatchonly_mixer0_l10_10_s-1\nComputational complexity:       32.9 MMac\nNumber of parameters:           1.14 M  \n\noriginal_mixer0_l7_100_s-1\nComputational complexity:       75.7 MMac\nNumber of parameters:           1.95 M  \n\noriginal_mixer1_l7_100_s-1\nComputational complexity:       61.36 MMac\nNumber of parameters:           1.77 M  \n\npatchonly_mix

In [23]:
# model

In [24]:
L = 7
C = 10

# model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=L, num_classes=C)
# model_name = f'original_mixer0_l7_c10'

model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=L, num_classes=C)
# model_name = f'original_mixer1_l7_c10'

# model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=L, num_classes=C)
# model_name = f'patchonly_mixer0_l7_c10'

MLP Mixer : Channes per patch -> Initial:48 Final:153


In [25]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=153, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((153,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((153,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=153, out_features=306, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=306, out_features=153, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((153,), eps=1e-05, elementwise_affine=Tru

In [26]:
from ptflops import get_model_complexity_info

In [27]:
macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

MlpMixer(
  884.41 k, 100.000% Params, 60.48 MMac, 100.000% MACs, 
  (scaler): UpsamplingBilinear2d(0, 0.000% Params, 0.0 Mac, 0.000% MACs, size=(32, 32), mode='bilinear')
  (unfold): Unfold(0, 0.000% Params, 0.0 Mac, 0.000% MACs, kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(7.5 k, 0.848% Params, 470.17 KMac, 0.777% MACs, in_features=48, out_features=153, bias=True)
  (mixer_blocks): Sequential(
    778.98 k, 88.079% Params, 59.91 MMac, 99.061% MACs, 
    (0): MixerBlock(
      111.28 k, 12.583% Params, 8.56 MMac, 14.152% MACs, 
      (ln0): LayerNorm(306, 0.035% Params, 9.79 KMac, 0.016% MACs, (153,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        16.58 k, 1.874% Params, 2.53 MMac, 4.177% MACs, 
        (mlp): Sequential(
          16.58 k, 1.874% Params, 2.53 MMac, 4.177% MACs, 
          (0): Linear(8.32 k, 0.941% Params, 1.25 MMac, 2.073% MACs, in_features=64, out_features=128, bias=True)
          (1): GELU(0, 0.000

In [28]:
# from thop import profile

In [29]:
# dummy_inputs = torch.rand(1, 3, 32, 32)
# macs, params, info = profile(model, inputs=(dummy_inputs, ), ret_layer_info=True)

In [30]:
# info

In [31]:
# from flopth import flopth

In [32]:
# # Or use input tensors
# dummy_inputs = torch.rand(1, 3, 32, 32)
# flops, params = flopth(model, inputs=(dummy_inputs,), show_detail=True)

In [33]:
# print(flops, params)

In [34]:
## mlpmixer0 -> 996.543K 896.779K
## mlpmixer1 -> 899.374K 884.408K
## patchmixer -> 233.464K 807.082K

In [35]:
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=num_layers, num_classes=num_cls)
model

MLP Mixer : Channes per patch -> Initial:48 Final:144


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=81, out_features=162, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=162, out_features=81, bias=True)
        )
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=144, out_features=288, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=288, out_features=144, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=Tru

In [36]:
### FOR ORIGINAL MIXER V1
model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=3.2, num_blocks=num_layers, num_classes=num_cls)
model

MLP Mixer : Channes per patch -> Initial:48 Final:153


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=153, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((153,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((153,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=153, out_features=306, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=306, out_features=153, bias=True)
        )
      )
    )
    (1): MixerBlock(
      (ln0): LayerNorm((153,), eps=1e-05, elementwise_affine=Tru

In [37]:
model = PatchMlpMixer((3, 35, 35), patch_sizes=[5,7], hidden_channels=3, num_blocks=num_layers, num_classes=num_cls)
model

PatchMlpMixer(
  (conv1x1): Identity()
  (mixer_blocks): Sequential(
    (0): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=150, out_features=75, bias=True)
        )
      )
    )
    (1): PatchMixerBlock(
      (ln0): LayerNorm((147,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=147, out_features=294, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=294, out_features=147, bias=True)
        )
      )
    )
    (2): PatchMixerBlock(
      (ln0): LayerNorm((75,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=75, out_features=150, bias=True)
          (1): GELU(a