In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

from tqdm import tqdm
import random, time, os, sys

In [3]:
import sparse_nonlinear_lib as snl

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## For FMNIST dataset

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [7]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([50, 1, 28, 28])

# Model Comparision

## Pair Linear Mixing

In [8]:
def img2patch(x, input_dim=(1, 28, 28), patch_size=(7, 4)):
    y = nn.functional.unfold(x, 
                             kernel_size=patch_size, 
                             stride=patch_size
                            )
    return y

In [9]:
def patch2img(x, patch_size=(7, 4), input_dim=(1, 28, 28)):
    y = nn.functional.fold(x, (input_dim[-2], input_dim[-1]), 
                               kernel_size=patch_size, 
                               stride=patch_size
                              )
    return y

In [10]:
img2patch(torch.randn(2, 1, 28, 28)).shape

torch.Size([2, 28, 28])

In [11]:
patch2img(torch.randn(2, 7* 4, 7*4)).shape

torch.Size([2, 1, 28, 28])

In [12]:
(7*4)**2, (8*4)**2

(784, 1024)

1. Linearize by expanding the dimension of folded image.

## Final Model

In [13]:
class FMNIST_BlockMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), patch_size=(7, 7)):
        super().__init__()
        self.patch_size = patch_size
        self.patch_linear = nn.Linear(img_size[0]*patch_size[0]*patch_size[1], 8*8)
#         self.dim_sel = snl.DimensionSelector(8*8*4*4, 8*4*8*4)

#         self.block_mlp = snl.BlockMLP_MixerBlock(8*4*8*4, 8*4, hidden_layers_ratio=[4])
        self.block_mlp = snl.BlockMLP_MixerBlock(8*4*8*4, 2, hidden_layers_ratio=[10])
        self.bn = nn.BatchNorm1d(8*4*8*4)
        self.fc = nn.Linear(8*4*8*4, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = img2patch(x, input_dim=(1, 28, 28), patch_size=self.patch_size)
#         print(x.shape)
        x = self.patch_linear(x.transpose(-2, -1)).reshape(bs, -1)
#         print(x.shape)
#         x = self.dim_sel(x)
#         print(x.shape)
        x = self.block_mlp(x)
#         print(x.shape)
        x = self.bn(x)
        x = self.fc(x)
        return x

In [14]:
class FMNIST_OrdMLP(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28)):
        super().__init__()
        self.input_dim = np.prod(img_size)
        M = 1500#354 #1024
        self.l0 = nn.Linear(self.input_dim, M)
        self.l1 = nn.Linear(M, 10)
        self.elu = nn.ELU()
        self.ln = nn.LayerNorm(M)
        
    def forward(self, x):
        bs = x.shape[0]
        x = self.l0(x.reshape(bs, -1))
        x = self.ln(x)
        x = self.elu(x)
        x = self.l1(x)
        return x

In [15]:
class FMNIST_PairBilinear(nn.Module):
    
    def __init__(self, img_size=(1, 28, 28), patch_size=(7, 7), grid_width=5):
        super().__init__()
        self.patch_size = patch_size
        self.patch_linear = nn.Linear(img_size[0]*patch_size[0]*patch_size[1], 8*8)
#         self.dim_sel = snl.DimensionSelector(8*8*4*4, 8*4*8*4)
#         self.ln = nn.LayerNorm(8*4*8*4)
        self.bn = nn.BatchNorm1d(8*4*8*4)

        self.block_func = snl.PairBilinear_MixerBlock(8*4*8*4, 8*4*8*4, grid_width=grid_width)#.double()
        self.fc = nn.Linear(8*4*8*4, 10)
        
    def forward(self, x):
        bs = x.shape[0]
        x = img2patch(x, input_dim=(1, 28, 28), patch_size=self.patch_size)
        x = self.patch_linear(x.transpose(-2, -1)).reshape(bs, -1)
