In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

In [3]:
torch.cuda.get_device_name(1)

'NVIDIA TITAN Xp'

In [4]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [5]:
import torch.optim as optim
from torch.utils import data

In [6]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [8]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

In [9]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Model

In [10]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [11]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU(approximate='none')
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [12]:
import sparse_nonlinear_lib_minimal as snl

In [13]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])

BlockMLP_MixerBlock(
  (facto_nets): ModuleList(
    (0-1): 2 x BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [16, 16, 16]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [16, 16, 16]
      )
    )
  )
)

In [14]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [15]:
snl.BlockLinear_MixerBlock(256, 16)

BlockLinear_MixerBlock(
  (facto_nets): ModuleList(
    (0-1): 2 x BlockWeight: [16, 16, 16]
  )
)

In [16]:
snl.BlockLinear_MixerBlock(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [17]:
class SparseResMlp(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_expansion=2, actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        self.hex = hidden_expansion
            
        self.layers1s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=True) for _ in range(hidden_expansion)]
        self.layers2s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=False) for _ in range(hidden_expansion)]
        
        self.bias = nn.Parameter(torch.zeros(1, input_dim))
        
        self.layers1s = nn.ModuleList(self.layers1s)
        self.actf = actf()
        self.layers2s = nn.ModuleList(self.layers2s)
        
    def forward(self, x):
        y = 0
        for i in range(self.hex):
            h = self.layers1s[i](x)
            h = self.actf(h)
            h = self.layers2s[i](h)
            y += h#*(1/self.hex)
        return y + self.bias

In [18]:
model = SparseResMlp(256, 16)
# model.layers1s[0].bias
model

SparseResMlp(
  (layers1s): ModuleList(
    (0-1): 2 x BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0-1): 2 x BlockWeight: [16, 16, 16]
      )
    )
  )
  (actf): GELU(approximate='none')
  (layers2s): ModuleList(
    (0-1): 2 x BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0-1): 2 x BlockWeight: [16, 16, 16]
      )
    )
  )
)

In [19]:
SparseResMlp(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [20]:
# asdasd

### Block MLP - without res

In [21]:
class BlockMLP(nn.Module):
    def __init__(self, input_dim, layer_dims, actf=nn.GELU):
        super().__init__()
        self.block_dim = layer_dims[0]
        
        assert input_dim%self.block_dim == 0, "Input dim must be even number"
        ### Create a block MLP
        self.mlp = []
        n_blocks = input_dim//layer_dims[0]
        for i in range(len(layer_dims)-1):
            l = snl.BlockLinear(n_blocks, layer_dims[i], layer_dims[i+1])
            a = actf()
            self.mlp.append(l)
            self.mlp.append(a)
        self.mlp = self.mlp[:-1]
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, self.block_dim).transpose(0,1)
        x = self.mlp(x)
        x = x.transpose(1,0).reshape(bs, -1)
        return x
    
############################################################################
############################################################################

class BlockMLP_MixerBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.block_dim = block_dim
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
        hidden_layers_ratio = [1] + hidden_layers_ratio + [1]
        
        block_layer_dims = [int(a*block_dim) for a in hidden_layers_ratio]
        self.facto_nets = []
        for i in range(num_layers):
            net = BlockMLP(self.input_dim, block_layer_dims, actf)
            self.facto_nets.append(net)
            
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        bs = x.shape[0]
        y = x
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1, self.block_dim, self.block_dim**i).permute(0, 2, 1).contiguous().view(bs, -1)
            y = fn(y)
            y = y.view(-1, self.block_dim**i, self.block_dim).permute(0, 2, 1).contiguous()

        y = y.view(bs, -1)
        return y

## MLP-Mixer 

