In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
# from tqdm.autonotebook import tqdm
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.optim as optim
from torch.utils import data

In [3]:
from transformers_lib import TransformerBlock, \
        Mixer_TransformerBlock_Encoder, \
        PositionalEncoding

In [4]:
class TinyImageNet_Preload(data.Dataset):
#     https://gist.github.com/z-a-f/b862013c0dc2b540cf96a123a6766e54
    
    def __init__(self, root, mode='train', transform=None, preload=False):
        super().__init__()
        self.preload = preload
        dataset = datasets.ImageFolder(
            root=os.path.join(root, mode),
            transform=None
        )
        self.transform = transform
        self.images, self.labels = [], []
        print("Dataset Size:",len(dataset))
        
        if preload:
            for i in tqdm(range(len(dataset))):
                x, y = dataset[i]
                self.images.append(x)
                self.labels.append(y)
                
#         del dataset
        self.dataset = dataset
            
    def _add_channels(img, total_channels=3):
        while len(img.shape) < 3:  # third axis is the channels
            img = np.expand_dims(img, axis=-1)
        while(img.shape[-1]) < 3:
            img = np.concatenate([img, img[:, :, -1:]], axis=-1)
        return img
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if self.preload:
            img, lbl = self.images[idx], self.labels[idx]
        else:
            img, lbl = self.dataset[idx]
        return self.transform(img), lbl

# Model

In [5]:
class Mixer_ViT_Classifier(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_channel:int, num_blocks:int, num_classes:int, block_seq_size:int, block_mlp_size:int, forward_expansion:float=2.0, pos_emb=True, dropout:float=0.0):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
        
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
        final_dim = hidden_channel
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"ViT Mixer : Channels per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.transformer_blocks = []
        
        f = self.get_factors(self.channel_dim)
        print(f)
        fi = np.abs(np.array(f) - np.sqrt(self.channel_dim)).argmin()
        
        _n_heads = f[fi]
        
        ## number of dims per channel -> channel_dim
