In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 3090'

In [4]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [5]:
import torch.optim as optim
from torch.utils import data

In [6]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [8]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5071, 0.4865, 0.4409],
#         std=[0.2009, 0.1984, 0.2023],
#     ),
# ])

# train_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=True, download=True, transform=cifar_train)
# test_dataset = datasets.CIFAR100(root="../../../../../_Datasets/cifar100/", train=False, download=True, transform=cifar_test)

In [9]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Model

In [10]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [11]:
MlpBLock(2, [3,4])

MlpBLock(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=6, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=6, out_features=8, bias=True)
    (3): GELU(approximate='none')
    (4): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [12]:
import sparse_nonlinear_lib_minimal as snl

In [13]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])

BlockMLP_MixerBlock(
  (facto_nets): ModuleList(
    (0-1): 2 x BlockMLP(
      (mlp): Sequential(
        (0): BlockLinear: [16, 16, 16]
        (1): ELU(alpha=1.0)
        (2): BlockLinear: [16, 16, 16]
      )
    )
  )
)

In [14]:
snl.BlockMLP_MixerBlock(256, 16, hidden_layers_ratio=[1])(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [15]:
snl.BlockLinear_MixerBlock(256, 16)

BlockLinear_MixerBlock(
  (facto_nets): ModuleList(
    (0-1): 2 x BlockWeight: [16, 16, 16]
  )
)

In [16]:
snl.BlockLinear_MixerBlock(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [17]:
class SparseResMlp(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_expansion=2, actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        self.hex = hidden_expansion
            
        self.layers1s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=True) for _ in range(hidden_expansion)]
        self.layers2s = [snl.BlockLinear_MixerBlock(input_dim, block_dim, bias=False) for _ in range(hidden_expansion)]
        
        self.bias = nn.Parameter(torch.zeros(1, input_dim))
        
        self.layers1s = nn.ModuleList(self.layers1s)
        self.actf = actf()
        self.layers2s = nn.ModuleList(self.layers2s)
        
    def forward(self, x):
        y = 0
        for i in range(self.hex):
            h = self.layers1s[i](x)
            h = self.actf(h)
            h = self.layers2s[i](h)
            y += h#*(1/self.hex)
        return y+self.bias

In [18]:
model = SparseResMlp(256, 16)
# model.layers1s[0].bias
model

SparseResMlp(
  (layers1s): ModuleList(
    (0-1): 2 x BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0-1): 2 x BlockWeight: [16, 16, 16]
      )
    )
  )
  (actf): GELU(approximate='none')
  (layers2s): ModuleList(
    (0-1): 2 x BlockLinear_MixerBlock(
      (facto_nets): ModuleList(
        (0-1): 2 x BlockWeight: [16, 16, 16]
      )
    )
  )
)

In [19]:
SparseResMlp(256, 16)(torch.randn(1, 256)).shape

torch.Size([1, 256])

In [20]:
lin = snl.BlockLinear_MixerBlock(16, 4, bias=False)
torch.det(lin(torch.eye(16)))

tensor(-1.9978e-26, grad_fn=<LinalgDetBackward0>)

### Sparse Linear with Hidden Expansion 

In [21]:
snl.BlockLinear(4, 2, 4)

BlockLinear: [4, 2, 4]

