In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys, json

In [3]:
import sparse_nonlinear_lib as snl

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## For FMNIST dataset

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [7]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([50, 1, 28, 28])

# Model Comparision

## Pair Linear Mixing

In [8]:
def img2patch(x, input_dim=(1, 28, 28), patch_size=(7, 4)):
    y = nn.functional.unfold(x, 
                             kernel_size=patch_size, 
                             stride=patch_size
                            )
    return y

In [9]:
def patch2img(x, patch_size=(7, 4), input_dim=(1, 28, 28)):
    y = nn.functional.fold(x, (input_dim[-2], input_dim[-1]), 
                               kernel_size=patch_size, 
                               stride=patch_size
                              )
    return y

1. Linearize by expanding the dimension of folded image.

## Final Model

In [10]:
class FMNIST_BlockMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, block_size=2, hidden_layers_ratio=[4], actf=nn.GELU):
        super().__init__()
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)

        self.block_mlp = snl.BlockMLP_MixerBlock(select, block_size, 
                                                 hidden_layers_ratio=hidden_layers_ratio, actf=actf)
#         self.norm = nn.BatchNorm1d(select)
        self.norm = nn.LayerNorm(select)

        self.actf = actf()
        self.fc = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.block_mlp(x)
        x = self.norm(x)
        x = self.actf(x)
        x = self.fc(x)
        return x

In [11]:
model = FMNIST_BlockMLP(block_size=4)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  201738


In [12]:
model

FMNIST_BlockMLP(
  (dim_sel): DimensionSelector: [+=240]
  (block_mlp): BlockMLP_MixerBlock(
    (facto_nets): ModuleList(
      (0): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [256, 4, 16]
          (1): GELU()
          (2): BlockLinear: [256, 16, 4]
        )
      )
      (1): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [256, 4, 16]
          (1): GELU()
          (2): BlockLinear: [256, 16, 4]
        )
      )
      (2): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [256, 4, 16]
          (1): GELU()
          (2): BlockLinear: [256, 16, 4]
        )
      )
      (3): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [256, 4, 16]
          (1): GELU()
          (2): BlockLinear: [256, 16, 4]
        )
      )
      (4): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [256, 4, 16]
          (1): GELU()
          (2): BlockLinear: [256, 16, 4]
        )
      )
    )
  )
  (norm): LayerNo

In [13]:
# asdfasdf

In [14]:
#### USING DimensionSelector to make comparative

class FMNIST_OrdMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024):
        super().__init__()
        self.input_dim = np.prod(img_size)
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        
        self.l0 = nn.Linear(select, select)
        self.norm = nn.LayerNorm(select)
        self.actf = nn.GELU()
        self.l1 = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.l0(x)
        x = self.norm(x)
        x = self.actf(x)
        x = self.l1(x)
        return x

In [15]:
model = FMNIST_OrdMLP(select=1024)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1061898


In [16]:
model

FMNIST_OrdMLP(
  (dim_sel): DimensionSelector: [+=240]
  (l0): Linear(in_features=1024, out_features=1024, bias=True)
  (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (actf): GELU()
  (l1): Linear(in_features=1024, out_features=10, bias=True)
)

In [17]:
class FMNIST_SparseMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, block_size=2):
        super().__init__()
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        
        self.l0 = snl.BlockLinear_MixerBlock(select, block_size)
        self.norm = nn.LayerNorm(select)
        self.actf = nn.GELU()
        self.l1 = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.l0(x)
        x = self.norm(x)
        x = self.actf(x)
        x = self.l1(x)
        return x

In [18]:
class FMNIST_SparseMLP_PWLF(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, block_size=2):
        super().__init__()
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        
        self.l0 = snl.BlockLinear_MixerBlock(select, block_size)
        self.norm = nn.LayerNorm(select)
        self.pwlf = snl.PairBilinear(select, 5)
#         self.actf = nn.GELU()
        self.l1 = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.l0(x)
        x = self.norm(x)
        x = self.pwlf(x)
#         x = self.actf(x)
        x = self.l1(x)
        return x

In [19]:
# model = FMNIST_SparseMLP(block_size=32).to(device)
model = FMNIST_SparseMLP_PWLF(block_size=32).to(device)