In [22]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim, patch_mixing="dense", channel_mixing="dense"):
        super().__init__()
        
        self.valid_functions = ["dense", "sparse_linear", "sparse_mlp"]
        assert patch_mixing in self.valid_functions
        assert channel_mixing in self.valid_functions
        
        self.patch_dim = patch_dim
        self.channel_dim = channel_dim
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = self.get_mlp(patch_dim, patch_mixing)    
        
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = self.get_mlp(channel_dim, channel_mixing)
    
    def get_mlp(self, dim, mixing_function, actf=nn.GELU):
        block_dim = int(np.sqrt(dim))
        assert block_dim**2 == dim, "Sparsifying dimension must be a square number"
        
        HIDDEN_EXPANSION = 1 #2
        if mixing_function == self.valid_functions[0]:
            mlp = MlpBLock(dim, [HIDDEN_EXPANSION], actf)
        elif mixing_function == self.valid_functions[1]:
            mlp = SparseResMlp(dim, block_dim, HIDDEN_EXPANSION, actf)
        elif mixing_function == self.valid_functions[2]:
            mlp = BlockMLP_MixerBlock(dim, block_dim, [HIDDEN_EXPANSION], actf)
        return mlp
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2).contiguous()
        
        y = y.view(-1, self.patch_dim)
        y = self.mlp_patch(y)
        y = y.view(-1, self.channel_dim, self.patch_dim)
        
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = y.view(-1, self.channel_dim)
        y = self.mlp_channel(y)
        y = y.view(-1, self.patch_dim, self.channel_dim)
        
        x = x+y
        return x

In [23]:
# class MixerBlock(nn.Module):
    
#     def __init__(self, patch_dim, channel_dim):
#         super().__init__()
        
#         self.ln0 = nn.LayerNorm(channel_dim)
#         self.mlp_patch = MlpBLock(patch_dim, [2])
#         self.ln1 = nn.LayerNorm(channel_dim)
#         self.mlp_channel = MlpBLock(channel_dim, [2])
    
#     def forward(self, x):
#         ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
#         ######## !!!! Can use same mixer on shape of -> N, C, P;
        
#         #### mix per patch
#         y = self.ln0(x) ### per channel layer normalization ?? 
#         y = torch.swapaxes(y, -1, -2)
#         y = self.mlp_patch(y)
#         y = torch.swapaxes(y, -1, -2)
#         x = x+y
        
#         #### mix per channel 
#         y = self.ln1(x)
#         y = self.mlp_channel(y)
#         x = x+y
#         return x

In [24]:
model = MixerBlock(2*2*4, 16*16, channel_mixing="sparse_mlp", patch_mixing="sparse_linear")
# model = MixerBlock(2*2*4, 16*16)
model

MixerBlock(
  (ln0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): SparseResMlp(
    (layers1s): ModuleList(
      (0): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0-1): 2 x BlockWeight: [4, 4, 4]
        )
      )
    )
    (actf): GELU(approximate='none')
    (layers2s): ModuleList(
      (0): BlockLinear_MixerBlock(
        (facto_nets): ModuleList(
          (0-1): 2 x BlockWeight: [4, 4, 4]
        )
      )
    )
  )
  (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_channel): BlockMLP_MixerBlock(
    (facto_nets): ModuleList(
      (0-1): 2 x BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [16, 16, 16]
          (1): GELU(approximate='none')
          (2): BlockLinear: [16, 16, 16]
        )
      )
    )
  )
)

In [25]:
model(torch.randn(1, 2*2*4, 16*16)).shape

torch.Size([1, 16, 256])