#         print('Num patches:', self.patch_dim)
        print(f'Sequence len: {self.patch_dim} ; Block size: {block_seq_size}')
        print('Channel dim:', self.channel_dim, 'num heads:',_n_heads)
        
        if block_seq_size is None or block_seq_size<2:
            ### Find the block size for sequence:
            block_seq_size = int(2**np.ceil(np.log2(np.sqrt(self.patch_dim))))
            
        print(f'MLP dim: {self.channel_dim} ; Block size: {block_mlp_size}')

        for i in range(num_blocks):
            L = Mixer_TransformerBlock_Encoder(self.patch_dim, block_seq_size, self.channel_dim, _n_heads, dropout, forward_expansion, nn.GELU, block_mlp_size)
            self.transformer_blocks.append(L)
        self.transformer_blocks = nn.Sequential(*self.transformer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        self.positional_encoding = PositionalEncoding(self.channel_dim, dropout=0)
        if not pos_emb:
            self.positional_encoding = nn.Identity()
        
        
    def get_factors(self, n):
        facts = []
        for i in range(2, n+1):
            if n%i == 0:
                facts.append(i)
        return facts
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.positional_encoding(x)
        x = self.transformer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [11]:
# device = torch.device('cuda:0')
device = torch.device('cpu')

In [12]:
# torch.cuda.device_count()

In [13]:
# torch.cuda.get_device_name(0)

In [14]:
# torch.cuda.memory_allocated()

In [45]:
model = Mixer_ViT_Classifier([3, 32, 32], [2, 2], 64, num_blocks=2, num_classes=10, 
                            block_seq_size=16, block_mlp_size=None, pos_emb=False).to(device)

ViT Mixer : Channels per patch -> Initial:12 Final:64
[2, 4, 8, 16, 32, 64]
Sequence len: 256 ; Block size: 16
Channel dim: 64 num heads: 8
MLP dim: 64 ; Block size: None


In [42]:
# model

In [49]:
model = Mixer_ViT_Classifier([3, 32, 32], [1, 1], 64, num_blocks=2, num_classes=10, 
                            block_seq_size=32, block_mlp_size=None, pos_emb=False).to(device)

ViT Mixer : Channels per patch -> Initial:3 Final:64
[2, 4, 8, 16, 32, 64]
Sequence len: 1024 ; Block size: 32
Channel dim: 64 num heads: 8
MLP dim: 64 ; Block size: None


In [52]:
# print("number of params: ", sum(p.numel() for p in model.parameters()))
for name, m in model.named_children():
#     print(name)
    print(f"{name}: {sum(p.numel() for p in m.parameters())}")

unfold: 0
channel_change: 256
transformer_blocks: 133888
linear: 655370
positional_encoding: 0


In [47]:
# model

In [30]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  789514


In [13]:
# xx = torch.randn(32, 3, 32, 32).to(device)

In [14]:
# %timeit model(xx).mean().backward()

In [15]:
# %timeit model(xx)

In [16]:
# with torch.no_grad():
#     %timeit model(xx)

In [17]:
# model = Mixer_ViT_Classifier([3, 32, 32], [1, 1], 64, num_blocks=4, num_classes=10, 
#                             block_seq_size=1024, block_mlp_size=None, pos_emb=False).to(device)

In [18]:
# model

In [19]:
# print("number of params: ", sum(p.numel() for p in model.parameters()))

In [20]:
# %timeit model(xx).mean().backward()

In [21]:
# %timeit model(xx)

In [22]:
# with torch.no_grad():
#     %timeit model(xx)

In [23]:
# asas

In [24]:
def experiment_skip(model_name, ep):
    
    ## if file of benchmark is there and the training is done for full epochs
    filename = f"./output/benchmark/{model_name}_data.json"
    if os.path.exists(filename):
        with open(filename) as f:
            data = json.load(f)
        ## data consists of lists and dicts.
        epochs = data['train_stat'][-1][0]
        if epochs >= ep-1:
            return True
    
    return False

In [25]:
def benchmark(dataset:str, patch_size:int, num_layers:int, SEED:int, sparse_att:bool=False, sparse_mlp:bool=False, pos_emb:bool=False, cuda:int=0):
    device = torch.device(f"cuda:{cuda}")
    
    if sparse_att:
        assert num_layers%2 == 0, 'number of blocks on sparse transformer is (x2)/2 hence it must be even'
        num_layers_ = num_layers//2
    else:
        num_layers_ = num_layers
    
    BS = 256
    NC = -1
    EPOCHS = 300
    imsize = (3, 32, 32)
    expansion_dict = {16:1024, 8:256, 4:128, 2:64, 1:64}
    expansion = expansion_dict[patch_size]

    
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    ##### Data Transforms
    if dataset == 'tiny':
        NC = 200
        EPOCHS = 400
        imsize = (3, 64, 64)
        tiny_train = transforms.Compose([
        transforms.RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5]*3,
            std=[0.2]*3,
            ),
        ])

        tiny_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5]*3,
                std=[0.2]*3,
            ),
        ])
        
        train_dataset = TinyImageNet_Preload(root="../../../../../_Datasets/tiny-imagenet-200",
                                     mode='train', transform=tiny_train)
        test_dataset = TinyImageNet_Preload(root="../../../../../_Datasets/tiny-imagenet-200",
                                     mode='val', transform=tiny_test)
        
    elif dataset == 'cifar10':
        NC = 10
        BS = 32
        EPOCHS = 200
        
        cifar_train = transforms.Compose([
            transforms.RandomCrop(size=32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
                std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
            ),
        ])

        cifar_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
                std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
            ),
        ])

        train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
        test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

    elif dataset == 'cifar100':
        NC = 100
        BS = 128
        EPOCHS = 300

        cifar_train = transforms.Compose([
            transforms.RandomCrop(size=32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5071, 0.4865, 0.4409],
                std=[0.2009, 0.1984, 0.2023],
            ),
        ])

        cifar_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5071, 0.4865, 0.4409],
                std=[0.2009, 0.1984, 0.2023],
            ),
        ])

        train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
        test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)
        
    ##### Now create data loaders
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=4)
    
    ### Now create models
    
    seq_len = (imsize[-1]*imsize[-2])//(patch_size*patch_size)
    mlp_dim = expansion
    print(seq_len, mlp_dim)
    
    if sparse_att:
        seq_len = int(2**np.ceil(np.log2(np.sqrt(seq_len))))
    if sparse_mlp:
        mlp_dim = int(2**np.ceil(np.log2(np.sqrt(expansion))))
    
    
    torch.manual_seed(SEED)
    model = Mixer_ViT_Classifier(imsize, 
                                 patch_size=[patch_size]*2, 
                                 hidden_channel=expansion, 
                                 num_blocks=num_layers_, 
                                 num_classes=NC, 
                                 block_seq_size=seq_len, 
                                 block_mlp_size=mlp_dim,
                                 pos_emb=pos_emb).to(device)
    