print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  105482


In [20]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [21]:
model

FMNIST_SparseMLP_PWLF(
  (dim_sel): DimensionSelector: [+=240]
  (l0): BlockLinear_MixerBlock(
    (facto_nets): ModuleList(
      (0): BlockWeight: [32, 32, 32]
      (1): BlockWeight: [32, 32, 32]
    )
  )
  (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (pwlf): PairBilinear: [1024 -> 1024] (grid: 5)
  (l1): Linear(in_features=1024, out_features=10, bias=True)
)

In [22]:
class FMNIST_PairBilinear(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, grid_width=5):
        super().__init__()

        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        self.block_func = snl.PairBilinear_MixerBlock(select, select, grid_width=grid_width)
        self.norm = nn.LayerNorm(select)
        self.actf = nn.GELU()
#         self.actf = nn.ELU()
        self.fc = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs,-1)
        x = self.dim_sel(x)
        x = self.block_func(x)
        x = self.norm(x)
        x = self.actf(x)
        x = self.fc(x)
        return x

In [23]:
model = FMNIST_PairBilinear(grid_width=3)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  125962


In [24]:
model

FMNIST_PairBilinear(
  (dim_sel): DimensionSelector: [+=240]
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairBilinear: [1024 -> 1024] (grid: 3)
      (1): PairBilinear: [1024 -> 1024] (grid: 3)
      (2): PairBilinear: [1024 -> 1024] (grid: 3)
      (3): PairBilinear: [1024 -> 1024] (grid: 3)
      (4): PairBilinear: [1024 -> 1024] (grid: 3)
      (5): PairBilinear: [1024 -> 1024] (grid: 3)
      (6): PairBilinear: [1024 -> 1024] (grid: 3)
      (7): PairBilinear: [1024 -> 1024] (grid: 3)
      (8): PairBilinear: [1024 -> 1024] (grid: 3)
      (9): PairBilinear: [1024 -> 1024] (grid: 3)
    )
    (reducer): Identity()
  )
  (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (actf): GELU()
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

## Create Models

In [25]:
# model = FMNIST_BlockMLP(img_size=(1, 28, 28), patch_size=(7, 7))
# model = FMNIST_OrdMLP(img_size=(1, 28, 28))
# model = FMNIST_PairBilinear(img_size=(1, 28, 28), grid_width=7)



In [26]:
model = model.to(device)
model

FMNIST_PairBilinear(
  (dim_sel): DimensionSelector: [+=240]
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairBilinear: [1024 -> 1024] (grid: 3)
      (1): PairBilinear: [1024 -> 1024] (grid: 3)
      (2): PairBilinear: [1024 -> 1024] (grid: 3)
      (3): PairBilinear: [1024 -> 1024] (grid: 3)
      (4): PairBilinear: [1024 -> 1024] (grid: 3)
      (5): PairBilinear: [1024 -> 1024] (grid: 3)
      (6): PairBilinear: [1024 -> 1024] (grid: 3)
      (7): PairBilinear: [1024 -> 1024] (grid: 3)
      (8): PairBilinear: [1024 -> 1024] (grid: 3)
      (9): PairBilinear: [1024 -> 1024] (grid: 3)
    )
    (reducer): Identity()
  )
  (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (actf): GELU()
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

In [27]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [28]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  125962


## Training

In [29]:
 ## debugging to find the good classifier/output distribution.
# model_name = 'block_mlp_mixer_fmnist_v0'
# model_name = 'ord_mlp_mixer_fmnist_v0'
model_name = 'pair_bilinear_mixer_fmnist_v0'

In [30]:
EPOCHS = 50
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [31]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    loss = train_loss/(batch_idx+1)
    acc = 100.*correct/total
    print(f"[Train] {epoch} Loss: {loss:.3f} | Acc: {acc:.3f} {correct}/{total}")
    return loss, acc

In [32]:
# best_acc = -1
def test(epoch, model, optimizer, best_acc, model_name):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    latency = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            
            start = time.time()
            outputs = model(inputs)
            ttaken = time.time()-start
                
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            latency.append(ttaken)
    
    loss = test_loss/(batch_idx+1)
    acc = 100.*correct/total
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    return loss, acc, best_acc, latency

In [33]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [34]:
# # ### Train the whole damn thing

# best_acc = -1
# for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
#     trloss, tracc = train(epoch, model, optimizer)
#     teloss, teacc, best_acc, latency = test(epoch, model, optimizer, best_acc, model_name)
#     scheduler.step()

In [35]:
# best_acc ## 90.42 for ordinary, 89.59 for sparse, 89.82 fro 32bMLP, 

### Do all experiments in repeat

In [36]:
def train_model(model, lr, model_name, epochs=50, seed=0):
    global criterion, train_loader, test_loader
    
    torch.manual_seed(seed)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)
    
    best_acc = -1
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    n_params = sum(p.numel() for p in model.parameters())
    stats = {'num_param':n_params, 'latency': [], 
             'train_acc':[], 'train_loss':[], 
             'test_acc':[], 'test_loss':[] 
            }
    
    print(f"Begin Training for {model_name}")
    print(f"Num Parameters: {n_params}")

    for epoch in range(epochs):
        trloss, tracc = train(epoch, model, optimizer)
        teloss, teacc, best_acc, laten = test(epoch, model, optimizer, best_acc, model_name)
        scheduler.step()
        
        stats['latency'] += laten
        stats['train_acc'].append(tracc)
        stats['test_acc'].append(teacc)
        stats['train_loss'].append(trloss)
        stats['test_loss'].append(teloss)
        
    print()
    
    latency = np.array(stats['latency'])
    mu, std = np.mean(latency), np.std(latency)
    stats['latency'] = {'mean':mu, 'std':std}
    ### Save stats of the model
    with open(f'./models/stats/{model_name}_stats.json', 'w') as f:
        json.dump(stats, f)
    
    return stats, best_acc

