In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys, json

In [3]:
# import sparse_nonlinear_lib as snl

In [4]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

In [5]:
# time.sleep(60*60)

## For CIFAR10 dataset

In [6]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [8]:
## demo of train loader
# xx, yy = iter(train_loader).next()
for xx, yy in train_loader:
    break
xx.shape

torch.Size([32, 3, 32, 32])

# Model Comparision

In [9]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [10]:
# class CIFAR10_ImageMonarchMLP(nn.Module):
    
#     def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], actf=nn.GELU):
#         super().__init__()
        
#         self.block0 = MlpBLock(img_size[0]*img_size[1], hidden_layers_ratio, actf=actf)
#         self.block1 = MlpBLock(img_size[0]*img_size[2], hidden_layers_ratio, actf=actf)
        
# #         self.norm = nn.BatchNorm1d(select)
#         self.norm = nn.LayerNorm(np.prod(img_size))
    
#         ### Can also use normalization per block for effeciency
# #         self.norm1 = nn.LayerNorm(self.block1.input_dim)

#         self.actf = actf()
#         self.fc = nn.Linear(np.prod(img_size), 10)
        
#     def forward(self, x):
#         bs, C, H, W = x.shape
        
#         ### use B, W, C*H
#         x = x.permute(0, 3, 1, 2).contiguous().view(bs, W, -1)
#         x = self.block0(x).view(bs, W, C, H)
#         ### use B, H, C*W
#         x = x.permute(0, 3, 2, 1).contiguous().view(bs, H, -1)
#         x = self.block1(x).view(bs, -1)
        
#         x = self.norm(x)
#         x = self.actf(x)
#         x = self.fc(x)
#         return x

In [11]:
# CIFAR10_ImageMonarchMLP()(torch.randn(1, 3, 32, 32))

In [12]:
class RowColMixer(nn.Module):
    
    def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        
        self.block0 = MlpBLock(img_size[0]*img_size[1], hidden_layers_ratio, actf=actf)
        self.norm0 = nn.LayerNorm(self.block0.input_dim)
        self.block1 = MlpBLock(img_size[0]*img_size[2], hidden_layers_ratio, actf=actf)
        self.norm1 = nn.LayerNorm(self.block1.input_dim)
        
    def forward(self, x):
        bs, C, H, W = x.shape
        res = x
        
        ### use B, W, C*H
        x = x.permute(0, 3, 1, 2).contiguous().view(bs, W, -1)
        x = self.block0(self.norm0(x))
        ### use B, H, C*W
        x = x.view(bs, W, C, H).permute(0, 3, 2, 1).contiguous().view(bs, H, -1)
        x = self.block1(self.norm1(x))

        x = x.view(bs, H, C, W).permute(0, 2, 1, 3).contiguous()
        return x + res
    
class CIFAR10_RowColMixer(nn.Module):
    
    def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], layers=1, channel_expand=3, actf=nn.GELU):
        super().__init__()
        assert img_size[0] <= channel_expand, "Can't reduce channels than original size"
        self.select_channels = torch.randperm(channel_expand - img_size[0])%3
        
        img_size = (channel_expand, img_size[1], img_size[2])
        
        self.blocks = []
        for i in range(layers):
            self.blocks.append(RowColMixer(img_size, hidden_layers_ratio, actf=actf))
        self.blocks = nn.Sequential(*self.blocks)
        
        self.norm = nn.BatchNorm1d(np.prod(img_size))
#         self.norm = nn.LayerNorm(np.prod(img_size))
#         self.actf = actf()
        self.fc = nn.Linear(np.prod(img_size), 10)
        
    def forward(self, x):
        bs, C, H, W = x.shape
        x = torch.cat((x, x[:, self.select_channels, :, :]), dim=1)
        
        ### use B, W, C*H
        x = self.blocks(x).view(bs, -1)
        x = self.norm(x)
#         x = self.actf(x)
        x = self.fc(x)
        return x

In [13]:
CIFAR10_RowColMixer(channel_expand=20)(torch.randn(10, 3, 32, 32))

