In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

In [5]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [7]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

In [8]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Model

In [9]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [10]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU()
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU()
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [11]:
import sparse_nonlinear_lib as snl

In [12]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])

BlockMLP_MixerBlock(
  (facto_nets): ModuleList(
    (0): BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [16, 16, 16]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [16, 16, 16]
      )
    )
    (1): BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [16, 16, 16]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [16, 16, 16]
      )
    )
  )
)

In [86]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [87]:
snl.BlockLinear_MixerBlock(256, 16)

BlockLinear_MixerBlock(
  (facto_nets): ModuleList(
    (0): BlockWeight: [16, 16, 16]
    (1): BlockWeight: [16, 16, 16]
  )
)

In [88]:
snl.BlockLinear_MixerBlock(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [89]:
class SparseResMlp(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_expansion=2, actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        self.hex = hidden_expansion
            
        self.layers1s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=True) for _ in range(hidden_expansion)]
        self.layers2s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=False) for _ in range(hidden_expansion)]
        self.bias = nn.Parameter(torch.zeros(1, input_dim))
        
        self.layers1s = nn.ModuleList(self.layers1s)
        self.actf = actf()
        self.layers2s = nn.ModuleList(self.layers2s)
        
    def forward(self, x):
        y = x+self.bias
        for i in range(self.hex):
            h = self.layers1s[i](x)
            h = self.actf(h)
            h = self.layers2s[i](x)
            y += h
        return h

In [90]:
model = SparseResMlp(256, 16)
# model.layers1s[0].bias
model

SparseResMlp(
  (layers1s): ModuleList(
    (0): BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0): BlockWeight: [16, 16, 16]
        (1): BlockWeight: [16, 16, 16]
      )
    )
    (1): BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0): BlockWeight: [16, 16, 16]
        (1): BlockWeight: [16, 16, 16]
      )
    )
  )
  (actf): GELU()
  (layers2s): ModuleList(
    (0): BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0): BlockWeight: [16, 16, 16]
        (1): BlockWeight: [16, 16, 16]
      )
    )
    (1): BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0): BlockWeight: [16, 16, 16]
        (1): BlockWeight: [16, 16, 16]
      )
    )
  )
)

In [91]:
SparseResMlp(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [92]:
asdasd

NameError: name 'asdasd' is not defined

### Block MLP - without res

In [202]:
class BlockMLP(nn.Module):
    def __init__(self, input_dim, layer_dims, actf=nn.ELU):
        super().__init__()
        self.block_dim = layer_dims[0]
        
        assert input_dim%self.block_dim == 0, "Input dim must be even number"
        ### Create a block MLP
        self.mlp = []
        n_blocks = input_dim//layer_dims[0]
        for i in range(len(layer_dims)-1):
            l = snl.BlockLinear(n_blocks, layer_dims[i], layer_dims[i+1])
            a = actf()
            self.mlp.append(l)
            self.mlp.append(a)
        self.mlp = self.mlp[:-1]
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, self.block_dim).transpose(0,1)
        x = self.mlp(x)
        x = x.transpose(1,0).reshape(bs, -1)
        return x
    
############################################################################
############################################################################

class BlockMLP_MixerBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_layers_ratio=[2], actf=nn.ELU):
        super().__init__()
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.block_dim = block_dim
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
        hidden_layers_ratio = [1] + hidden_layers_ratio + [1]
        
        block_layer_dims = [int(a*block_dim) for a in hidden_layers_ratio]
        self.facto_nets = []
        for i in range(num_layers):
            net = BlockMLP(self.input_dim, block_layer_dims, actf)
            self.facto_nets.append(net)
            
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        bs = x.shape[0]
        y = x
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1, self.block_dim, self.block_dim**i).permute(0, 2, 1).contiguous().view(bs, -1)
            y = fn(y)
            y = y.view(-1, self.block_dim**i, self.block_dim).permute(0, 2, 1).contiguous()

        y = y.view(bs, -1)
        return y

## MLP-Mixer 