In [22]:
class MLP_SparseLinear_Monarch_Deform(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_expansion=2, actf=nn.ELU):
        super().__init__()
        assert input_dim%block_dim == 0, "Input dim must be divisible by block dim"
        assert np.sqrt(input_dim) == block_dim, "Input dim must be square of block dim"
        assert hidden_expansion >= 1
        
        self.input_dim = input_dim
        self.block_dim = block_dim
        self.hidden_expansion = hidden_expansion
        
        self.hidden_dim = input_dim*hidden_expansion
        
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
        assert num_layers == 2, "Num layers > 2 does not contribute to monarch"
        
        self.linear0_0 = snl.BlockLinear(self.input_dim//block_dim, block_dim, block_dim*hidden_expansion, bias=False)
        self.linear0_1 = snl.BlockLinear(self.hidden_dim//block_dim, block_dim, block_dim, bias=True)
        self.stride0 = block_dim*hidden_expansion
        
        self.actf = actf()
        self.linear1_0 = snl.BlockLinear(self.hidden_dim//block_dim, block_dim, block_dim, bias=False)
        self.linear1_1 = snl.BlockLinear(self.input_dim//block_dim, block_dim*hidden_expansion, block_dim, bias=True)
        
    def forward(self, x):
        ## Say shape of x is [BS, 121] > hidden expansion 2
        
        bs = x.shape[0] ## BS, input_dim
        y = x
        
        y = y.view(bs, -1, self.block_dim) ## BS, num_blocks, block_dim ; [bs, 11, 11]
        y = y.transpose(0,1).contiguous()  ## num_blocks, BS, block_dim ; [11, bs, 11]
        y = self.linear0_0(y) ## num_blocks, BS, block_dim*hidden_expansion ; [11, bs, 22]
        y = y.transpose(0,1).contiguous() ## BS, num_blocks, block_dim*hidden_expansion ; [bs, 11, 22]
        y = y.view(bs, -1)  ## BS, hidden_dim ; [bs, 242]
        
        y = y.view(bs, self.block_dim, self.stride0).permute(2,0,1).contiguous() ## num_blocks, BS, block_dim; [22, bs, 11]
        y = self.linear0_1(y) ## num_blocks, BS, block_dim ; [22, bs, 11]
        y = y.transpose(0,1).contiguous() ## BS, num_blocks, block_dim ; [bs, 22, 11]
        y = y.view(bs, -1)  ## BS, hidden_dim ; [bs, 242]
        
        ### First linear complete
        y = self.actf(y)
        
        
        y = y.view(bs, -1, self.block_dim) ## BS, num_blocks, block_dim ; [bs, 22, 11]
        y = y.transpose(0,1).contiguous()  ## num_blocks, BS, block_dim ; [22, bs, 11]
        y = self.linear1_0(y) ## num_blocks, BS, block_dim*hidden_expansion ; [22, bs, 11]
        y = y.transpose(0,1).contiguous() ## BS, num_blocks, block_dim*hidden_expansion ; [bs, 22, 11]
        y = y.view(bs, -1)  ## BS, hidden_dim ; [bs, 242]
        
        y = y.view(bs, self.stride0, self.block_dim).permute(2,0,1).contiguous() ## num_blocks, BS, block_dim; [11, bs, 22]
        y = self.linear1_1(y) ## num_blocks, BS, block_dim ; [11, bs, 11]
        y = y.transpose(0,1).contiguous() ## BS, num_blocks, block_dim ; [bs, 11, 11]
        y = y.view(bs, -1)  ## BS, hidden_dim ; [bs, 121]
        
        return y

In [23]:
monarch = MLP_SparseLinear_Monarch_Deform(4, 2, hidden_expansion=1)
monarch

MLP_SparseLinear_Monarch_Deform(
  (linear0_0): BlockLinear: [2, 2, 2]
  (linear0_1): BlockLinear: [2, 2, 2]
  (actf): ELU(alpha=1.0)
  (linear1_0): BlockLinear: [2, 2, 2]
  (linear1_1): BlockLinear: [2, 2, 2]
)

In [24]:
torch.det(monarch(torch.eye(4))) ## remove residual before testing

tensor(1.6140e-09, grad_fn=<LinalgDetBackward0>)

In [25]:
monarch(torch.eye(4))

tensor([[ 0.5158, -0.0804,  0.3653,  0.1062],
        [ 0.5260, -0.0833,  0.3730,  0.1205],
        [ 0.5209, -0.0811,  0.3733,  0.1179],
        [ 0.5216, -0.0814,  0.3731,  0.1170]], grad_fn=<ViewBackward0>)

In [26]:
# asdasd

### Block MLP - without res

In [27]:
class BlockMLP(nn.Module):
    def __init__(self, input_dim, layer_dims, actf=nn.GELU):
        super().__init__()
        self.block_dim = layer_dims[0]
        
        assert input_dim%self.block_dim == 0, "Input dim must be even number"
        ### Create a block MLP
        self.mlp = []
        n_blocks = input_dim//layer_dims[0]
        for i in range(len(layer_dims)-1):
            l = snl.BlockLinear(n_blocks, layer_dims[i], layer_dims[i+1])
            a = actf()
            self.mlp.append(l)
            self.mlp.append(a)
        self.mlp = self.mlp[:-1]
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, self.block_dim).transpose(0,1)
        x = self.mlp(x)
        x = x.transpose(1,0).reshape(bs, -1)
        return x
    
############################################################################
############################################################################

class BlockMLP_MixerBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        
        assert input_dim%block_dim == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.block_dim = block_dim
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        num_layers = int(np.ceil(log_base(input_dim, base=block_dim)))
        hidden_layers_ratio = [1] + hidden_layers_ratio + [1]
        
        block_layer_dims = [int(a*block_dim) for a in hidden_layers_ratio]
        self.facto_nets = []
        for i in range(num_layers):
            net = BlockMLP(self.input_dim, block_layer_dims, actf)
            self.facto_nets.append(net)
            
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        bs = x.shape[0]
        y = x
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1, self.block_dim, self.block_dim**i).permute(0, 2, 1).contiguous().view(bs, -1)
            y = fn(y)
            y = y.view(-1, self.block_dim**i, self.block_dim).permute(0, 2, 1).contiguous()

        y = y.view(bs, -1)
        return y

## MLP-Mixer 

In [28]:
class MixerBlock(nn.Module):
    
    def __init__(self, patch_dim, channel_dim, patch_mixing="dense", channel_mixing="dense"):
        super().__init__()
        
        self.valid_functions = ["dense", "sparse_linear", "sparse_mlp"]
        assert patch_mixing in self.valid_functions
        assert channel_mixing in self.valid_functions
        
        self.patch_dim = patch_dim
        self.channel_dim = channel_dim
        
        self.ln0 = nn.LayerNorm(channel_dim)
        self.mlp_patch = self.get_mlp(patch_dim, patch_mixing)    
        
        self.ln1 = nn.LayerNorm(channel_dim)
        self.mlp_channel = self.get_mlp(channel_dim, channel_mixing)
    
    def get_mlp(self, dim, mixing_function, actf=nn.GELU):
        block_dim = int(np.sqrt(dim))
        assert block_dim**2 == dim, "Sparsifying dimension must be a square number"
        
        HIDDEN_EXPANSION = 1 #2
        if mixing_function == self.valid_functions[0]:
            mlp = MlpBLock(dim, [HIDDEN_EXPANSION], actf)
        elif mixing_function == self.valid_functions[1]:
#             mlp = SparseResMlp(dim, block_dim, HIDDEN_EXPANSION, actf)
            mlp = MLP_SparseLinear_Monarch_Deform(dim, block_dim, HIDDEN_EXPANSION, actf)

        elif mixing_function == self.valid_functions[2]:
            mlp = BlockMLP_MixerBlock(dim, block_dim, [HIDDEN_EXPANSION], actf)
        return mlp
    
    def forward(self, x):
        ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
        ######## !!!! Can use same mixer on shape of -> N, C, P;
        
        #### mix per patch
        y = self.ln0(x) ### per channel layer normalization ?? 
        y = torch.swapaxes(y, -1, -2).contiguous()
        
        y = y.view(-1, self.patch_dim)
        y = self.mlp_patch(y)
        y = y.view(-1, self.channel_dim, self.patch_dim)
        
        y = torch.swapaxes(y, -1, -2)
        x = x+y
        
        #### mix per channel 
        y = self.ln1(x)
        y = y.view(-1, self.channel_dim)
        y = self.mlp_channel(y)
        y = y.view(-1, self.patch_dim, self.channel_dim)
        
        x = x+y
        return x

In [29]:
# class MixerBlock(nn.Module):
    
#     def __init__(self, patch_dim, channel_dim):
#         super().__init__()
        
#         self.ln0 = nn.LayerNorm(channel_dim)
#         self.mlp_patch = MlpBLock(patch_dim, [2])
#         self.ln1 = nn.LayerNorm(channel_dim)
#         self.mlp_channel = MlpBLock(channel_dim, [2])
    
#     def forward(self, x):
#         ## x has shape-> N, nP, nC/hidden_dims; C=Channel, P=Patch
        
#         ######## !!!! Can use same mixer on shape of -> N, C, P;
        
#         #### mix per patch
#         y = self.ln0(x) ### per channel layer normalization ?? 
#         y = torch.swapaxes(y, -1, -2)
#         y = self.mlp_patch(y)
#         y = torch.swapaxes(y, -1, -2)
#         x = x+y
        
#         #### mix per channel 
#         y = self.ln1(x)
#         y = self.mlp_channel(y)
#         x = x+y
#         return x

In [30]:
model = MixerBlock(2*2*4, 16*16, channel_mixing="sparse_linear", patch_mixing="sparse_linear")
# model = MixerBlock(2*2*4, 16*16)
model

MixerBlock(
  (ln0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_patch): MLP_SparseLinear_Monarch_Deform(
    (linear0_0): BlockLinear: [4, 4, 4]
    (linear0_1): BlockLinear: [4, 4, 4]
    (actf): GELU(approximate='none')
    (linear1_0): BlockLinear: [4, 4, 4]
    (linear1_1): BlockLinear: [4, 4, 4]
  )
  (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mlp_channel): MLP_SparseLinear_Monarch_Deform(
    (linear0_0): BlockLinear: [16, 16, 16]
    (linear0_1): BlockLinear: [16, 16, 16]
    (actf): GELU(approximate='none')
    (linear1_0): BlockLinear: [16, 16, 16]
    (linear1_1): BlockLinear: [16, 16, 16]
  )
)

In [31]:
model(torch.randn(1, 2*2*4, 16*16)).shape

torch.Size([1, 16, 256])

In [32]:
class MlpMixer(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int,
                patch_mixing:str, channel_mixing:str):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        self.scaler = nn.UpsamplingBilinear2d(size=(self.img_dim[-2], self.img_dim[-1]))
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
#         final_dim = int(patch_size[0]*patch_size[1]*hidden_expansion)
        final_dim = int(init_dim*hidden_expansion)

        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"MLP Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.mixer_blocks = []
        for i in range(num_blocks):
            self.mixer_blocks.append(MixerBlock(self.patch_dim, self.channel_dim, patch_mixing, channel_mixing))
        self.mixer_blocks = nn.Sequential(*self.mixer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.scaler(x)
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.mixer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [33]:
mixer = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=1, num_classes=10, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
mixer

MLP Mixer : Channes per patch -> Initial:48 Final:121


MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(32, 32), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=121, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((121,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MLP_SparseLinear_Monarch_Deform(
        (linear0_0): BlockLinear: [8, 8, 8]
        (linear0_1): BlockLinear: [8, 8, 8]
        (actf): GELU(approximate='none')
        (linear1_0): BlockLinear: [8, 8, 8]
        (linear1_1): BlockLinear: [8, 8, 8]
      )
      (ln1): LayerNorm((121,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MLP_SparseLinear_Monarch_Deform(
        (linear0_0): BlockLinear: [11, 11, 11]
        (linear0_1): BlockLinear: [11, 11, 11]
        (actf): GELU(approximate='none')
        (linear1_0): BlockLinear: [11, 11, 11]
        (linear1_1): BlockLinear: [11, 11, 11]
      )
    )
  )
  (linear): L

In [34]:
print("number of params: ", sum(p.numel() for p in mixer.parameters())) 

number of params:  91605


In [35]:
mixer(torch.randn(1, 3, 32, 32))

tensor([[-0.2760, -0.1265,  0.2318, -0.4155, -0.4572,  0.2440, -0.0724,  0.2644,
         -0.2106, -0.2772]], grad_fn=<AddmmBackward0>)

#### Final Model

In [36]:
model = MlpMixer((3, 4*9, 4*9), (4, 4), hidden_expansion=3.0, num_blocks=10, num_classes=10,
                patch_mixing="sparse_linear", channel_mixing="sparse_linear")
#                 patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
#                 patch_mixing="dense", channel_mixing="dense")
                 

model = model.to(device)

MLP Mixer : Channes per patch -> Initial:48 Final:144


In [37]:
model

MlpMixer(
  (scaler): UpsamplingBilinear2d(size=(36, 36), mode='bilinear')
  (unfold): Unfold(kernel_size=(4, 4), dilation=1, padding=0, stride=(4, 4))
  (channel_change): Linear(in_features=48, out_features=144, bias=True)
  (mixer_blocks): Sequential(
    (0): MixerBlock(
      (ln0): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_patch): MLP_SparseLinear_Monarch_Deform(
        (linear0_0): BlockLinear: [9, 9, 9]
        (linear0_1): BlockLinear: [9, 9, 9]
        (actf): GELU(approximate='none')
        (linear1_0): BlockLinear: [9, 9, 9]
        (linear1_1): BlockLinear: [9, 9, 9]
      )
      (ln1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (mlp_channel): MLP_SparseLinear_Monarch_Deform(
        (linear0_0): BlockLinear: [12, 12, 12]
        (linear0_1): BlockLinear: [12, 12, 12]
        (actf): GELU(approximate='none')
        (linear1_0): BlockLinear: [12, 12, 12]
        (linear1_1): BlockLinear: [12, 12, 12]
      )
    )
    (1): MixerBlo

In [38]:
# print("number of params: ", sum(p.numel() for p in model.parameters())) 
print("number of params: ", sum(p.numel() for p in model.mixer_blocks.parameters())) 

### Dense: 1,104,390
### Sparse MLP: 215820
### Sparse Linear: 209,070

number of params:  108540


In [39]:
model(torch.randn(1, 3, 32, 32).to(device)).shape

torch.Size([1, 10])

In [40]:
# asdasd

## Training

In [41]:
# model_name = f'mlp_mixer_sparse-mlp_c10_s{SEED}'

In [42]:
# EPOCHS = 200
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [43]:
STAT ={'train_stat':[], 'test_stat':[]}

In [44]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [45]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    time_taken = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)

            start = time.time()

            outputs = model(inputs)

            start = time.time()-start
            time_taken.append(start)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total, np.mean(time_taken))) ### (Epochs, Loss, Acc, time)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models_v1/{model_name}.pth')
        best_acc = acc
        
    with open(f"./models_v1/stats/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [46]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [47]:
# ### Train the whole damn thing

# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     train(epoch)
#     test(epoch)
#     scheduler.step()

In [48]:
# best_acc

In [49]:
# checkpoint = torch.load(f'./models/{model_name}.pth')
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# best_acc, start_epoch

In [50]:
# model.load_state_dict(checkpoint['model'])

In [51]:
# model

In [52]:
# STAT

In [53]:
# train_stat = np.array(STAT['train_stat'])
# test_stat = np.array(STAT['test_stat'])

In [54]:
# plt.plot(train_stat[:,1], label='train')
# plt.plot(test_stat[:,1], label='test')
# plt.ylabel("Loss")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_loss.svg")
# plt.show()

In [55]:
# plt.plot(train_stat[:,2], label='train')
# plt.plot(test_stat[:,2], label='test')
# plt.ylabel("Accuracy")
# plt.legend()
# plt.savefig(f"./output/plots/{model_name}_accs.svg")
# plt.show()

## Benchmark Training

In [56]:
def get_data_loaders(seed, ds):
    BS = 64
    if ds == 'c100': BS = 128
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)
    return train_loader, test_loader

In [57]:
# ! mkdir models_v1/
# ! mkdir models_v1/stats

In [58]:
def benchmark():
    global model, optimizer, train_loader, test_loader, model_name, criterion, STAT, best_acc
    EPOCHS = 200
    criterion = nn.CrossEntropyLoss()
    lr = 0.001
#     DS = 'c100'
    DS = 'c10'
#     for SEED in [147, 258, 369]:
    for SEED in [741, 852, 963, 159, 357]:
        for num_layers in [7]: ##[7, 10]
#             for i in range(3): ## 3 models training
            for i in [1]: ## Test Linear Only

                print("Experiment index:", i)
                train_loader, test_loader = get_data_loaders(SEED, DS)
                torch.manual_seed(SEED)
                num_cls = 10
                if DS=='c100': num_cls = 100
                ### FOR ORIGINAL MIXER
                if i == 0:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="dense", channel_mixing="dense")
                    model_name = f'mlp_mixer_dense_l{num_layers}_{DS}_s{SEED}'
                elif i == 1:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                    model_name = f'mlp_mixer_sparseLinear_l{num_layers}_{DS}_s{SEED}'
                elif i == 2:
                    model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                     num_classes=num_cls, patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
                    model_name = f'mlp_mixer_sparseMlp_l{num_layers}_{DS}_s{SEED}'
                else:
                    print("JPT........!!!!")
                    continue
                    
                model = model.to(device)
                model = torch.compile(model)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

                num_params = sum(p.numel() for p in model.parameters())

                model_name = "03.1_" + model_name + "_h2"
                print(f"EXPERIMENTING FOR : {model_name} | params: {num_params}  .......\n.......")
                
#                 continue
                STAT ={'train_stat':[], 'test_stat':[], 'num_params':num_params}
                best_acc = -1
                for epoch in range(0, EPOCHS): ## for 200 epochs
                    train(epoch)
                    test(epoch)
                    scheduler.step()
                print(f"Training finished\n")
                pass
            pass
        pass
    return 0           

In [59]:
## warning - Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

## Flops

In [87]:
from ptflops import get_model_complexity_info

SEED = -1
for num_cls in [10]:
    DS = f"c{num_cls}"
    for num_layers in [7]:
        for i in range(3):
            if i == 0:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="dense", channel_mixing="dense")
                model_name = f'mlp_mixer_dense_l{num_layers}_{DS}_s{SEED}'
            elif i == 1:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="sparse_linear", channel_mixing="sparse_linear")
                model_name = f'mlp_mixer_sparseLinear_l{num_layers}_{DS}_s{SEED}'
            elif i == 2:
                model = MlpMixer((3, 32, 32), (4, 4), hidden_expansion=2.53, num_blocks=num_layers, 
                                 num_classes=num_cls, patch_mixing="sparse_mlp", channel_mixing="sparse_mlp")
                model_name = f'mlp_mixer_sparseMlp_l{num_layers}_{DS}_s{SEED}'
            else:
                print("JPT........!!!!")
                continue

            macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, ignore_modules=['channel_change'],
                                           print_per_layer_stat=False, verbose=False)
            
            print(model_name)