In [26]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int,
                patch_mixing:str, channel_mixing:str):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim, patch_mixing, channel_mixing))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [27]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=1, num_classes=10, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:121


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=121, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((121,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): SparseResMlp(
        (layers1s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0-1): 2 x BlockWeight: [8, 8, 8]
            )
          )
        )
        (actf): GELU(approximate='none')
        (layers2s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0-1): 2 x BlockWeight: [8, 8, 8]
            )
          )
        )
      )
      (ln1): LayerNorm((121,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): SparseResMlp(
        (layers1s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (fact

In [28]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  91605


In [29]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[-0.0875,  0.2247,  0.3493, -0.6245,  0.5838, -0.0184, -0.7723, -0.1112,
          0.0543, -0.2550]], grad_fn=<AddmmBackward0>)

#### Final Model

In [30]:
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=10, num_classes=10,
                patch_mixing="sparse_linear", channel_mixing="sparse_linear")
#                 patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
#                 patch_mixing="dense", channel_mixing="dense")
                 

model = model.to(device)

MLP Mixer : Channes per patch -> Initial:48 Final:144


In [31]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): SparseResMlp(
        (layers1s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0-1): 2 x BlockWeight: [9, 9, 9]
            )
          )
        )
        (actf): GELU(approximate='none')
        (layers2s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (facto_nets): ModuleList(
              (0-1): 2 x BlockWeight: [9, 9, 9]
            )
          )
        )
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): SparseResMlp(
        (layers1s): ModuleList(
          (0): BlockLinear_MixerBlock(
            (fact

In [32]:
# print("number of params: ", sum(p.numel() for p in model.parameters())) 
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

### Dense: 1,104,390
### Sparse MLP: 215820
### Sparse Linear: 209,070

number of params:  108540


In [33]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [34]:
# asdasd

## Training

In [35]:
# model_name = f'mlp_mixer_sparse-mlp_c10_s{SEED}'

In [36]:
# EPOCHS = 200
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [37]:
STAT ={'train_stat':[], 'test_stat':[]}

In [38]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [39]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    time_taken = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)

            start = time.time()

            outputs = model(inputs)

            start = time.time()-start
            time_taken.append(start)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models_v1/{model_name}.pth')
        best_acc = acc
        
    with open(f"./models_v1/stats/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [40]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [41]:
# ### Train the whole damn thing

# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     train(epoch)
#     test(epoch)
#     scheduler.step()

In [42]:
# best_acc

In [43]:
# checkpoint = torch.load(f'./models/{model_name}.pth')
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# best_acc, start_epoch

In [44]:
# model.load_state_dict(checkpoint['model'])

In [45]:
# model

In [46]:
# STAT

In [47]:
# train_stat = np.array(STAT['train_stat'])
# test_stat = np.array(STAT['test_stat'])

In [48]:
# plt.plot(train_stat[:,1], label='train')
# plt.plot(test_stat[:,1], label='test')
# plt.ylabel("Loss")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_loss.svg")
# plt.show()

In [49]:
# plt.plot(train_stat[:,2], label='train')
# plt.plot(test_stat[:,2], label='test')
# plt.ylabel("Accuracy")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_accs.svg")
# plt.show()

## Benchmark Training

In [50]:
def get_data_loaders(seed, ds):
    BS = 64
    if ds == 'c100': BS = 128
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)
    return train_loader, test_loader

In [51]:
# ! mkdir models_v1/
# ! mkdir models_v1/stats

In [52]:
def benchmark():
    global model, optimizer, train_loader, test_loader, model_name, criterion, STAT, best_acc
    EPOCHS = 200
    criterion = nn.CrossEntropyLoss()
    lr = 0.001
#     DS = 'c100'
    DS = 'c10'
    for SEED in [147, 258, 369]:
#     for SEED in [741, 852, 963, 159, 357]:
        for num_layers in [7]: ##[7, 10]
#             for i in range(3): ## 3 models training
            for i in [1]:

                print("Experiment index:", i)
                train_loader, test_loader = get_data_loaders(SEED, DS)
                torch.manual_seed(SEED)
                num_cls = 10
                if DS=='c100': num_cls = 100
                ### FOR ORIGINAL MIXER
                if i == 0:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="dense", channel_mixing="dense")
                    model_name = f'mlp_mixer_dense_l{num_layers}_{DS}_s{SEED}'
                elif i == 1:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                    model_name = f'mlp_mixer_sparseLinear_l{num_layers}_{DS}_s{SEED}'
                elif i == 2:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
                    model_name = f'mlp_mixer_sparseMlp_l{num_layers}_{DS}_s{SEED}'
                else:
                    print("JPT........!!!!")
                    continue
                    
                model = model.to(device)
#                 model = torch.compile(model)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                num_params = sum(p.numel() for p in model.parameters())

                model_name = "03.0_" + model_name + "_h1"
                print(f"EXPERIMENTING FOR : {model_name} | params: {num_params}  .......\n.......")
                
#                 continue
                STAT ={'train_stat':[], 'test_stat':[], 'num_params':num_params}
                best_acc = -1
                for epoch in range(0, EPOCHS): ## for 200 epochs
                    train(epoch)
                    test(epoch)
                    scheduler.step()
                print(f"Training finished\n")
                pass
            pass
        pass
    return 0           

In [53]:
## warning - Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

In [None]:
benchmark()

Experiment index: 1
MLP Mixer : Channes per patch -> Initial:48 Final:121
EXPERIMENTING FOR : 03.0_mlp_mixer_sparseLinear_l7_c10_s147_h1 | params: 140961  .......
.......


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.86it/s]


[Train] 0 Loss: 1.791 | Acc: 36.290 18145/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.68it/s]


