In [14]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
from tqdm import tqdm

In [28]:
# device = torch.device("cuda:1")
device = torch.device("cpu")

In [29]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [30]:
import torch.optim as optim
from torch.utils import data

In [31]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

cifar_train = transforms.Compose([
    transforms.RandAugment(magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [32]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [33]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape, yy.shape

(torch.Size([32, 3, 32, 32]), torch.Size([32]))

# Model

In [34]:
from transformers_lib import TransformerBlock, \
        Mixer_TransformerBlock_Encoder, \
        PositionalEncoding, \
        ViT_Classifier

In [35]:
class Mixer_ViT_Classifier(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
        final_dim = int(__patch_size*hidden_expansion/2)*2
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"ViT Mixer : Channes per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.transformer_blocks = []
        
        f = self.get_factors(self.channel_dim)
        print(f)
        fi = np.abs(np.array(f) - np.sqrt(self.channel_dim)).argmin()
        
        _n_heads = f[fi]
        
        ## number of dims per channel -> channel_dim
        print('Num patches',self.patch_dim)
        print(self.channel_dim, _n_heads)
        
        ### Find the block size for sequence:
#         block_seq_size = int(2**np.ceil(np.log2(np.sqrt(16))))
        block_seq_size = 16

#         block_seq_size = int(2**np.ceil(np.log2(np.sqrt(self.patch_dim))))
        print(f'Mixing with block: {block_seq_size}')
        
#         block = int(np.ceil(np.sqrt(self.patch_dim)))
        for i in range(num_blocks):
            L = Mixer_TransformerBlock_Encoder(self.patch_dim, block_seq_size, self.channel_dim, _n_heads, 0, 2)
            self.transformer_blocks.append(L)
        self.transformer_blocks = nn.Sequential(*self.transformer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        self.positional_encoding = PositionalEncoding(self.channel_dim, dropout=0)
        
        
    def get_factors(self, n):
        facts = []
        for i in range(2, n+1):
            if n%i == 0:
                facts.append(i)
        return facts
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
#         x = self.positional_encoding(x)
        x = self.transformer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [36]:
vit_mixer = Mixer_ViT_Classifier((3, 32, 32), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=1, num_classes=10)

ViT Mixer : Channes per patch -> Initial:48 Final:114
[2, 3, 6, 19, 38, 57, 114]
Num patches 64
114 6
Mixing with block: 16


In [37]:
vit_mixer

Mixer_ViT_Classifier(
  (unfold): Unfold(kernel_size=[4, 4], dilation=1, padding=0, stride=[4, 4])
  (channel_change): Linear(in_features=48, out_features=114, bias=True)
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention_Sparse(
            (values): Linear(in_features=114, out_features=114, bias=True)
            (keys): Linear(in_features=114, out_features=114, bias=True)
            (queries): Linear(in_features=114, out_features=114, bias=True)
            (fc_out): Linear(in_features=114, out_features=114, bias=True)
          )
          (norm1): LayerNorm((114,), eps=1e-05, elementwise_affine=True)
          (feed_forward): ResMlpBlock(
            (mlp): Sequential(
              (0): Linear(in_features=114, out_features=228, bias=True)
              (1): GELU()
              (2): Linear(in_features=228, out_features=114, bias=True)
    

In [38]:
vit_mixer(torch.randn(1, 3, 32, 32)).shape

torch.Size([1, 10])

In [39]:
asdasd

NameError: name 'asdasd' is not defined

#### Final Model

In [40]:
torch.manual_seed(SEED)
# model = ViT_Classifier((3, 32, 32), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=3, num_classes=10, pos_emb=False)
# model = Mixer_ViT_Classifier((3, 32, 32), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=3, num_classes=10)

model = ViT_Classifier((3, 32, 32), patch_size=[1, 1], hidden_expansion=2.4, num_blocks=9, num_classes=10, pos_emb=False)
# model = Mixer_ViT_Classifier((3, 32, 32), patch_size=[1, 1], hidden_expansion=2.4, num_blocks=3, num_classes=10)

model = model.to(device)

ViT Mixer : Channes per patch -> Initial:3 Final:6
[2, 3, 6]
6 2


In [None]:
model

In [None]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 
## Patch ||  1137220
## Mixer ||  1141703
## ViT   ||  1130776 / 1341220 (1) / 1025554 (2) / 709888 (3) / 394222 (4)
## SMViT ||  1341220 / 1025554 (1) / 709888 (2) / 394222 (3)

## Training

In [None]:
# model_name = f'vit_mixer_c10_s{SEED}'
# model_name = f'vit_pe_mixer_c10_s{SEED}'

# model_name = f'vit_sparse_mixer_c10_s{SEED}' ## sparse but with 12 layers total
# model_name = f'vit1_mixer_c10_s{SEED}' ## with 12 layers for comparision
# model_name = f'vit1_sparse_mixer_c10_s{SEED}' ## sparse with 9 layers for reference
# model_name = f'vit2_mixer_c10_s{SEED}' ## with 9 layers for reference
# model_name = f'vit2_sparse_mixer_c10_s{SEED}' ## sparse with 6 layers for reference
# model_name = f'vit3_mixer_c10_s{SEED}' ## with 6 layers for reference
# model_name = f'vit3_sparse_mixer_c10_s{SEED}' ## sparse with 3 layers for reference
# model_name = f'vit4_mixer_c10_s{SEED}' ## with 3 layers for reference
# model_name = f'vit4_sparse_mixer_c10_s{SEED}' ## sparse with 3 layers for reference and 1x1 patch

model_name = f'temp_s{SEED}'

In [None]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
STAT ={'train_stat':[], 'test_stat':[]}

In [None]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [None]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    with open(f"./output/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [None]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [None]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
# for epoch in range(2): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

In [None]:
best_acc

In [None]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

In [None]:
### the expansion is 2.4
### 83.69 for 12 layers sparse vit
### 82.46 for 12 layers vit
### 82.57 for 10 layers vit
### 84.84 for 9 = (3*3) layers sparse vit
### 82.47 for 9 layers vit
### 83.88 for 6 = (2*3) layers sparse vit
### 81.36 for 6 layers vit
### 81.68 for 3 = (1*3) layers sparse vit
### 81.50 for 3 layers vit

In [None]:
model.load_state_dict(checkpoint['model'])

In [None]:
model

In [None]:
with open(f"./output/{model_name}_data.json", 'r') as f:
    STAT = json.load(f)

In [None]:
STAT

In [None]:
train_stat = np.array(STAT['train_stat'])
test_stat = np.array(STAT['test_stat'])

In [None]:
plt.plot(train_stat[:,1], label='train')
plt.plot(test_stat[:,1], label='test')
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"./output/plots/{model_name}_loss.svg")
plt.show()

In [None]:
plt.plot(train_stat[:,2], label='train')
plt.plot(test_stat[:,2], label='test')
plt.ylabel("Accuracy")
plt.legend()
plt.savefig(f"./output/plots/{model_name}_accs.svg")
plt.show()