#             print(model)
            print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
            print('{:<30}  {:<8}'.format('Number of parameters: ', params))
            print('')
            pass
        pass
    pass
pass

MLP Mixer : Channes per patch -> Initial:48 Final:121
mlp_mixer_dense_l7_c10_s-1
Computational complexity:       20.73 MMac
Number of parameters:           351.68 k

MLP Mixer : Channes per patch -> Initial:48 Final:121
mlp_mixer_sparseLinear_l7_c10_s-1
Computational complexity:       5.0 MMac
Number of parameters:           140.96 k

MLP Mixer : Channes per patch -> Initial:48 Final:121
mlp_mixer_sparseMlp_l7_c10_s-1
Computational complexity:       5.33 MMac
Number of parameters:           143.55 k



In [61]:
# benchmark()

In [62]:
"""
The initialization on three methods are different.
Try to remove that gap.

1) Initializing everything to kaiming_uniform_ similar to nn.Linear
    - Might have better initialization (unclear how pytorch handles init for with with block-sparse)
"""

'\nThe initialization on three methods are different.\nTry to remove that gap.\n\n1) Initializing everything to kaiming_uniform_ similar to nn.Linear\n    - Might have better initialization (unclear how pytorch handles init for with with block-sparse)\n'

In [63]:
# !nvidia-smi

