In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys

In [3]:
import sparse_nonlinear_lib as snl

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## For FMNIST dataset

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [7]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([50, 1, 28, 28])

# Model Comparision

## Pair Linear Mixing

In [8]:
def img2patch(x, input_dim=(1, 28, 28), patch_size=(7, 4)):
    y = nn.functional.unfold(x, 
                             kernel_size=patch_size, 
                             stride=patch_size
                            )
    return y

In [9]:
def patch2img(x, patch_size=(7, 4), input_dim=(1, 28, 28)):
    y = nn.functional.fold(x, (input_dim[-2], input_dim[-1]), 
                               kernel_size=patch_size, 
                               stride=patch_size
                              )
    return y

1. Linearize by expanding the dimension of folded image.

## Final Model

In [241]:
class FMNIST_BlockMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, block_size=2, hidden_layers_ratio=[4], actf=nn.ELU):
        super().__init__()
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)

        self.block_mlp = snl.BlockMLP_MixerBlock(select, block_size, 
                                                 hidden_layers_ratio=hidden_layers_ratio, actf=actf)
        self.bn = nn.BatchNorm1d(select)
        self.actf = nn.ELU()
        self.fc = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.block_mlp(x)
        x = self.bn(x)
        x = self.actf(x)
        x = self.fc(x)
        return x

In [255]:
model = FMNIST_BlockMLP(block_size=32, hidden_layers_ratio=[16])
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  2144266


In [249]:
##### USING Hidden_Dim to change PARAMETERS
# class FMNIST_OrdMLP(nn.Module):
    
#     def __init__(self, img_size=(1, 28, 28), hidden_size=1024):
#         super().__init__()
#         self.input_dim = np.prod(img_size)
#         M = hidden_size #354 #1024
#         self.l0 = nn.Linear(self.input_dim, M)
#         self.bn = nn.BatchNorm1d(M)
#         self.elu = nn.ELU()
#         self.l1 = nn.Linear(M, 10)
        
#     def forward(self, x):
#         bs = x.shape[0]
#         x = x.reshape(bs, -1)
#         x = self.l0(x)
#         x = self.bn(x)
#         x = self.elu(x)
#         x = self.l1(x)
#         return x

# model = FMNIST_OrdMLP(hidden_size=688)

#### USING DimensionSelector to make comparative

class FMNIST_OrdMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024):
        super().__init__()
        self.input_dim = np.prod(img_size)
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        
        self.l0 = nn.Linear(select, select)
        self.bn = nn.BatchNorm1d(select)
        self.actf = nn.ELU()
        self.l1 = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.l0(x)
        x = self.bn(x)
        x = self.actf(x)
        x = self.l1(x)
        return x

In [250]:
model = FMNIST_OrdMLP(select=1024)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1061898


In [214]:
class FMNIST_SparseMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, block_size=2):
        super().__init__()
        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        
        self.l0 = snl.BlockLinear_MixerBlock(select, block_size)
        self.bn = nn.BatchNorm1d(select)
        self.actf = nn.ELU()
        self.l1 = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        x = self.dim_sel(x)
        x = self.l0(x)
        x = self.bn(x)
        x = self.actf(x)
        x = self.l1(x)
        return x

In [215]:
model = FMNIST_SparseMLP(block_size=2).to(device)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  32778