#     _x = torch.randn(BS, *imsize).to(device)
#     print("Output: ",vit_mixer(_x).shape)
    num_params = sum(p.numel() for p in model.parameters())
    print("number of params: ", num_params)
    
    _a, _b, _c = 'att', 'mlp', 'nPE'
    if sparse_att: _a = 'sAtt'
    if sparse_mlp: _b = 'sMlp'
    if pos_emb: _c = 'PE'
    model_name = f'01.3_ViT_{_c}_{dataset}_patch{patch_size}_l{num_layers}_{_a}_{_b}_s{SEED}'
    print(f"Model Name: {model_name}")
    
    if experiment_skip(model_name, EPOCHS):
        print(f'EXPERIMENT DONE... SKIPPING : {model_name}')
        return
    
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    STAT ={'train_stat':[], 'test_stat':[], 'params':num_params, }

    ## Following is copied from 
    ### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

    # Training
    def train(epoch):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
#             break

        STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
        print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
        return

    global best_acc
    best_acc = -1
    def test(epoch):
        global best_acc
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        time_taken = []
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
                inputs, targets = inputs.to(device), targets.to(device)

                start = time.time()

                outputs = model(inputs)

                start = time.time()-start
                time_taken.append(start)

                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
        print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")

        # Save checkpoint.
        acc = 100.*correct/total
        if acc > best_acc:
            print('Saving..')
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch
            }
            if not os.path.isdir('models'):
                os.mkdir('models')
            torch.save(state, f'./models/benchmark/{model_name}.pth')
            best_acc = acc

        with open(f"./output/benchmark/{model_name}_data.json", 'w') as f:
            json.dump(STAT, f, indent=0)

    ### Train the whole damn thing
#     EPOCHS = 1
    for epoch in range(0, EPOCHS):
        train(epoch)
        test(epoch)
        scheduler.step()
        
    
    train_stat = np.array(STAT['train_stat'])
    test_stat = np.array(STAT['test_stat'])

    plt.plot(train_stat[:,1], label='train')
    plt.plot(test_stat[:,1], label='test')
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f"./output/benchmark/plots/{model_name}_loss.svg")
    plt.show()

    plt.plot(train_stat[:,2], label='train')
    plt.plot(test_stat[:,2], label='test')
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig(f"./output/benchmark/plots/{model_name}_accs.svg")
    plt.show()
    
    del model, optimizer
    return

In [26]:
# benchmark(dataset='tiny', 
#           patch_size=4, num_layers=10, SEED=123, sparse_att=True, sparse_mlp=True, cuda=0
#          )

In [27]:
# ### Automate the benchmark
# ###### for tiny
# cuda_idx = 0
# # for seed in [147, 258, 369]:
# for seed in [147]:
#     for patch_size in [16, 8, 4]:
#         for sparse_attention in [False, True]:
#             for sparse_mlp in [False, True]:
# #                 for nlayers in [6, 10, 14]:
#                 for nlayers in [6]:
#                     print(f'''
#                         Experimenting on Tiny Dataset 
#                         patch:{patch_size},
#                         sparse_att: {sparse_attention},
#                         sparse_mlp: {sparse_mlp},
#                         num_layers : {nlayers},
#                         seed: {seed}
#                     ''')
            
#                     benchmark(dataset='tiny', 
#                               patch_size=patch_size, 
#                               num_layers=nlayers, 
#                               SEED=seed, 
#                               sparse_att=sparse_attention, sparse_mlp=sparse_mlp, 
#                               cuda=cuda_idx
#                              )


In [28]:
# ### Automate the benchmark
# ###### for tiny

# not_working = [
#     (4, False, True, 6),
#     (4, False, False, 10),
#     (4, False, False, 14),
#     (4, False, True, 10),
#     (4, False, True, 14),
#     (4, True, True, 14),
# ]

# cuda_idx = 0
# # for seed in [147, 258, 369]:
# for seed in [147]:
#     for patch_size in [16, 8, 4]:
#         for sparse_attention in [False, True]:
#             for sparse_mlp in [False, True]:
#                 for nlayers in [6, 10, 14]:

#                     print(f'''
#                         Experimenting on Tiny Dataset 
#                         patch:{patch_size},
#                         sparse_att: {sparse_attention},
#                         sparse_mlp: {sparse_mlp},
#                         num_layers : {nlayers},
#                         seed: {seed}
#                     ''')
            
#                 ### check if config is in not_working case
#                     exit = False
#                     for nw in not_working:
#                         if patch_size==nw[0] and \
#                             sparse_attention==nw[1] and \
#                             sparse_mlp==nw[2] and\
#                             nlayers==nw[3]:
                            
#                             exit=True
#                             break
#                     if exit:
#                         print(f'Exiting as the config is in NOT WORKING')
#                         continue