In [37]:
mlp_dims_scale = {
    32: [2, 4, 8],
    4: [4, 8, 16],
    2: [4, 8, 16],
}
pwlf_grid_size = [3, 5, 9]

SEEDS = [147, 258, 369]

def benchmark_fmnist():
    for seed in [147]:
        ### First test MLP with allowed dimension mixing
        for dim, hid_dim in mlp_dims_scale.items(): ## For 1024 these are the factors
            for hr in hid_dim:
                model = FMNIST_BlockMLP(block_size=dim, hidden_layers_ratio=[hr])
                n_params = sum(p.numel() for p in model.parameters())
                print(f"{dim}\t{hr}\t{n_params}\tBlockMLP")
            
            model = FMNIST_SparseMLP(block_size=dim)
            n_params = sum(p.numel() for p in model.parameters())
            print(f"{dim}\t\t{n_params}\tSparseMLP")
            
            model = FMNIST_SparseMLP_PWLF(block_size=dim)
            n_params = sum(p.numel() for p in model.parameters())
            print(f"{dim}\t\t{n_params}\tSparseMLP_PWLF")
            print()
        
        for gsz in pwlf_grid_size:
            model = FMNIST_PairBilinear(grid_width=gsz)
            n_params = sum(p.numel() for p in model.parameters())
            print(f"{2}\t{gsz}\t{n_params}\tPairPWLF")
        
        print()
        model = FMNIST_OrdMLP()
        n_params = sum(p.numel() for p in model.parameters())
        print(f"\t\t{n_params}\tOrdMLP")


benchmark_fmnist()

32	2	280586	BlockMLP
32	4	546826	BlockMLP
32	8	1079306	BlockMLP
32		77834	SparseMLP
32		105482	SparseMLP_PWLF

4	4	201738	BlockMLP
4	8	386058	BlockMLP
4	16	754698	BlockMLP
4		32778	SparseMLP
4		60426	SparseMLP_PWLF

2	4	227338	BlockMLP
2	8	432138	BlockMLP
2	16	841738	BlockMLP
2		32778	SparseMLP
2		60426	SparseMLP_PWLF

2	3	125962	PairPWLF
2	5	289802	PairPWLF
2	9	863242	PairPWLF

		1061898	OrdMLP


In [38]:
# model = FMNIST_OrdMLP()
# sum(p.numel() for p in model.parameters())