#         x = self.block_func(x.type(torch.double)).to(torch.float32)
        x = self.block_func(x)
        x = self.bn(x)
#         x = self.ln(x)
        x = self.fc(x)
        return x

In [16]:
block_func = snl.PairBilinear_MixerBlock(8*4*8*4, 8*4*8*4, grid_width=10).type(torch.double).to(device)
# block_func.pairwise_mixing[0].pairW.dtype
for p in block_func.parameters():
    print(p.shape, p.dtype, p.device)

torch.Size([1024]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0
torch.Size([512, 2, 2]) torch.float64 cuda:0
torch.Size([512, 2, 10, 10]) torch.float64 cuda:0


In [17]:
block_func(torch.randn(2, 1024).to(device).double())

tensor([[ 0.6331,  0.0728,  0.9186,  ..., -0.0287, -0.1448,  0.8178],
        [-1.3230,  0.3096, -0.4694,  ..., -0.3249,  1.9116,  0.0740]],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)

In [18]:
# a = torch.randn_like(block_func.pairwise_mixing[0].Y)
try:
    a = model.block_func.pairwise_mixing[0].Y
except:
    a = block_func.pairwise_mixing[0].Y

In [19]:
a

Parameter containing:
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ...,

In [20]:
init0 = a[:,:,:,:1]
init1 = a[:,:,:1,:]
init0.shape, init1.shape

(torch.Size([512, 2, 10, 1]), torch.Size([512, 2, 1, 10]))

In [21]:
a0 = torch.diff(a, dim=-1)
a0.shape

torch.Size([512, 2, 10, 9])

In [22]:
a0[0]

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float64, grad_fn=<Se

In [23]:
a0.data.max()

tensor(0., device='cuda:0', dtype=torch.float64)

In [24]:
a1 = torch.diff(a, dim=-2)
a1.shape

torch.Size([512, 2, 9, 10])

In [25]:
a1[0]

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [26]:
a1.data.max()

tensor(0., device='cuda:0', dtype=torch.float64)

In [27]:
# model.block_func.pairwise_mixing[0].dy

In [28]:
_a0 = torch.cat([init0, a0], dim=-1)
_a0.shape

torch.Size([512, 2, 10, 10])

In [29]:
_a1 = torch.cat([init1, a1], dim=-2)
_a1.shape

torch.Size([512, 2, 10, 10])

In [30]:
f0 = torch.cumsum(_a0, dim=-1)

In [31]:
f1 = torch.cumsum(_a1, dim=-2)

In [32]:
f = (f0+f1)/2

In [33]:
torch.abs(f-a).mean()

tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

In [34]:
x = torch.randn(2, 1024).double().to(device)
block_func(x)
# x.dtype

tensor([[ 0.7227,  1.7482,  2.2501,  ...,  0.7609,  0.5850, -0.5791],
        [-0.0686,  1.0620,  1.4424,  ..., -0.2611,  0.4803,  1.3782]],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)

In [35]:
# model = FMNIST_BlockMLP(img_size=(1, 28, 28), patch_size=(7, 7))
# model = FMNIST_OrdMLP(img_size=(1, 28, 28))
model = FMNIST_PairBilinear(img_size=(1, 28, 28), patch_size=(7, 7), grid_width=7)

model = model.to(device)

In [36]:
model

FMNIST_PairBilinear(
  (patch_linear): Linear(in_features=49, out_features=64, bias=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairLinear: [1024 -> 1024] (grid: 7)
      (1): PairLinear: [1024 -> 1024] (grid: 7)
      (2): PairLinear: [1024 -> 1024] (grid: 7)
      (3): PairLinear: [1024 -> 1024] (grid: 7)
      (4): PairLinear: [1024 -> 1024] (grid: 7)
      (5): PairLinear: [1024 -> 1024] (grid: 7)
      (6): PairLinear: [1024 -> 1024] (grid: 7)
      (7): PairLinear: [1024 -> 1024] (grid: 7)
      (8): PairLinear: [1024 -> 1024] (grid: 7)
      (9): PairLinear: [1024 -> 1024] (grid: 7)
    )
    (reducer): Identity()
  )
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

In [37]:
model(torch.randn(2, 1, 28, 28).to(device)).shape

torch.Size([2, 10])

In [38]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  538762


In [39]:
## BLOCK MLP (1024/64) number of params:  281738
## ORDINARY MLP (784, 1024) number of params:  814090
## PAIR BILINEAR has 1058954

In [40]:
784*784

614656

## Training

In [41]:
 ## debugging to find the good classifier/output distribution.
# model_name = 'block_mlp_mixer_fmnist_v0'
# model_name = 'ord_mlp_mixer_fmnist_v0'
model_name = 'pair_bilinear_mixer_fmnist_v0'

In [42]:
EPOCHS = 50
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [43]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [44]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [45]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [46]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.05it/s]