In [216]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [217]:
class FMNIST_PairBilinear(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), select=1024, grid_width=5):
        super().__init__()

        self.dim_sel = snl.DimensionSelector(np.prod(img_size), select)
        self.block_func = snl.PairBilinear_MixerBlock(select, select, grid_width=grid_width)
        self.bn = nn.BatchNorm1d(select)
        self.actf = nn.ELU()
        self.fc = nn.Linear(select, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = x.reshape(bs,-1)
        x = self.dim_sel(x)
        x = self.block_func(x)
        x = self.bn(x)
        x = self.actf(x)
        x = self.fc(x)
        return x

In [218]:
model = FMNIST_PairBilinear(grid_width=10)
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  1057802


## Create Models

In [219]:
# model = FMNIST_BlockMLP(img_size=(1, 28, 28), patch_size=(7, 7))
# model = FMNIST_OrdMLP(img_size=(1, 28, 28))
# model = FMNIST_PairBilinear(img_size=(1, 28, 28), grid_width=7)



In [227]:
model = model.to(device)
model

FMNIST_BlockMLP(
  (dim_sel): DimensionSelector: [+=240]
  (block_mlp): BlockMLP_MixerBlock(
    (facto_nets): ModuleList(
      (0): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [torch.Size([512, 2, 8])]
          (1): ELU(alpha=1.0)
          (2): BlockLinear: [torch.Size([512, 8, 2])]
        )
      )
      (1): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [torch.Size([512, 2, 8])]
          (1): ELU(alpha=1.0)
          (2): BlockLinear: [torch.Size([512, 8, 2])]
        )
      )
      (2): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [torch.Size([512, 2, 8])]
          (1): ELU(alpha=1.0)
          (2): BlockLinear: [torch.Size([512, 8, 2])]
        )
      )
      (3): BlockMLP(
        (mlp): Sequential(
          (0): BlockLinear: [torch.Size([512, 2, 8])]
          (1): ELU(alpha=1.0)
          (2): BlockLinear: [torch.Size([512, 8, 2])]
        )
      )
      (4): BlockMLP(
        (mlp): Sequential(
          (0)

In [228]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [229]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  227338


## Training

In [230]:
 ## debugging to find the good classifier/output distribution.
# model_name = 'block_mlp_mixer_fmnist_v0'
# model_name = 'ord_mlp_mixer_fmnist_v0'
model_name = 'pair_bilinear_mixer_fmnist_v0'

In [236]:
EPOCHS = 50
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [237]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [238]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [239]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [240]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 155.39it/s]


[Train] 0 Loss: 0.447 | Acc: 83.993 50396/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 386.59it/s]


[Test] 0 Loss: 0.420 | Acc: 84.980 8498/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.69it/s]


[Train] 1 Loss: 0.371 | Acc: 86.605 51963/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.20it/s]


[Test] 1 Loss: 0.394 | Acc: 86.120 8612/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.77it/s]


[Train] 2 Loss: 0.342 | Acc: 87.503 52502/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 382.70it/s]


[Test] 2 Loss: 0.370 | Acc: 86.860 8686/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 155.28it/s]


[Train] 3 Loss: 0.322 | Acc: 88.285 52971/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 379.02it/s]


[Test] 3 Loss: 0.351 | Acc: 87.730 8773/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.62it/s]


[Train] 4 Loss: 0.309 | Acc: 88.830 53298/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 381.28it/s]


[Test] 4 Loss: 0.336 | Acc: 87.970 8797/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.61it/s]


[Train] 5 Loss: 0.293 | Acc: 89.208 53525/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 382.39it/s]


[Test] 5 Loss: 0.341 | Acc: 88.110 8811/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.51it/s]


[Train] 6 Loss: 0.280 | Acc: 89.660 53796/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 396.18it/s]


[Test] 6 Loss: 0.339 | Acc: 87.950 8795/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 155.74it/s]


[Train] 7 Loss: 0.270 | Acc: 90.058 54035/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 384.78it/s]


[Test] 7 Loss: 0.327 | Acc: 88.800 8880/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.19it/s]


[Train] 8 Loss: 0.264 | Acc: 90.218 54131/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 376.76it/s]


[Test] 8 Loss: 0.353 | Acc: 87.870 8787/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 158.67it/s]


[Train] 9 Loss: 0.254 | Acc: 90.518 54311/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 375.54it/s]


[Test] 9 Loss: 0.326 | Acc: 88.660 8866/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.93it/s]