In [64]:
# !pip list

In [65]:
# exit(0)

## Convert BlockLinear implementation using conv1d

In [66]:
conv1d = nn.Conv2d(10, 30, kernel_size=1, groups=5)

In [67]:
conv1d.weight.shape

torch.Size([30, 2, 1, 1])

In [68]:
conv1d.weight.view(5, 2, 6)

tensor([[[-0.5987, -0.6935,  0.6045,  0.2860,  0.5444,  0.1846],
         [-0.0949,  0.2955,  0.4700, -0.4096,  0.1172, -0.1096]],

        [[-0.4105,  0.2790,  0.1393, -0.4721, -0.1584, -0.1914],
         [ 0.1679,  0.3108,  0.2105, -0.4167, -0.2007,  0.1860]],

        [[-0.5702,  0.1191,  0.4164,  0.3362,  0.1124, -0.2719],
         [ 0.0805,  0.0746, -0.1390, -0.0561, -0.4196, -0.3399]],

        [[-0.4491,  0.4099, -0.2549,  0.3952, -0.1426, -0.2481],
         [-0.7063,  0.0291, -0.2250, -0.6278,  0.5068, -0.1535]],

        [[ 0.1714, -0.2993, -0.3794,  0.3074,  0.3262,  0.5486],
         [-0.6626,  0.1814,  0.3607,  0.6895,  0.4803, -0.1974]]],
       grad_fn=<ViewBackward0>)