[Train] 0 Loss: 0.783 | Acc: 74.672 44803/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 273.77it/s]


[Test] 0 Loss: 0.515 | Acc: 81.930 8193/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.78it/s]


[Train] 1 Loss: 0.455 | Acc: 84.468 50681/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 267.87it/s]


[Test] 1 Loss: 0.428 | Acc: 84.810 8481/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.79it/s]


[Train] 2 Loss: 0.388 | Acc: 86.643 51986/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 279.25it/s]


[Test] 2 Loss: 0.388 | Acc: 86.080 8608/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.76it/s]


[Train] 3 Loss: 0.351 | Acc: 87.738 52643/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 279.41it/s]


[Test] 3 Loss: 0.369 | Acc: 86.640 8664/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.31it/s]


[Train] 4 Loss: 0.326 | Acc: 88.715 53229/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 271.52it/s]


[Test] 4 Loss: 0.351 | Acc: 87.290 8729/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.93it/s]


[Train] 5 Loss: 0.306 | Acc: 89.362 53617/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.27it/s]


[Test] 5 Loss: 0.345 | Acc: 87.570 8757/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.17it/s]


[Train] 6 Loss: 0.291 | Acc: 89.972 53983/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.35it/s]


[Test] 6 Loss: 0.338 | Acc: 87.810 8781/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.06it/s]


[Train] 7 Loss: 0.276 | Acc: 90.337 54202/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 270.98it/s]


[Test] 7 Loss: 0.332 | Acc: 88.140 8814/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.99it/s]


[Train] 8 Loss: 0.263 | Acc: 90.990 54594/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.59it/s]


[Test] 8 Loss: 0.324 | Acc: 88.370 8837/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.99it/s]


[Train] 9 Loss: 0.251 | Acc: 91.212 54727/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.11it/s]


[Test] 9 Loss: 0.321 | Acc: 88.430 8843/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 146.11it/s]


[Train] 10 Loss: 0.241 | Acc: 91.683 55010/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 277.11it/s]


[Test] 10 Loss: 0.318 | Acc: 88.600 8860/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.06it/s]


[Train] 11 Loss: 0.232 | Acc: 91.892 55135/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.44it/s]


[Test] 11 Loss: 0.317 | Acc: 88.550 8855/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.92it/s]


[Train] 12 Loss: 0.222 | Acc: 92.283 55370/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.17it/s]


[Test] 12 Loss: 0.313 | Acc: 88.560 8856/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.34it/s]


[Train] 13 Loss: 0.213 | Acc: 92.685 55611/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 279.69it/s]


[Test] 13 Loss: 0.313 | Acc: 88.740 8874/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.74it/s]


[Train] 14 Loss: 0.206 | Acc: 92.887 55732/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 275.25it/s]


[Test] 14 Loss: 0.316 | Acc: 88.450 8845/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.17it/s]


[Train] 15 Loss: 0.198 | Acc: 93.175 55905/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 278.36it/s]