#                     benchmark(dataset='tiny', 
#                               patch_size=patch_size, 
#                               num_layers=nlayers, 
#                               SEED=seed, 
#                               sparse_att=sparse_attention, sparse_mlp=sparse_mlp, 
#                               cuda=cuda_idx
#                              )


In [29]:
# ### Automate the benchmark
# ###### for c10

# not_working = [
# ]

# cuda_idx = 0
# for dataset in ['cifar10']:
#     for seed in [147]:
#         for nlayers in [12]:
#             for patch_size in [2]:
#                 for sparse_attention in [True]:
#                     for sparse_mlp in [False]:
#                         for PE in [False]:

#                             print(f'''
#                                 Experimenting on {dataset} Dataset 
#                                 patch:{patch_size},
#                                 sparse_att: {sparse_attention},
#                                 sparse_mlp: {sparse_mlp},
#                                 num_layers: {nlayers},
#                                 pos_embed: {PE},
#                                 seed: {seed}
#                             ''')

#                         ### check if config is in not_working case
#                             exit = False
#                             for nw in not_working:
#                                 if patch_size==nw[0] and \
#                                     sparse_attention==nw[1] and \
#                                     sparse_mlp==nw[2] and\
#                                     nlayers==nw[3]:

#                                     exit=True
#                                     break
#                             if exit:
#                                 print(f'Exiting as the config is in NOT WORKING')
#                                 continue


#                             benchmark(dataset=dataset, 
#                                       patch_size=patch_size, 
#                                       num_layers=nlayers, 
#                                       SEED=seed, 
#                                       sparse_att=sparse_attention, sparse_mlp=sparse_mlp, 
#                                       pos_emb=PE,
#                                       cuda=cuda_idx
#                                      )

In [30]:
# asfdas

In [31]:
# ### Automate the benchmark
# ###### for c100

# not_working = [
#     (1, False, False, 4),
#     (1, False, False, 8),
#     (1, False, False, 12),
#     (1, False, True, 4),
#     (1, False, True, 8),
#     (1, False, True, 12),
# ]

# cuda_idx = 0
# dataset = 'cifar100'
# # for seed in [147, 258, 369]:
# for seed in [147]:
#     for patch_size in [1, 2, 4, 8]:
#         for sparse_attention in [False, True]:
#             for sparse_mlp in [False, True]:
#                 for nlayers in [4, 8, 12]:

#                     print(f'''
#                         Experimenting on {dataset} Dataset 
#                         patch:{patch_size},
#                         sparse_att: {sparse_attention},
#                         sparse_mlp: {sparse_mlp},
#                         num_layers : {nlayers},
#                         seed: {seed}
#                     ''')
            
#                 ### check if config is in not_working case
#                     exit = False
#                     for nw in not_working:
#                         if patch_size==nw[0] and \
#                             sparse_attention==nw[1] and \
#                             sparse_mlp==nw[2] and\
#                             nlayers==nw[3]:
                            
#                             exit=True
#                             break
#                     if exit:
#                         print(f'Exiting as the config is in NOT WORKING')
#                         continue


#                     benchmark(dataset=dataset, 
#                               patch_size=patch_size, 
#                               num_layers=nlayers, 
#                               SEED=seed, 
#                               sparse_att=sparse_attention, sparse_mlp=sparse_mlp, 
#                               cuda=cuda_idx
#                              )

In [32]:
## c10 works on all patch sizes on 3090 using 64 batch size
## c100 doesn't work on patch 1 for dense 

## Benchmark Memory and Time CIFAR

In [33]:
import nvidia_smi

MB = 1024*1024
def get_memory_used():
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(1)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    val = info.used/MB
    nvidia_smi.nvmlShutdown()
    return val

In [34]:
get_memory_used()

313.625

In [35]:
import subprocess as sp
import os

def get_memory_used():
    command = "nvidia-smi --query-gpu=memory.used --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values[1]

In [36]:
get_memory_used()

5

In [37]:
# asdfsadf