In [69]:
conv1d.bias.shape

torch.Size([30])

In [70]:
weight = torch.randn(5, 2, 6)
bias = torch.randn(5, 1, 6)

In [71]:
x = torch.randn(5, 7, 2)
x_conv = x.reshape(7, 5*2, 1, 1)

In [72]:
conv1d(x_conv).shape

torch.Size([7, 30, 1, 1])

In [73]:
(torch.bmm(x, weight) + bias).shape

torch.Size([5, 7, 6])

In [74]:
class BlockLinear_conv(nn.Module):
    def __init__(self, num_blocks, input_block_dim, output_block_dim, bias=True):
        super().__init__()
        self.conv = nn.Conv2d(input_block_dim*num_blocks, output_block_dim*num_blocks, 
                              kernel_size=1, groups=num_blocks, bias=bias)
        
    def forward(self, x):
        nblocks, bs, dim = x.shape[0], x.shape[1], x.shape[2]
        x = x.transpose(0,1).reshape(bs, -1, 1, 1)
        x = self.conv(x).reshape(bs, nblocks, -1).transpose(0,1)
        return x
    
    def __repr__(self):
        S = f'BlockLinear_conv: {list(self.conv.weight.shape)}'
        return S

In [75]:
BlockLinear_conv(7, 5, 6)(torch.randn(7, 1, 5)).shape