[Test] 0 Loss: 1.458 | Acc: 47.790 4779/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.85it/s]


[Train] 1 Loss: 1.445 | Acc: 48.208 24104/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 115.61it/s]


[Test] 1 Loss: 1.240 | Acc: 55.740 5574/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.65it/s]


[Train] 2 Loss: 1.289 | Acc: 54.078 27039/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 115.83it/s]


[Test] 2 Loss: 1.127 | Acc: 60.630 6063/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.70it/s]


[Train] 3 Loss: 1.194 | Acc: 57.748 28874/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.52it/s]


[Test] 3 Loss: 1.070 | Acc: 62.370 6237/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.59it/s]


[Train] 4 Loss: 1.144 | Acc: 59.620 29810/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.54it/s]


[Test] 4 Loss: 1.044 | Acc: 63.140 6314/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.61it/s]


[Train] 5 Loss: 1.096 | Acc: 61.514 30757/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.60it/s]


[Test] 5 Loss: 0.980 | Acc: 65.460 6546/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.59it/s]


[Train] 6 Loss: 1.066 | Acc: 62.500 31250/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.30it/s]


[Test] 6 Loss: 0.960 | Acc: 66.010 6601/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.61it/s]


[Train] 7 Loss: 1.038 | Acc: 63.696 31848/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 114.35it/s]


[Test] 7 Loss: 0.933 | Acc: 67.580 6758/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.56it/s]


[Train] 8 Loss: 1.014 | Acc: 64.522 32261/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.08it/s]


[Test] 8 Loss: 0.937 | Acc: 67.660 6766/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.68it/s]


[Train] 9 Loss: 0.997 | Acc: 65.034 32517/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.74it/s]


[Test] 9 Loss: 0.895 | Acc: 68.950 6895/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.56it/s]


[Train] 10 Loss: 0.975 | Acc: 65.900 32950/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.09it/s]


[Test] 10 Loss: 0.886 | Acc: 69.190 6919/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.58it/s]


[Train] 11 Loss: 0.958 | Acc: 66.760 33380/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.73it/s]


[Test] 11 Loss: 0.885 | Acc: 68.680 6868/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.61it/s]


[Train] 12 Loss: 0.932 | Acc: 67.460 33730/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.87it/s]


[Test] 12 Loss: 0.842 | Acc: 70.930 7093/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.49it/s]


[Train] 13 Loss: 0.923 | Acc: 67.846 33923/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.33it/s]


[Test] 13 Loss: 0.849 | Acc: 70.420 7042/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.47it/s]


[Train] 14 Loss: 0.909 | Acc: 68.390 34195/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.90it/s]


[Test] 14 Loss: 0.814 | Acc: 72.010 7201/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:31<00:00, 25.00it/s]


[Train] 15 Loss: 0.889 | Acc: 69.022 34511/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 110.67it/s]


[Test] 15 Loss: 0.794 | Acc: 72.100 7210/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.29it/s]


[Train] 16 Loss: 0.879 | Acc: 69.212 34606/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.50it/s]


[Test] 16 Loss: 0.815 | Acc: 71.720 7172/10000


100%|███████████████████████████████████████████████████| 782/782 [00:31<00:00, 24.71it/s]


[Train] 17 Loss: 0.865 | Acc: 69.804 34902/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 114.70it/s]


[Test] 17 Loss: 0.793 | Acc: 72.520 7252/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.49it/s]


[Train] 18 Loss: 0.854 | Acc: 70.322 35161/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.21it/s]


[Test] 18 Loss: 0.796 | Acc: 72.770 7277/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.41it/s]