[Test] 15 Loss: 0.314 | Acc: 88.710 8871/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 146.75it/s]


[Train] 16 Loss: 0.191 | Acc: 93.547 56128/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 271.48it/s]


[Test] 16 Loss: 0.311 | Acc: 88.910 8891/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.75it/s]


[Train] 17 Loss: 0.183 | Acc: 93.853 56312/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.65it/s]


[Test] 17 Loss: 0.311 | Acc: 89.120 8912/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 149.24it/s]


[Train] 18 Loss: 0.177 | Acc: 93.957 56374/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.04it/s]


[Test] 18 Loss: 0.312 | Acc: 88.780 8878/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.43it/s]


[Train] 19 Loss: 0.170 | Acc: 94.355 56613/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 283.57it/s]


[Test] 19 Loss: 0.310 | Acc: 88.970 8897/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.63it/s]


[Train] 20 Loss: 0.164 | Acc: 94.547 56728/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 269.83it/s]


[Test] 20 Loss: 0.313 | Acc: 89.050 8905/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.23it/s]


[Train] 21 Loss: 0.160 | Acc: 94.698 56819/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 280.33it/s]


[Test] 21 Loss: 0.312 | Acc: 89.050 8905/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.52it/s]


[Train] 22 Loss: 0.154 | Acc: 94.912 56947/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 278.40it/s]


[Test] 22 Loss: 0.314 | Acc: 89.080 8908/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.87it/s]


[Train] 23 Loss: 0.148 | Acc: 95.290 57174/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 278.90it/s]


[Test] 23 Loss: 0.315 | Acc: 89.090 8909/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.58it/s]


[Train] 24 Loss: 0.144 | Acc: 95.332 57199/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 277.70it/s]


[Test] 24 Loss: 0.317 | Acc: 89.030 8903/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.47it/s]


[Train] 25 Loss: 0.140 | Acc: 95.487 57292/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 274.55it/s]


[Test] 25 Loss: 0.317 | Acc: 88.940 8894/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.58it/s]


[Train] 26 Loss: 0.136 | Acc: 95.697 57418/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 267.11it/s]


[Test] 26 Loss: 0.315 | Acc: 89.040 8904/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.41it/s]


[Train] 27 Loss: 0.131 | Acc: 95.908 57545/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 271.65it/s]


[Test] 27 Loss: 0.317 | Acc: 88.960 8896/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.01it/s]


[Train] 28 Loss: 0.128 | Acc: 96.003 57602/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.70it/s]


[Test] 28 Loss: 0.319 | Acc: 89.040 8904/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.73it/s]


[Train] 29 Loss: 0.124 | Acc: 96.233 57740/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 282.99it/s]


[Test] 29 Loss: 0.318 | Acc: 89.120 8912/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.97it/s]


[Train] 30 Loss: 0.121 | Acc: 96.238 57743/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.02it/s]


[Test] 30 Loss: 0.319 | Acc: 89.020 8902/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 143.88it/s]


[Train] 31 Loss: 0.117 | Acc: 96.445 57867/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 267.91it/s]


[Test] 31 Loss: 0.319 | Acc: 89.070 8907/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.32it/s]


[Train] 32 Loss: 0.115 | Acc: 96.558 57935/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 281.40it/s]


[Test] 32 Loss: 0.321 | Acc: 88.960 8896/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.48it/s]


[Train] 33 Loss: 0.112 | Acc: 96.672 58003/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 271.94it/s]


[Test] 33 Loss: 0.323 | Acc: 88.950 8895/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 149.23it/s]


[Train] 34 Loss: 0.111 | Acc: 96.758 58055/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 281.01it/s]


[Test] 34 Loss: 0.322 | Acc: 89.090 8909/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.28it/s]


[Train] 35 Loss: 0.107 | Acc: 96.893 58136/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 273.90it/s]


[Test] 35 Loss: 0.321 | Acc: 89.210 8921/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 146.89it/s]