torch.Size([7, 1, 6])

In [76]:
## Check if both get same output
x = torch.randn(7, 1, 5)
l1 = BlockLinear_conv(7, 5, 6, bias=True)
l2 = snl.BlockLinear(7, 5, 6, bias=True)

In [77]:
l1.conv.weight.shape

torch.Size([42, 5, 1, 1])

In [78]:
l2.weight.shape

torch.Size([7, 5, 6])

In [79]:
l2.weight.data = l1.conv.weight.data.reshape(7, 6, 5).transpose(1,2)

In [80]:
l1.conv.bias.shape

torch.Size([42])

In [81]:
l2.bias.shape

torch.Size([7, 1, 6])

In [82]:
l2.bias.data = l1.conv.bias.data.reshape(7, 1, 6)

In [83]:
l1(x)

tensor([[[-1.3229, -0.4138, -0.2152,  0.1704, -1.0895, -0.2707]],

        [[ 0.4223, -1.4276, -0.4511, -0.2083, -0.8804,  0.0695]],

        [[-0.3560, -0.1052,  0.4898,  0.3275, -1.1440,  1.5165]],

        [[-0.3584,  0.2785, -0.4157,  0.4401, -0.5882, -0.3017]],

        [[-0.9296,  0.0218, -0.3808, -1.4007, -0.5351,  0.6662]],

        [[-0.1567, -0.2236, -0.0443,  0.1398, -0.1497, -0.2061]],

        [[-0.2095, -0.9281,  0.2969,  0.2429, -0.1553,  0.6684]]],
       grad_fn=<TransposeBackward0>)

In [84]:
l2(x)

tensor([[[-1.3229, -0.4138, -0.2152,  0.1704, -1.0895, -0.2707]],

        [[ 0.4223, -1.4276, -0.4511, -0.2083, -0.8804,  0.0695]],

        [[-0.3560, -0.1052,  0.4898,  0.3275, -1.1440,  1.5165]],

        [[-0.3584,  0.2785, -0.4157,  0.4401, -0.5882, -0.3017]],

        [[-0.9296,  0.0218, -0.3808, -1.4007, -0.5351,  0.6662]],

        [[-0.1567, -0.2236, -0.0443,  0.1398, -0.1497, -0.2061]],

        [[-0.2095, -0.9281,  0.2969,  0.2429, -0.1553,  0.6684]]],
       grad_fn=<AddBackward0>)

In [85]:
snl.BlockLinear = BlockLinear_conv

In [95]:
macs, params = get_model_complexity_info(nn.Conv2d(3, 32, 3, bias=False), 
                                         (3, 32, 32), as_strings=True, ignore_modules=['channel_change'],
                                           print_per_layer_stat=False, verbose=False)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       777.6 KMac
Number of parameters:           864     