tensor([[-3.3503e-02,  4.4751e-01, -5.6777e-05,  9.5314e-01,  1.1769e+00,
          1.2710e+00,  7.1167e-01,  1.0610e-01,  9.0168e-01,  3.8667e-01],
        [-3.0119e-02, -8.1653e-01,  3.0493e-01,  1.7339e-01, -1.5425e+00,
         -3.9107e-01,  7.6206e-01,  1.0401e+00, -3.7562e-01,  1.3260e-02],
        [ 4.6819e-01, -2.5582e-03, -6.9367e-02,  6.7653e-01,  1.1961e-01,
          5.5355e-01, -1.2394e+00, -1.7039e+00, -8.1242e-01,  2.0579e-01],
        [-2.9483e-01, -9.7842e-01, -4.6154e-01, -1.1676e+00,  9.3388e-01,
         -1.7197e-01, -5.5773e-01,  1.8984e-01, -1.7195e-01, -4.7012e-02],
        [-1.7174e-01,  7.6339e-02,  2.6950e-02, -5.7526e-01, -5.2477e-01,
          3.4190e-01, -3.1683e-01,  1.5339e+00,  8.9587e-01, -7.1919e-01],
        [ 6.3098e-01,  2.1883e-01,  3.2724e-01,  3.8056e-01,  3.2671e-03,
         -5.9096e-01,  5.2233e-01, -1.4360e+00, -6.5502e-01,  7.3660e-01],
        [-1.8325e-01,  2.3869e-02,  7.4325e-01,  4.5781e-02,  3.0987e-01,
          4.6250e-01, -9.3400e-0

## Create Models

In [14]:
# model = CIFAR10_ImageMonarchMLP()
model = CIFAR10_RowColMixer(layers=2, channel_expand=10)

In [15]:
model = model.to(device)
model

CIFAR10_RowColMixer(
  (blocks): Sequential(
    (0): RowColMixer(
      (block0): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=320, out_features=640, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=640, out_features=320, bias=True)
        )
      )
      (norm0): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (block1): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=320, out_features=640, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=640, out_features=320, bias=True)
        )
      )
      (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (1): RowColMixer(
      (block0): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=320, out_features=640, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=640, out_features=320, bias=True)
        )
      )
      (norm0): LayerNorm((

In [16]:
model(torch.randn(2, 3, 32, 32).to(device)).shape

torch.Size([2, 10])

In [17]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1767690


## Training

In [18]:
## debugging to find the good classifier/output distribution.
# model_name = 'RowColumn_Mixer_CIFAR10_v0'

In [19]:
# EPOCHS = 50
# criterion = nn.CrossEntropyLoss()
# # optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# # optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# # optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [20]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
#     for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
    for batch_idx, (inputs, targets) in enumerate(train_loader):

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    loss = train_loss/(batch_idx+1)
    acc = 100.*correct/total
#     print(f"[Train] {epoch} Loss: {loss:.3f} | Acc: {acc:.3f} {correct}/{total}")
    return loss, acc

In [21]:
# best_acc = -1
def test(epoch, model, optimizer, best_acc, model_name):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    latency = []
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            start = time.time()
            outputs = model(inputs)
            ttaken = time.time()-start
                
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            latency.append(ttaken)
    
    loss = test_loss/(batch_idx+1)
    acc = 100.*correct/total
#     print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
#         print(f'Saving.. Acc: {100.*correct/total:.3f}')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    return loss, acc, best_acc, latency

In [22]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [23]:
# # ### Train the whole damn thing

# best_acc = -1
# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     trloss, tracc = train(epoch, model, optimizer)
#     teloss, teacc, best_acc, latency = test(epoch, model, optimizer, best_acc, model_name)
#     scheduler.step()

In [24]:
# best_acc ## 90.42 for ordinary, 89.59 for sparse, 89.82 fro 32bMLP, 

### Do all experiments in repeat

In [25]:
def train_model(model, lr, model_name, epochs=200, seed=0):
    global criterion, train_loader, test_loader
    
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)
    
    best_acc = -1
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    n_params = sum(p.numel() for p in model.parameters())
    stats = {'num_param':n_params, 'latency': [], 
             'train_acc':[], 'train_loss':[], 
             'test_acc':[], 'test_loss':[] 
            }
    
    print(f"Begin Training for {model_name}")
    print(f"Num Parameters: {n_params}")

    for epoch in tqdm(range(epochs)):
        trloss, tracc = train(epoch, model, optimizer)
        teloss, teacc, best_acc, laten = test(epoch, model, optimizer, best_acc, model_name)
        scheduler.step()
        
        stats['latency'] += laten
        stats['train_acc'].append(tracc)
        stats['test_acc'].append(teacc)
        stats['train_loss'].append(trloss)
        stats['test_loss'].append(teloss)
        
    print()
    
    latency = np.array(stats['latency'])
    mu, std = np.mean(latency), np.std(latency)
    stats['latency'] = {'mean':mu, 'std':std}
    ### Save stats of the model
    with open(f'./models/stats/{model_name}_stats.json', 'w') as f:
        json.dump(stats, f)
    
    return stats, best_acc

In [26]:
###### FOR MULTI-LAYERED MONARCH/ROW-COL VISION MIXER
hidden_scale = 2
channel_expand = [5, 10, 20]
EPOCHS = 200
LR = 0.001
seed = 147

def benchmark_cifar10():
    for layers in [2, 4, 3]:
        ### First test MLP with allowed dimension mixing
        
        h = hidden_scale
        for c in channel_expand:
            torch.manual_seed(seed)
            
            model = CIFAR10_RowColMixer(layers=layers, hidden_layers_ratio=[h], channel_expand=c)
            n_params = sum(p.numel() for p in model.parameters())
            print(f"\t\t{n_params}\tRowCol-MLP-Mixer ChannelExpand")
            model_name = f"cifar10_layered_RowColMixer_l{layers}_c{c}_h{h}_s{seed}"
            
            train_model(model, LR, model_name, EPOCHS)

In [None]:
benchmark_cifar10()

		474250	RowCol-MLP-Mixer ChannelExpand
Begin Training for cifar10_layered_RowColMixer_l2_c5_h2_s147
Num Parameters: 474250


100%|███████████████████████████████████████████████████| 200/200 [23:12<00:00,  6.96s/it]



		1767690	RowCol-MLP-Mixer ChannelExpand
Begin Training for cifar10_layered_RowColMixer_l2_c10_h2_s147
Num Parameters: 1767690


100%|███████████████████████████████████████████████████| 200/200 [28:55<00:00,  8.68s/it]



		6812170	RowCol-MLP-Mixer ChannelExpand
Begin Training for cifar10_layered_RowColMixer_l2_c20_h2_s147
Num Parameters: 6812170


 68%|█████████████████████████████████                | 135/200 [1:03:21<36:27, 33.65s/it]