[Train] 36 Loss: 0.107 | Acc: 96.948 58169/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 274.69it/s]


[Test] 36 Loss: 0.324 | Acc: 89.070 8907/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.82it/s]


[Train] 37 Loss: 0.104 | Acc: 97.100 58260/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.72it/s]


[Test] 37 Loss: 0.326 | Acc: 89.120 8912/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 149.07it/s]


[Train] 38 Loss: 0.103 | Acc: 97.073 58244/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 284.73it/s]


[Test] 38 Loss: 0.324 | Acc: 89.060 8906/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 146.78it/s]


[Train] 39 Loss: 0.102 | Acc: 97.170 58302/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 281.03it/s]


[Test] 39 Loss: 0.326 | Acc: 89.260 8926/10000
Saving..


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.98it/s]


[Train] 40 Loss: 0.101 | Acc: 97.188 58313/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 275.50it/s]


[Test] 40 Loss: 0.326 | Acc: 89.070 8907/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.37it/s]


[Train] 41 Loss: 0.100 | Acc: 97.230 58338/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 282.00it/s]


[Test] 41 Loss: 0.327 | Acc: 89.090 8909/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 147.36it/s]


[Train] 42 Loss: 0.099 | Acc: 97.210 58326/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 281.73it/s]


[Test] 42 Loss: 0.326 | Acc: 89.100 8910/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.44it/s]


[Train] 43 Loss: 0.098 | Acc: 97.310 58386/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 271.03it/s]


[Test] 43 Loss: 0.326 | Acc: 89.120 8912/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.24it/s]


[Train] 44 Loss: 0.097 | Acc: 97.340 58404/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 269.84it/s]


[Test] 44 Loss: 0.326 | Acc: 89.090 8909/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 143.96it/s]


[Train] 45 Loss: 0.096 | Acc: 97.382 58429/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 276.43it/s]


[Test] 45 Loss: 0.325 | Acc: 89.140 8914/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 145.52it/s]


[Train] 46 Loss: 0.097 | Acc: 97.393 58436/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 272.12it/s]


[Test] 46 Loss: 0.327 | Acc: 88.950 8895/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 146.62it/s]


[Train] 47 Loss: 0.096 | Acc: 97.298 58379/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 270.84it/s]


[Test] 47 Loss: 0.330 | Acc: 88.910 8891/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 144.83it/s]


[Train] 48 Loss: 0.096 | Acc: 97.377 58426/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 265.56it/s]


[Test] 48 Loss: 0.323 | Acc: 89.130 8913/10000


100%|████████████████████████████████████████████████| 1200/1200 [00:08<00:00, 148.06it/s]


[Train] 49 Loss: 0.096 | Acc: 97.328 58397/60000


100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 270.95it/s]

[Test] 49 Loss: 0.328 | Acc: 89.070 8907/10000





In [47]:
best_acc

89.26

In [48]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(89.26, 39)

In [49]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [50]:
model

FMNIST_PairBilinear(
  (patch_linear): Linear(in_features=49, out_features=64, bias=True)
  (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (block_func): PairBilinear_MixerBlock(
    (selector): BiasLayer: [1024]
    (pairwise_mixing): ModuleList(
      (0): PairLinear: [1024 -> 1024] (grid: 7)
      (1): PairLinear: [1024 -> 1024] (grid: 7)
      (2): PairLinear: [1024 -> 1024] (grid: 7)
      (3): PairLinear: [1024 -> 1024] (grid: 7)
      (4): PairLinear: [1024 -> 1024] (grid: 7)
      (5): PairLinear: [1024 -> 1024] (grid: 7)
      (6): PairLinear: [1024 -> 1024] (grid: 7)
      (7): PairLinear: [1024 -> 1024] (grid: 7)
      (8): PairLinear: [1024 -> 1024] (grid: 7)
      (9): PairLinear: [1024 -> 1024] (grid: 7)
    )
    (reducer): Identity()
  )
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)