[Train] 10 Loss: 0.248 | Acc: 90.802 54481/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 385.05it/s]


[Test] 10 Loss: 0.325 | Acc: 88.670 8867/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 159.03it/s]


[Train] 11 Loss: 0.236 | Acc: 91.245 54747/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.08it/s]


[Test] 11 Loss: 0.314 | Acc: 88.900 8890/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.29it/s]


[Train] 12 Loss: 0.232 | Acc: 91.377 54826/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 383.04it/s]


[Test] 12 Loss: 0.330 | Acc: 88.600 8860/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.96it/s]


[Train] 13 Loss: 0.222 | Acc: 91.675 55005/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 387.50it/s]


[Test] 13 Loss: 0.330 | Acc: 88.780 8878/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.56it/s]


[Train] 14 Loss: 0.216 | Acc: 91.930 55158/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.16it/s]


[Test] 14 Loss: 0.348 | Acc: 88.540 8854/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.27it/s]


[Train] 15 Loss: 0.209 | Acc: 92.195 55317/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 382.29it/s]


[Test] 15 Loss: 0.318 | Acc: 89.390 8939/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.32it/s]


[Train] 16 Loss: 0.200 | Acc: 92.595 55557/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 381.98it/s]


[Test] 16 Loss: 0.317 | Acc: 89.090 8909/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.82it/s]


[Train] 17 Loss: 0.195 | Acc: 92.770 55662/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 398.42it/s]


[Test] 17 Loss: 0.339 | Acc: 88.540 8854/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 153.90it/s]


[Train] 18 Loss: 0.189 | Acc: 92.922 55753/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 387.39it/s]


[Test] 18 Loss: 0.327 | Acc: 89.480 8948/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.56it/s]


[Train] 19 Loss: 0.181 | Acc: 93.267 55960/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.19it/s]


[Test] 19 Loss: 0.334 | Acc: 89.310 8931/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 153.95it/s]


[Train] 20 Loss: 0.175 | Acc: 93.565 56139/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 381.70it/s]


[Test] 20 Loss: 0.329 | Acc: 89.330 8933/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.39it/s]


[Train] 21 Loss: 0.166 | Acc: 93.788 56273/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 382.08it/s]


[Test] 21 Loss: 0.328 | Acc: 89.230 8923/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.46it/s]


[Train] 22 Loss: 0.160 | Acc: 94.052 56431/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 383.00it/s]


[Test] 22 Loss: 0.333 | Acc: 89.280 8928/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.28it/s]


[Train] 23 Loss: 0.153 | Acc: 94.257 56554/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 381.45it/s]


[Test] 23 Loss: 0.329 | Acc: 89.750 8975/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 153.48it/s]


[Train] 24 Loss: 0.146 | Acc: 94.508 56705/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 373.32it/s]


[Test] 24 Loss: 0.356 | Acc: 89.500 8950/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.62it/s]


[Train] 25 Loss: 0.141 | Acc: 94.757 56854/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 383.84it/s]


[Test] 25 Loss: 0.345 | Acc: 89.610 8961/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.50it/s]


[Train] 26 Loss: 0.135 | Acc: 94.880 56928/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 393.04it/s]


[Test] 26 Loss: 0.351 | Acc: 89.540 8954/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.99it/s]


[Train] 27 Loss: 0.130 | Acc: 95.207 57124/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 384.11it/s]


[Test] 27 Loss: 0.344 | Acc: 89.770 8977/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.19it/s]


[Train] 28 Loss: 0.120 | Acc: 95.677 57406/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 389.72it/s]


[Test] 28 Loss: 0.360 | Acc: 89.230 8923/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 155.68it/s]


[Train] 29 Loss: 0.116 | Acc: 95.780 57468/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 386.16it/s]


[Test] 29 Loss: 0.369 | Acc: 89.540 8954/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.80it/s]


[Train] 30 Loss: 0.109 | Acc: 95.995 57597/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 391.80it/s]


