In [16]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [17]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys, json

In [18]:
# import sparse_nonlinear_lib as snl

In [19]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [20]:
# time.sleep(60*60)

## For CIFAR10 dataset

In [21]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [22]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [23]:
## demo of train loader
# xx, yy = iter(train_loader).next()
for xx, yy in train_loader:
    break
xx.shape

torch.Size([32, 3, 32, 32])

# Model Comparision

In [24]:
class MlpBLock(nn.Module):
    
    def __init__(self, input_dim, hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        self.input_dim = input_dim
        #### convert hidden layers ratio to list if integer is inputted
        if isinstance(hidden_layers_ratio, int):
            hidden_layers_ratio = [hidden_layers_ratio]
            
        self.hlr = [1]+hidden_layers_ratio+[1]
        
        self.mlp = []
        ### for 1 hidden layer, we iterate 2 times
        for h in range(len(self.hlr)-1):
            i, o = int(self.hlr[h]*self.input_dim),\
                    int(self.hlr[h+1]*self.input_dim)
            self.mlp.append(nn.Linear(i, o))
            self.mlp.append(actf())
        self.mlp = self.mlp[:-1]
        
        self.mlp = nn.Sequential(*self.mlp)
        
    def forward(self, x):
        return self.mlp(x)

In [25]:
# class CIFAR10_ImageMonarchMLP(nn.Module):
    
#     def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], actf=nn.GELU):
#         super().__init__()
        
#         self.block0 = MlpBLock(img_size[0]*img_size[1], hidden_layers_ratio, actf=actf)
#         self.block1 = MlpBLock(img_size[0]*img_size[2], hidden_layers_ratio, actf=actf)
        
# #         self.norm = nn.BatchNorm1d(select)
#         self.norm = nn.LayerNorm(np.prod(img_size))
    
#         ### Can also use normalization per block for effeciency
# #         self.norm1 = nn.LayerNorm(self.block1.input_dim)

#         self.actf = actf()
#         self.fc = nn.Linear(np.prod(img_size), 10)
        
#     def forward(self, x):
#         bs, C, H, W = x.shape
        
#         ### use B, W, C*H
#         x = x.permute(0, 3, 1, 2).contiguous().view(bs, W, -1)
#         x = self.block0(x).view(bs, W, C, H)
#         ### use B, H, C*W
#         x = x.permute(0, 3, 2, 1).contiguous().view(bs, H, -1)
#         x = self.block1(x).view(bs, -1)
        
#         x = self.norm(x)
#         x = self.actf(x)
#         x = self.fc(x)
#         return x

In [26]:
# CIFAR10_ImageMonarchMLP()(torch.randn(1, 3, 32, 32))

In [27]:
class RowColMixer(nn.Module):
    
    def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], actf=nn.GELU):
        super().__init__()
        
        self.block0 = MlpBLock(img_size[0]*img_size[1], hidden_layers_ratio, actf=actf)
        self.norm0 = nn.LayerNorm(self.block0.input_dim)
        self.block1 = MlpBLock(img_size[0]*img_size[2], hidden_layers_ratio, actf=actf)
        self.norm1 = nn.LayerNorm(self.block1.input_dim)
        
    def forward(self, x):
        bs, C, H, W = x.shape
        res = x
        
        ### use B, W, C*H
        x = x.permute(0, 3, 1, 2).contiguous().view(bs, W, -1)
        x = self.block0(self.norm0(x))
        ### use B, H, C*W
        x = x.view(bs, W, C, H).permute(0, 3, 2, 1).contiguous().view(bs, H, -1)
        x = self.block1(self.norm1(x))

        x = x.view(bs, H, C, W).permute(0, 2, 1, 3).contiguous()
        return x + res
    
class CIFAR10_RowColMixer(nn.Module):
    
    def __init__(self, img_size=(3, 32, 32), hidden_layers_ratio=[2], layers=1, channel_expand=3, actf=nn.GELU):
        super().__init__()
        assert img_size[0] <= channel_expand, "Can't reduce channels than original size"
        self.select_channels = torch.randperm(channel_expand - img_size[0])%3
        
        img_size = (channel_expand, img_size[1], img_size[2])
        
        self.blocks = []
        for i in range(layers):
            self.blocks.append(RowColMixer(img_size, hidden_layers_ratio, actf=actf))
        self.blocks = nn.Sequential(*self.blocks)
        