In [203]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim, patch_mixing="dense", channel_mixing="dense"):
        super().__init__()
        
        self.valid_functions = ["dense", "sparse_linear", "sparse_mlp"]
        assert patch_mixing in self.valid_functions
        assert channel_mixing in self.valid_functions
        
        self.patch_dim = patch_dim
        self.channel_dim = channel_dim
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = self.get_mlp(patch_dim, patch_mixing)    
        
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = self.get_mlp(channel_dim, channel_mixing)
    
    def get_mlp(self, dim, mixing_function, actf=nn.GELU):
        block_dim = int(np.sqrt(dim))
        assert block_dim**2 == dim, "Sparsifying dimension must be a square number"
        
        if mixing_function == self.valid_functions[0]:
            mlp = MlpBLock(dim, [2], actf)
        elif mixing_function == self.valid_functions[1]:
            mlp = SparseResMlp(dim, block_dim, 2, actf)
        elif mixing_function == self.valid_functions[2]:
            mlp = BlockMLP_MixerBlock(dim, block_dim, [2], actf)
        return mlp
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2).contiguous()
        
        y = y.view(-1, self.patch_dim)
        y = self.mlp_patch(y)
        y = y.view(-1, self.channel_dim, self.patch_dim)
        
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = y.view(-1, self.channel_dim)
        y = self.mlp_channel(y)
        y = y.view(-1, self.patch_dim, self.channel_dim)
        
        x = x+y
        return x

In [204]:
# class MixerBlock(nn.Module):
    
#     def __init__(self, patch_dim, channel_dim):
#         super().__init__()
        
#         self.ln0 = nn.LayerNorm(channel_dim)
#         self.mlp_patch = MlpBLock(patch_dim, [2])
#         self.ln1 = nn.LayerNorm(channel_dim)
#         self.mlp_channel = MlpBLock(channel_dim, [2])
    
#     def forward(self, x):
#         ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
#         ######## !!!! Can use same mixer on shape of -> N, C, P;
        
#         #### mix per patch
#         y = self.ln0(x) ### per channel layer normalization ?? 
#         y = torch.swapaxes(y, -1, -2)
#         y = self.mlp_patch(y)
#         y = torch.swapaxes(y, -1, -2)
#         x = x+y
        
#         #### mix per channel 
#         y = self.ln1(x)
#         y = self.mlp_channel(y)
#         x = x+y
#         return x

In [205]:
model = MixerBlock(2*2*4, 16*16, channel_mixing="sparse_mlp", patch_mixing="sparse_linear")
# model = MixerBlock(2*2*4, 16*16)
model

MixerBlock(
  (ln0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): SparseResMlp(
    (layers1s): ModuleList(
      (0): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockWeight: [4, 4, 4]
          (1): BlockWeight: [4, 4, 4]
        )
      )
      (1): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockWeight: [4, 4, 4]
          (1): BlockWeight: [4, 4, 4]
        )
      )
    )
    (actf): GELU()
    (layers2s): ModuleList(
      (0): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockWeight: [4, 4, 4]
          (1): BlockWeight: [4, 4, 4]
        )
      )
      (1): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockWeight: [4, 4, 4]
          (1): BlockWeight: [4, 4, 4]
        )
      )
    )
  )
  (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_channel): BlockMLP_MixerBlock(
    (facto_nets): ModuleList(
      (0): BlockMLP(
 

In [206]:
model(torch.randn(1, 2*2*4, 16*16)).shape

torch.Size([1, 16, 256])

In [207]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int,
                patch_mixing:str, channel_mixing:str):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim, patch_mixing, channel_mixing))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [208]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=1, num_classes=10, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:121


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=121, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((121,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): SparseResMlp(
        (layers1s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0): BlockWeight: [8, 8, 8]
              (1): BlockWeight: [8, 8, 8]
            )
          )
          (1): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0): BlockWeight: [8, 8, 8]
              (1): BlockWeight: [8, 8, 8]
            )
          )
        )
        (actf): GELU()
        (layers2s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0): BlockWeight: [8, 8, 8]
              (1): BlockWeight: [8, 8

In [209]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  99162


In [210]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[ 2.0674,  1.0478, -1.3613, -1.0466, -0.0820,  1.5648, -2.3207, -0.5586,
         -0.8303,  0.3042]], grad_fn=<AddmmBackward>)