[Test] 30 Loss: 0.365 | Acc: 89.660 8966/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.34it/s]


[Train] 31 Loss: 0.104 | Acc: 96.280 57768/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 385.74it/s]


[Test] 31 Loss: 0.365 | Acc: 89.680 8968/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.51it/s]


[Train] 32 Loss: 0.098 | Acc: 96.382 57829/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 384.35it/s]


[Test] 32 Loss: 0.373 | Acc: 89.570 8957/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 155.06it/s]


[Train] 33 Loss: 0.093 | Acc: 96.653 57992/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 386.74it/s]


[Test] 33 Loss: 0.368 | Acc: 89.680 8968/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 154.79it/s]


[Train] 34 Loss: 0.089 | Acc: 96.800 58080/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 385.38it/s]


[Test] 34 Loss: 0.377 | Acc: 89.550 8955/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 159.57it/s]


[Train] 35 Loss: 0.083 | Acc: 97.143 58286/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.82it/s]


[Test] 35 Loss: 0.383 | Acc: 89.450 8945/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.48it/s]


[Train] 36 Loss: 0.078 | Acc: 97.297 58378/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 375.30it/s]


[Test] 36 Loss: 0.384 | Acc: 89.520 8952/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.33it/s]


[Train] 37 Loss: 0.074 | Acc: 97.532 58519/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.22it/s]


[Test] 37 Loss: 0.389 | Acc: 89.490 8949/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 158.27it/s]


[Train] 38 Loss: 0.071 | Acc: 97.603 58562/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 385.89it/s]


[Test] 38 Loss: 0.388 | Acc: 89.740 8974/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.77it/s]


[Train] 39 Loss: 0.067 | Acc: 97.850 58710/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 385.33it/s]


[Test] 39 Loss: 0.398 | Acc: 89.390 8939/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 153.82it/s]


[Train] 40 Loss: 0.064 | Acc: 97.953 58772/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 377.72it/s]


[Test] 40 Loss: 0.399 | Acc: 89.370 8937/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 153.48it/s]


[Train] 41 Loss: 0.061 | Acc: 98.035 58821/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 381.37it/s]


[Test] 41 Loss: 0.401 | Acc: 89.420 8942/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 156.78it/s]


[Train] 42 Loss: 0.060 | Acc: 98.112 58867/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 388.77it/s]


[Test] 42 Loss: 0.401 | Acc: 89.360 8936/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.65it/s]


[Train] 43 Loss: 0.057 | Acc: 98.212 58927/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 383.22it/s]


[Test] 43 Loss: 0.404 | Acc: 89.560 8956/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:07<00:00, 157.82it/s]


[Train] 44 Loss: 0.056 | Acc: 98.275 58965/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 377.29it/s]


[Test] 44 Loss: 0.409 | Acc: 89.280 8928/10000


 91%|███████████████████████████████████████████▊    | 1094/1200 [00:06<00:00, 158.78it/s]


KeyboardInterrupt: 

In [None]:
best_acc

In [None]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

In [49]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [50]:
model

FMNIST_PairBilinear(
  (patch_linear): Linear(in_features=49, out_features=64, bias=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairLinear: [1024 -> 1024] (grid: 7)
      (1): PairLinear: [1024 -> 1024] (grid: 7)
      (2): PairLinear: [1024 -> 1024] (grid: 7)
      (3): PairLinear: [1024 -> 1024] (grid: 7)
      (4): PairLinear: [1024 -> 1024] (grid: 7)
      (5): PairLinear: [1024 -> 1024] (grid: 7)
      (6): PairLinear: [1024 -> 1024] (grid: 7)
      (7): PairLinear: [1024 -> 1024] (grid: 7)
      (8): PairLinear: [1024 -> 1024] (grid: 7)
      (9): PairLinear: [1024 -> 1024] (grid: 7)
    )
    (reducer): Identity()
  )
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

## Test possible dimensions