[Train] 19 Loss: 0.849 | Acc: 70.748 35374/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.17it/s]


[Test] 19 Loss: 0.797 | Acc: 72.170 7217/10000


100%|███████████████████████████████████████████████████| 782/782 [00:36<00:00, 21.44it/s]


[Train] 20 Loss: 0.839 | Acc: 71.038 35519/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.83it/s]


[Test] 20 Loss: 0.798 | Acc: 72.160 7216/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.41it/s]


[Train] 21 Loss: 0.833 | Acc: 71.080 35540/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.67it/s]


[Test] 21 Loss: 0.758 | Acc: 73.970 7397/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.45it/s]


[Train] 22 Loss: 0.816 | Acc: 71.694 35847/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.74it/s]


[Test] 22 Loss: 0.779 | Acc: 72.630 7263/10000


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.07it/s]


[Train] 23 Loss: 0.814 | Acc: 71.940 35970/50000


100%|███████████████████████████████████████████████████| 157/157 [00:01<00:00, 93.04it/s]


[Test] 23 Loss: 0.761 | Acc: 73.690 7369/10000


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.20it/s]


[Train] 24 Loss: 0.809 | Acc: 71.964 35982/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.63it/s]


[Test] 24 Loss: 0.754 | Acc: 73.970 7397/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.52it/s]


[Train] 25 Loss: 0.796 | Acc: 72.270 36135/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 105.44it/s]


[Test] 25 Loss: 0.774 | Acc: 73.360 7336/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.43it/s]


[Train] 26 Loss: 0.792 | Acc: 72.574 36287/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 107.98it/s]


[Test] 26 Loss: 0.762 | Acc: 73.820 7382/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.46it/s]


[Train] 27 Loss: 0.787 | Acc: 72.770 36385/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 114.39it/s]


[Test] 27 Loss: 0.748 | Acc: 74.320 7432/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.10it/s]


[Train] 28 Loss: 0.778 | Acc: 72.904 36452/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 110.16it/s]


[Test] 28 Loss: 0.750 | Acc: 74.680 7468/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.09it/s]


[Train] 29 Loss: 0.774 | Acc: 73.126 36563/50000


100%|███████████████████████████████████████████████████| 157/157 [00:01<00:00, 87.13it/s]


[Test] 29 Loss: 0.748 | Acc: 74.110 7411/10000


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.05it/s]


[Train] 30 Loss: 0.759 | Acc: 73.630 36815/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.18it/s]


[Test] 30 Loss: 0.757 | Acc: 73.950 7395/10000


100%|███████████████████████████████████████████████████| 782/782 [00:32<00:00, 24.42it/s]


[Train] 31 Loss: 0.761 | Acc: 73.740 36870/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 110.87it/s]


[Test] 31 Loss: 0.726 | Acc: 74.480 7448/10000


100%|███████████████████████████████████████████████████| 782/782 [00:31<00:00, 24.48it/s]


[Train] 32 Loss: 0.755 | Acc: 73.594 36797/50000


100%|███████████████████████████████████████████████████| 157/157 [00:01<00:00, 92.80it/s]


[Test] 32 Loss: 0.708 | Acc: 75.420 7542/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:35<00:00, 21.83it/s]


[Train] 33 Loss: 0.748 | Acc: 74.064 37032/50000


100%|███████████████████████████████████████████████████| 157/157 [00:01<00:00, 88.83it/s]


[Test] 33 Loss: 0.724 | Acc: 75.160 7516/10000


100%|███████████████████████████████████████████████████| 782/782 [00:31<00:00, 25.10it/s]


[Train] 34 Loss: 0.748 | Acc: 74.018 37009/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.58it/s]


[Test] 34 Loss: 0.731 | Acc: 75.160 7516/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.51it/s]


[Train] 35 Loss: 0.737 | Acc: 74.458 37229/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.02it/s]


[Test] 35 Loss: 0.728 | Acc: 75.140 7514/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.58it/s]


[Train] 36 Loss: 0.732 | Acc: 74.548 37274/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.83it/s]


[Test] 36 Loss: 0.722 | Acc: 75.210 7521/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.54it/s]