#         self.norm = nn.BatchNorm1d(select)
        self.norm = nn.LayerNorm(np.prod(img_size))
        self.actf = actf()
        self.fc = nn.Linear(np.prod(img_size), 10)
        
    def forward(self, x):
        bs, C, H, W = x.shape
        x = torch.cat((x, x[:, self.select_channels, :, :]), dim=1)
        
        ### use B, W, C*H
        x = self.blocks(x).view(bs, -1)
        x = self.norm(x)
        x = self.actf(x)
        x = self.fc(x)
        return x

In [28]:
CIFAR10_RowColMixer(channel_expand=20)(torch.randn(1, 3, 32, 32))

tensor([[ 0.5932, -0.3034,  0.6858, -0.0429, -0.2959, -0.1132,  0.1623,  0.3577,
         -0.2613, -0.0198]], grad_fn=<AddmmBackward>)

## Create Models

In [29]:
# model = CIFAR10_ImageMonarchMLP()
model = CIFAR10_RowColMixer(layers=2)

In [30]:
model = model.to(device)
model

CIFAR10_RowColMixer(
  (blocks): Sequential(
    (0): RowColMixer(
      (block0): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=96, out_features=192, bias=True)
          (1): GELU()
          (2): Linear(in_features=192, out_features=96, bias=True)
        )
      )
      (norm0): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (block1): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=96, out_features=192, bias=True)
          (1): GELU()
          (2): Linear(in_features=192, out_features=96, bias=True)
        )
      )
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (1): RowColMixer(
      (block0): MlpBLock(
        (mlp): Sequential(
          (0): Linear(in_features=96, out_features=192, bias=True)
          (1): GELU()
          (2): Linear(in_features=192, out_features=96, bias=True)
        )
      )
      (norm0): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (block1): MlpB

In [31]:
model(torch.randn(2, 3, 32, 32).to(device)).shape

torch.Size([2, 10])

In [32]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  186250


## Training

In [33]:
## debugging to find the good classifier/output distribution.
# model_name = 'RowColumn_Mixer_CIFAR10_v0'

In [34]:
# EPOCHS = 50
# criterion = nn.CrossEntropyLoss()
# # optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# # optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# # optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [35]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
#     for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
    for batch_idx, (inputs, targets) in enumerate(train_loader):

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    loss = train_loss/(batch_idx+1)
    acc = 100.*correct/total
#     print(f"[Train] {epoch} Loss: {loss:.3f} | Acc: {acc:.3f} {correct}/{total}")
    return loss, acc

In [36]:
# best_acc = -1
def test(epoch, model, optimizer, best_acc, model_name):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    latency = []
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            start = time.time()
            outputs = model(inputs)
            ttaken = time.time()-start
                
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            latency.append(ttaken)
    
    loss = test_loss/(batch_idx+1)
    acc = 100.*correct/total
#     print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
#         print(f'Saving.. Acc: {100.*correct/total:.3f}')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    return loss, acc, best_acc, latency

In [75]:
# start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# resume = False

# if resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load(f'./models/{model_name}.pth')
#     model.load_state_dict(checkpoint['model'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

In [76]:
# # ### Train the whole damn thing

# best_acc = -1
# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     trloss, tracc = train(epoch, model, optimizer)
#     teloss, teacc, best_acc, latency = test(epoch, model, optimizer, best_acc, model_name)
#     scheduler.step()

In [77]:
# best_acc ## 90.42 for ordinary, 89.59 for sparse, 89.82 fro 32bMLP, 

### Do all experiments in repeat

In [78]:
def train_model(model, lr, model_name, epochs=200, seed=0):
    global criterion, train_loader, test_loader
    
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)
    
    best_acc = -1
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    n_params = sum(p.numel() for p in model.parameters())
    stats = {'num_param':n_params, 'latency': [], 
             'train_acc':[], 'train_loss':[], 
             'test_acc':[], 'test_loss':[] 
            }
    
    print(f"Begin Training for {model_name}")
    print(f"Num Parameters: {n_params}")

    for epoch in tqdm(range(epochs)):
        trloss, tracc = train(epoch, model, optimizer)
        teloss, teacc, best_acc, laten = test(epoch, model, optimizer, best_acc, model_name)
        scheduler.step()
        
        stats['latency'] += laten
        stats['train_acc'].append(tracc)
        stats['test_acc'].append(teacc)
        stats['train_loss'].append(trloss)
        stats['test_loss'].append(teloss)
        
    print()
    
    latency = np.array(stats['latency'])
    mu, std = np.mean(latency), np.std(latency)
    stats['latency'] = {'mean':mu, 'std':std}
    ### Save stats of the model
    with open(f'./models/stats/{model_name}_stats.json', 'w') as f:
        json.dump(stats, f)
    
    return stats, best_acc