#### Final Model

In [230]:
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=10, num_classes=10,
#                 patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
#                 patch_mixing="dense", channel_mixing="dense")
                 

model = model.to(device)

MLP Mixer : Channes per patch -> Initial:48 Final:144


In [231]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode=bilinear)
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): BlockMLP_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockMLP(
            (mlp): Sequential(
              (0): BlockLinear: [9, 9, 18]
              (1): GELU()
              (2): BlockLinear: [9, 18, 9]
            )
          )
          (1): BlockMLP(
            (mlp): Sequential(
              (0): BlockLinear: [9, 9, 18]
              (1): GELU()
              (2): BlockLinear: [9, 18, 9]
            )
          )
        )
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): BlockMLP_MixerBlock(
        (facto_nets): ModuleList(
          (0): BlockMLP(
            

In [232]:
# print("number of params: ", sum(p.numel() for p in model.parameters())) 
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

### Dense: 1,104,390
### Sparse MLP: 215820
### Sparse Linear: 209,070

number of params:  215820


In [214]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

## Training

In [215]:
model_name = f'mlp_mixer_sparse-mlp_c10_s{SEED}'

In [216]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [217]:
STAT ={'train_stat':[], 'test_stat':[]}

In [218]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [219]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    time_taken = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)

            start = time.time()

            outputs = model(inputs)

            start = time.time()-start
            time_taken.append(start)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    with open(f"./output/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [220]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [221]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

  1%|          | 14/1563 [00:06<12:09,  2.12it/s]


KeyboardInterrupt: 

In [None]:
# best_acc

In [None]:
# checkpoint = torch.load(f'./models/{model_name}.pth')
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# best_acc, start_epoch

In [None]:
# model.load_state_dict(checkpoint['model'])

In [None]:
# model

In [None]:
# STAT

In [None]:
# train_stat = np.array(STAT['train_stat'])
# test_stat = np.array(STAT['test_stat'])

In [None]:
# plt.plot(train_stat[:,1], label='train')
# plt.plot(test_stat[:,1], label='test')
# plt.ylabel("Loss")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_loss.svg")
# plt.show()

In [None]:
# plt.plot(train_stat[:,2], label='train')
# plt.plot(test_stat[:,2], label='test')
# plt.ylabel("Accuracy")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_accs.svg")
# plt.show()

## Benchmark Training

In [223]:
def get_data_loaders(seed, ds):
    BS = 64
    if ds == 'c100': BS = 128
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)
    return train_loader, test_loader

In [None]:
def benchmark():
    global model, optimizer, train_loader, test_loader, model_name, criterion, STAT, best_acc
    EPOCHS = 1
    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    DS = 'c100'
#     DS = 'c10'
    for SEED in [369]:
        for num_layers in [7, 10]:

            for i in range(3): ## 3 models training
                print("Experiment index:", i)
                train_loader, test_loader = get_data_loaders(SEED, DS)
                num_cls = 10
                if DS=='c100': num_cls = 100
                ### FOR ORIGINAL MIXER
                if i==0:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="dense", channel_mixing="dense")
                    model_name = f'mlp_mixer_dense_l{num_layers}_{DS}_s{SEED}'
                elif i == 1:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                    model_name = f'mlp_mixer_sparseLinear_l{num_layers}_{DS}_s{SEED}'
                elif i == 2:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
                    model_name = f'mlp_mixer_sparseMlp_l{num_layers}_{DS}_s{SEED}'
                else:
                    print("JPT........!!!!")
                    continue
                    
                model = model.to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                num_params = sum(p.numel() for p in model.parameters())

                model_name = "03.0_" + model_name
                print(f"EXPERIMENTING FOR : {model_name} | params: {num_params}  .......\n.......")
                
#                 continue
                STAT ={'train_stat':[], 'test_stat':[], 'num_params':num_params}
                best_acc = -1
                for epoch in range(0, EPOCHS): ## for 200 epochs
                    train(epoch)
                    test(epoch)
                    scheduler.step()
                print(f"Training finished\n")
                pass
            pass
        pass
    return 0           

In [None]:
benchmark()