[Train] 37 Loss: 0.728 | Acc: 74.718 37359/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.90it/s]


[Test] 37 Loss: 0.728 | Acc: 74.880 7488/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.56it/s]


[Train] 38 Loss: 0.730 | Acc: 74.702 37351/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.18it/s]


[Test] 38 Loss: 0.717 | Acc: 75.610 7561/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.38it/s]


[Train] 39 Loss: 0.716 | Acc: 75.130 37565/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.29it/s]


[Test] 39 Loss: 0.707 | Acc: 75.920 7592/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.49it/s]


[Train] 40 Loss: 0.719 | Acc: 75.042 37521/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.52it/s]


[Test] 40 Loss: 0.705 | Acc: 75.620 7562/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.55it/s]


[Train] 41 Loss: 0.712 | Acc: 75.324 37662/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 111.46it/s]


[Test] 41 Loss: 0.697 | Acc: 76.210 7621/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.49it/s]


[Train] 42 Loss: 0.711 | Acc: 75.454 37727/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 110.74it/s]


[Test] 42 Loss: 0.697 | Acc: 76.670 7667/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.55it/s]


[Train] 43 Loss: 0.709 | Acc: 75.388 37694/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.67it/s]


[Test] 43 Loss: 0.704 | Acc: 76.040 7604/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.54it/s]


[Train] 44 Loss: 0.701 | Acc: 75.624 37812/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.93it/s]


[Test] 44 Loss: 0.722 | Acc: 75.340 7534/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.49it/s]


[Train] 45 Loss: 0.700 | Acc: 75.714 37857/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.58it/s]


[Test] 45 Loss: 0.710 | Acc: 76.020 7602/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.43it/s]


[Train] 46 Loss: 0.687 | Acc: 75.972 37986/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 109.85it/s]


[Test] 46 Loss: 0.687 | Acc: 76.500 7650/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.56it/s]


[Train] 47 Loss: 0.687 | Acc: 76.142 38071/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.06it/s]


[Test] 47 Loss: 0.686 | Acc: 76.670 7667/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.57it/s]


[Train] 48 Loss: 0.679 | Acc: 76.164 38082/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 112.18it/s]


[Test] 48 Loss: 0.674 | Acc: 77.160 7716/10000
Saving..


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.51it/s]


[Train] 49 Loss: 0.678 | Acc: 76.484 38242/50000


100%|██████████████████████████████████████████████████| 157/157 [00:01<00:00, 113.55it/s]


[Test] 49 Loss: 0.688 | Acc: 76.320 7632/10000


100%|███████████████████████████████████████████████████| 782/782 [00:30<00:00, 25.53it/s]


[Train] 50 Loss: 0.677 | Acc: 76.586 38293/50000


  1%|▎                                                    | 1/157 [00:00<00:16,  9.74it/s]

## Flops

In [None]:
from ptflops import get_model_complexity_info

SEED = -1
for num_cls in [10]:
    DS = f"c{num_cls}"
    for num_layers in [7]:
        for i in range(3):
            if i == 0:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="dense", channel_mixing="dense")
                model_name = f'mlp_mixer_dense_l{num_layers}_{DS}_s{SEED}'
            elif i == 1:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                model_name = f'mlp_mixer_sparseLinear_l{num_layers}_{DS}_s{SEED}'
            elif i == 2:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
                model_name = f'mlp_mixer_sparseMlp_l{num_layers}_{DS}_s{SEED}'
            else:
                print("JPT........!!!!")
                continue

            macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, ignore_modules=['channel_change'],
                                           print_per_layer_stat=False, verbose=False)
            
            print(model_name)
#             print(model)
            print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
            print('{:<30}  {:<8}'.format('Number of parameters: ', params))
            print('')
            pass
        pass
    pass
pass

In [None]:
"""
The initialization on three methods are different.
Try to remove that gap.

1) Initializing everything to kaiming_uniform_ similar to nn.Linear
    - Might have better initialization (unclear how pytorch handles init for with with block-sparse)
"""

In [None]:
# !nvidia-smi

In [None]:
# !pip list

In [None]:
exit(0)