In [38]:
def benchmark_memory(dataset:str, patch_size:int, num_layers:int, SEED:int, sparse_att:bool=False, sparse_mlp:bool=False, pos_emb:bool=False, cuda:int=0):
    device = torch.device(f"cuda:{cuda}")
    
    if sparse_att:
        assert num_layers%2 == 0, 'number of blocks on sparse transformer is (x2)/2 hence it must be even'
        num_layers_ = num_layers//2
    else:
        num_layers_ = num_layers
    
    BS = 32
    NC = -1
    EPOCHS = 1
    imsize = (3, 32, 32)
    expansion_dict = {16:1024, 8:256, 4:128, 2:64, 1:64}
    expansion = expansion_dict[patch_size]

    
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    ##### Data Transforms
        
    if dataset == 'cifar10':
        NC = 10
    elif dataset == 'cifar100':
        NC = 100

    ### Now create models
    
    seq_len = (imsize[-1]*imsize[-2])//(patch_size*patch_size)
    mlp_dim = expansion
    print(seq_len, mlp_dim)
    
    if sparse_att:
        seq_len = int(2**np.ceil(np.log2(np.sqrt(seq_len))))
    if sparse_mlp:
        mlp_dim = int(2**np.ceil(np.log2(np.sqrt(expansion))))
    
    mem_begin = get_memory_used()
    
    torch.manual_seed(SEED)
    model = Mixer_ViT_Classifier(imsize, 
                                 patch_size=[patch_size]*2, 
                                 hidden_channel=expansion, 
                                 num_blocks=num_layers_, 
                                 num_classes=NC, 
                                 block_seq_size=seq_len, 
                                 block_mlp_size=mlp_dim,
                                 pos_emb=pos_emb).to(device)
    
    _x = torch.randn(BS, *imsize)#.to(device)
    _y = torch.randint(10, (BS,))
#     print("Output: ",vit_mixer(_x).shape)
    num_params = sum(p.numel() for p in model.parameters())
    print("number of params: ", num_params)
    
    _a, _b, _c = 'att', 'mlp', 'nPE'
    if sparse_att: _a = 'sAtt'
    if sparse_mlp: _b = 'sMlp'
    if pos_emb: _c = 'PE'
    model_name = f'01.3_ViT_{_c}_{dataset}_patch{patch_size}_l{num_layers}_{_a}_{_b}_s{SEED}'
    print(f"Model Name: {model_name}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Training
    model.train()
    inputs, targets = _x.to(device), _y.to(device)
    ### test time taken for multiple iterations
    time_taken = []
    for i in range(50):
        start = time.time()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        start = time.time()-start
        time_taken.append(start)
    train_time = np.mean(time_taken)
        
    mem_end = get_memory_used()
    print(f"mem begin: {mem_begin}  end: {mem_end}")

    model.eval()
    time_taken = []
    for i in range(50):
        with torch.no_grad():
            start = time.time()
            outputs = model(inputs)
            start = time.time()-start
            time_taken.append(start)
            
    test_time = np.mean(time_taken)
    
    filename = f"./output/benchmark_memory_data.json"
    if not os.path.exists(filename):
        with open(filename, 'w') as f:
            json.dump({}, f, indent=0)
            
    with open(filename, 'r+') as f:
#         with open(filename,'r+') as f:
        file_data = json.load(f)
        file_data[f"{model_name}"] = {'memory':mem_end-mem_begin, 'time_train':train_time, 'time_test':test_time}
        f.seek(0)
        json.dump(file_data, f, indent = 0)

    del model, optimizer
    return

In [39]:
# ### Automate the benchmark
# ###### for c10

# not_working = [
# ]

# cuda_idx = 0
# seed = 147
# PE = False
# for dataset in ['cifar10']:
#     for nlayers in [8, 4]:
#         for patch_size in [2, 4]:
#             for sparse_attention in [False, True]:
#                 for sparse_mlp in [False]:

#                     print(f'''
#                         Experimenting on {dataset} Dataset 
#                         patch:{patch_size},
#                         sparse_att: {sparse_attention},
#                         sparse_mlp: {sparse_mlp},
#                         num_layers: {nlayers},
#                         pos_embed: {PE},
#                         seed: {seed}
#                     ''')

#                     benchmark_memory(dataset=dataset, 
#                               patch_size=patch_size, 
#                               num_layers=nlayers, 
#                               SEED=seed, 
#                               sparse_att=sparse_attention, sparse_mlp=sparse_mlp, 
#                               pos_emb=PE,
#                               cuda=cuda_idx
#                              )

In [40]:
benchmark_memory(dataset='cifar100', 
                  patch_size=2, 
                  num_layers=4, 
                  SEED=147, 
                  sparse_att=True, 
                  sparse_mlp=False, 
                  pos_emb=False,
                  cuda=0
                 )

256 64
ViT Mixer : Channels per patch -> Initial:12 Final:64
[2, 4, 8, 16, 32, 64]
Sequence len: 256 ; Block size: 16
Channel dim: 64 num heads: 8
MLP dim: 64 ; Block size: 64
number of params:  1773220
Model Name: 01.3_ViT_nPE_cifar100_patch2_l4_sAtt_mlp_s147
mem begin: 5  end: 565