## Configuring training and saving functionality

In [41]:
mlp_dims_scale = {
    32: [2, 4, 8],
    4: [4, 8, 16],
    2: [4, 8, 16],
}
pwlf_grid_size = [3, 5, 9]

SEEDS = [147, 258, 369]
EPOCHS = 50
LR = 0.001

def benchmark_fmnist():
#     for seed in [147]:
    for seed in SEEDS:
        ## First test MLP with allowed dimension mixing
        for dim, hid_dim in mlp_dims_scale.items(): ## For 1024 these are the factors
            for hr in hid_dim:
                torch.manual_seed(seed)
                model = FMNIST_BlockMLP(block_size=dim, hidden_layers_ratio=[hr])
                model_name = f"fmnist_BlockMLP_b{dim}_h{hr}_s{seed}"
                train_model(model, LR, model_name, EPOCHS, seed)

            torch.manual_seed(seed)
            model = FMNIST_SparseMLP(block_size=dim)
            model_name = f"fmnist_SparseMLP_b{dim}_s{seed}"
            train_model(model, LR, model_name, EPOCHS, seed)
            
            torch.manual_seed(seed)
            model = FMNIST_SparseMLP_PWLF(block_size=dim)
            model_name = f"fmnist_SparseMLP_PWLF_b{dim}_s{seed}"
            train_model(model, LR, model_name, EPOCHS, seed)
            
        for gsz in pwlf_grid_size:
            torch.manual_seed(seed)
            model = FMNIST_PairBilinear(grid_width=gsz)
            model_name = f"fmnist_PairPWLF_g{gsz}_s{seed}"
            train_model(model, 0.00003, model_name, EPOCHS, seed)
        
        torch.manual_seed(seed)
        model = FMNIST_OrdMLP()
        model_name = f"fmnist_OrdinaryMLP_s{seed}"
        train_model(model, LR, model_name, EPOCHS, seed)

In [42]:
benchmark_fmnist()

100%|█████████████████████████████████████████████████| 1200/1200 [00:15<00:00, 79.66it/s]


[Train] 28 Loss: 0.070 | Acc: 97.507 58504/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 118.13it/s]


[Test] 28 Loss: 0.446 | Acc: 88.840 8884/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:18<00:00, 65.89it/s]


[Train] 29 Loss: 0.064 | Acc: 97.822 58693/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 144.39it/s]


[Test] 29 Loss: 0.439 | Acc: 89.140 8914/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:18<00:00, 64.73it/s]


[Train] 30 Loss: 0.057 | Acc: 98.070 58842/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 131.93it/s]


[Test] 30 Loss: 0.460 | Acc: 89.010 8901/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:18<00:00, 65.42it/s]


[Train] 31 Loss: 0.051 | Acc: 98.348 59009/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 138.95it/s]


[Test] 31 Loss: 0.468 | Acc: 88.920 8892/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:17<00:00, 69.01it/s]


[Train] 32 Loss: 0.045 | Acc: 98.592 59155/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 145.92it/s]


[Test] 32 Loss: 0.474 | Acc: 89.060 8906/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:15<00:00, 78.01it/s]


[Train] 33 Loss: 0.039 | Acc: 98.828 59297/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 147.39it/s]


[Test] 33 Loss: 0.486 | Acc: 89.140 8914/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:17<00:00, 68.49it/s]


[Train] 34 Loss: 0.034 | Acc: 98.997 59398/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 162.53it/s]


[Test] 34 Loss: 0.503 | Acc: 88.940 8894/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:16<00:00, 70.83it/s]


[Train] 35 Loss: 0.029 | Acc: 99.197 59518/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 185.16it/s]


[Test] 35 Loss: 0.522 | Acc: 88.760 8876/10000


100%|█████████████████████████████████████████████████| 1200/1200 [00:16<00:00, 74.34it/s]


[Train] 36 Loss: 0.025 | Acc: 99.367 59620/60000


100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 145.17it/s]


[Test] 36 Loss: 0.531 | Acc: 88.820 8882/10000


 54%|███████████████████████████▏                      | 653/1200 [00:09<00:07, 71.28it/s]

KeyboardInterrupt

