In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys

In [3]:
import sparse_nonlinear_lib as snl

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## For FMNIST dataset

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [7]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([50, 1, 28, 28])

# Model Comparision

## Pair Linear Mixing

In [8]:
def img2patch(x, input_dim=(1, 28, 28), patch_size=(7, 4)):
    y = nn.functional.unfold(x, 
                             kernel_size=patch_size, 
                             stride=patch_size
                            )
    return y

In [9]:
def patch2img(x, patch_size=(7, 4), input_dim=(1, 28, 28)):
    y = nn.functional.fold(x, (input_dim[-2], input_dim[-1]), 
                               kernel_size=patch_size, 
                               stride=patch_size
                              )
    return y

In [10]:
img2patch(torch.randn(2, 1, 28, 28)).shape

torch.Size([2, 28, 28])

In [11]:
patch2img(torch.randn(2, 7* 4, 7*4)).shape

torch.Size([2, 1, 28, 28])

In [12]:
(7*4)**2, (8*4)**2

(784, 1024)

1. Linearize by expanding the dimension of folded image.

## Final Model

In [13]:
class FMNIST_BlockMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), patch_size=(7, 7)):
        super().__init__()
        self.patch_size = patch_size
        self.patch_linear = nn.Linear(img_size[0]*patch_size[0]*patch_size[1], 8*8)
#         self.dim_sel = snl.DimensionSelector(8*8*4*4, 8*4*8*4)
        self.block_mlp = snl.BlockMLP_MixerBlock(8*4*8*4, 8*4, hidden_layers_ratio=[4])
        self.fc = nn.Linear(8*4*8*4, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = img2patch(x, input_dim=(1, 28, 28), patch_size=self.patch_size)
#         print(x.shape)
        x = self.patch_linear(x.transpose(-2, -1)).reshape(bs, -1)
#         print(x.shape)
#         x = self.dim_sel(x)
#         print(x.shape)
        x = self.block_mlp(x)
#         print(x.shape)
        x = self.fc(x)
        return x

In [14]:
class FMNIST_OrdMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28)):
        super().__init__()
        self.input_dim = np.prod(img_size)
        M = 512#354 #1024
        self.l0 = nn.Linear(self.input_dim, M)
        self.l1 = nn.Linear(M, 10)
        self.elu = nn.ELU()
        self.ln = nn.LayerNorm(M)
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.l0(x.reshape(bs, -1))
        x = self.ln(x)
        x = self.elu(x)
        x = self.l1(x)
        return x

In [37]:
class FMNIST_PairBilinear(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), patch_size=(7, 7)):
        super().__init__()
        self.patch_size = patch_size
        self.patch_linear = nn.Linear(img_size[0]*patch_size[0]*patch_size[1], 8*8)
#         self.dim_sel = snl.DimensionSelector(8*8*4*4, 8*4*8*4)
        self.ln = nn.LayerNorm(8*4*8*4)
        self.bn = nn.BatchNorm1d(8*4*8*4)

        self.block_func = snl.PairBilinear_MixerBlock(8*4*8*4, 8*4*8*4, grid_width=10).
        self.fc = nn.Linear(8*4*8*4, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = img2patch(x, input_dim=(1, 28, 28), patch_size=self.patch_size)
        x = self.patch_linear(x.transpose(-2, -1)).reshape(bs, -1)
        x = self.ln(x)
        x = self.block_func(x)
        x = self.bn(x)
        x = self.fc(x)
        return x

In [66]:
block_func = snl.PairBilinear_MixerBlock(8*4*8*4, 8*4*8*4, grid_width=10).type(torch.double).to(device)
# block_func.pairwise_mixing[0].pairW.dtype
for p in block_func.parameters():
    print(p.shape, p.dtype, p.device)

torch.Size([1024]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0


In [68]:
x = torch.randn(2, 1024).double().to(device)
block_func(x)
# x.dtype

RuntimeError: expected scalar type Double but found Float

In [38]:
# model = FMNIST_BlockMLP(img_size=(1, 28, 28), patch_size=(7, 7))
# model = FMNIST_OrdMLP(img_size=(1, 28, 28))
model = FMNIST_PairBilinear(img_size=(1, 28, 28), patch_size=(7, 7))

model = model.to(device)

In [39]:
model

FMNIST_PairBilinear(
  (patch_linear): Linear(in_features=49, out_features=64, bias=True)
  (ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairLinear: [1024 -> 1024] (grid: 10)
      (1): PairLinear: [1024 -> 1024] (grid: 10)
      (2): PairLinear: [1024 -> 1024] (grid: 10)
      (3): PairLinear: [1024 -> 1024] (grid: 10)
      (4): PairLinear: [1024 -> 1024] (grid: 10)
      (5): PairLinear: [1024 -> 1024] (grid: 10)
      (6): PairLinear: [1024 -> 1024] (grid: 10)
      (7): PairLinear: [1024 -> 1024] (grid: 10)
      (8): PairLinear: [1024 -> 1024] (grid: 10)
      (9): PairLinear: [1024 -> 1024] (grid: 10)
    )
    (reducer): Identity()
  )
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

In [40]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [41]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1063050


In [42]:
## BLOCK MLP (1024/64) number of params:  281738
## ORDINARY MLP (784, 1024) number of params:  814090
## PAIR BILINEAR has 1058954

In [43]:
784*784

614656

## Training

In [44]:
 ## debugging to find the good classifier/output distribution.
# model_name = 'block_mlp_mixer_fmnist_v0'
# model_name = 'ord_mlp_mixer_fmnist_v0'
model_name = 'pair_bilinear_mixer_fmnist_v0'

In [45]:
EPOCHS = 20
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [46]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [47]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [48]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [49]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|████████████████████████████████████████████████| 1200/1200 [00:06<00:00, 189.93it/s]


[Train] 0 Loss: nan | Acc: 10.762 6457/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 411.71it/s]


[Test] 0 Loss: nan | Acc: 10.000 1000/10000
Saving..


 24%|███████████▉                                     | 291/1200 [00:01<00:05, 180.45it/s]


KeyboardInterrupt: 

In [None]:
best_acc

In [None]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

In [None]:
model.load_state_dict(checkpoint['model'])

In [None]:
model