In [41]:
###### FOR 1 MONARCH LAYER
# hidden_scale = [2, 4, 8]
# # SEEDS = [147, 258, 369]
# EPOCHS = 200
# LR = 0.001


# def benchmark_cifar10():
#     for seed in [147]:
#         ### First test MLP with allowed dimension mixing
        
#         for h in hidden_scale:
#             torch.manual_seed(seed)
            
#             model = CIFAR10_ImageMonarchMLP(hidden_layers_ratio=[h])
#             n_params = sum(p.numel() for p in model.parameters())
#             print(f"\t\t{n_params}\tMonarchMLP")
#             model_name = f"cifar10_MonarchMLP_h{h}_s{seed}"
            
#             train_model(model, LR, model_name, EPOCHS)

In [42]:
# benchmark_cifar10()

		111178	MonarchMLP
Begin Training for cifar10_MonarchMLP_h2_s147
Num Parameters: 111178


100%|███████████████████████████████████████████████████| 200/200 [22:58<00:00,  6.89s/it]



		185290	MonarchMLP
Begin Training for cifar10_MonarchMLP_h4_s147
Num Parameters: 185290


100%|███████████████████████████████████████████████████| 200/200 [22:59<00:00,  6.90s/it]



		333514	MonarchMLP
Begin Training for cifar10_MonarchMLP_h8_s147
Num Parameters: 333514


100%|███████████████████████████████████████████████████| 200/200 [22:59<00:00,  6.90s/it]







In [81]:
###### FOR MULTI-LAYERED MONARCH/ROW-COL VISION MIXER
hidden_scale = [2, 4]
EPOCHS = 200
LR = 0.001
seed = 147

def benchmark_cifar10():
    for layers in [2, 3, 4]:
        ### First test MLP with allowed dimension mixing
        
        for h in hidden_scale:
            torch.manual_seed(seed)
            
            model = CIFAR10_RowColMixer(layers=layers, hidden_layers_ratio=[h])
            n_params = sum(p.numel() for p in model.parameters())
            print(f"\t\t{n_params}\tRowCol-MLP-Mixer")
            model_name = f"cifar10_RowColMixer_h{h}_l{layers}_s{seed}"
            
            train_model(model, LR, model_name, EPOCHS)

In [82]:
benchmark_cifar10()

		186250	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h2_l2_s147
Num Parameters: 186250


100%|███████████████████████████████████████████████████| 200/200 [23:05<00:00,  6.93s/it]



		334474	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h4_l2_s147
Num Parameters: 334474


100%|███████████████████████████████████████████████████| 200/200 [23:06<00:00,  6.93s/it]



		260938	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h2_l3_s147
Num Parameters: 260938


100%|███████████████████████████████████████████████████| 200/200 [24:12<00:00,  7.26s/it]



		483274	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h4_l3_s147
Num Parameters: 483274


100%|███████████████████████████████████████████████████| 200/200 [24:14<00:00,  7.27s/it]



		335626	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h2_l4_s147
Num Parameters: 335626


100%|███████████████████████████████████████████████████| 200/200 [28:53<00:00,  8.67s/it]



		632074	RowCol-MLP-Mixer
Begin Training for cifar10_RowColMixer_h4_l4_s147
Num Parameters: 632074


100%|███████████████████████████████████████████████████| 200/200 [28:57<00:00,  8.69s/it]





