In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy, json
# from tqdm.autonotebook import tqdm
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
# SEED = 147
# SEED = 258
SEED = 369

torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
import torch.optim as optim
from torch.utils import data

In [5]:
from tiny_imagenet import TinyImageNetDataset

In [6]:
# tiny_train = transforms.Compose([
#     transforms.RandomCrop(size=64, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.5]*3,
#         std=[0.2]*3,
#     ),
# ])

tiny_train = transforms.Compose([
    transforms.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5]*3,
        std=[0.2]*3,
    ),
])

tiny_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5]*3,
        std=[0.2]*3,
    ),
])

In [7]:
class TinyImageNet_Preload(data.Dataset):
    
    def __init__(self, root, mode='train', transform=None):
        
        dataset = datasets.ImageFolder(
            root=os.path.join(root, mode),
            transform=None
        )
        self.transform = transform
        self.images, self.labels = [], []
        print("Dataset Size:",len(dataset))
        for i in tqdm(range(len(dataset))):
            x, y = dataset[i]
            self.images.append(x)
            self.labels.append(y)
        del dataset
            
    def _add_channels(img, total_channels=3):
        while len(img.shape) < 3:  # third axis is the channels
            img = np.expand_dims(img, axis=-1)
        while(img.shape[-1]) < 3:
            img = np.concatenate([img, img[:, :, -1:]], axis=-1)
        return img
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img, lbl = self.images[idx], self.labels[idx]
        return self.transform(img), lbl

In [8]:
train_dataset = TinyImageNet_Preload(root="../../../../../_Datasets/tiny-imagenet-200",
                                     mode='train', transform=tiny_train)

Dataset Size: 200000


100%|███████████████████████████████████████████| 200000/200000 [00:21<00:00, 9313.34it/s]


In [9]:
# train_dataset = TinyImageNetDataset(root_dir="../../../../../_Datasets/tiny-imagenet-200/", 
#                                     mode='train',
#                                     transform=tiny_train, 
#                                     preload=False,
#                                     download=False)

# train_dataset = datasets.ImageFolder(
#         root=os.path.join("../../../../../_Datasets/tiny-imagenet-200", 'train'),
#         transform=tiny_train
#     )

In [10]:
# train_dataset.class_to_idx

In [11]:
# test_dataset = TinyImageNetDataset(root_dir="../../../../../_Datasets/tiny-imagenet-200/", 
#                                     mode='val',
#                                     transform=tiny_test, 
#                                     preload=False,
#                                     download=False)

# test_dataset = datasets.ImageFolder(
#         root=os.path.join("../../../../../_Datasets/tiny-imagenet-200", 'val'),
#         transform=tiny_train
#     )

In [12]:
# test_dataset.class_to_idx = train_dataset.class_to_idx

In [13]:
# test_dataset.class_to_idx

In [14]:
# test_dataset.class_to_idx == train_dataset.class_to_idx

In [15]:
# "../../../../../_Datasets/tiny-imagenet-200/train/"

In [16]:
test_dataset = TinyImageNet_Preload(root="../../../../../_Datasets/tiny-imagenet-200",
                                     mode='val', transform=tiny_test)

Dataset Size: 10000


100%|█████████████████████████████████████████████| 10000/10000 [00:01<00:00, 9483.48it/s]


In [17]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=256, shuffle=False, num_workers=4)

In [18]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape, yy.shape

(torch.Size([256, 3, 64, 64]), torch.Size([256]))

In [19]:
yy

tensor([ 10,   9, 106,  24, 153, 161,  50, 122, 164, 181,  71, 138,   4,  48,
         31, 134,  79,  77,   2, 120, 100,  18,  29, 166, 108, 108,  33,  52,
        127,  90, 158,  44, 159,  69, 194,  96,  92,  64,  70,  62, 162, 141,
         39,  20, 159, 153,  42, 170,  42, 103,  88,  74, 108,  51,  34,  99,
        136, 172,  18,  64, 137,  21, 146, 102,  10,  62,   6, 114, 167, 149,
        199, 175, 104, 143,  93, 135,  59, 186,  26, 180,  65, 138, 190,  86,
         36,  59, 143,  49, 163,  10, 161, 108, 160,  13,  20,  56,  60, 117,
         19,  27,   0,  16, 133, 180,  67, 102, 143,  64, 115, 102,  88,   1,
         33, 121,   2,  31,  75, 191, 197,  43, 141,  46,  28,  81, 184,  19,
        111, 188,  26, 153,  69,  34,  14, 189, 113,  17, 149,   4,  92, 110,
        127,  40, 134, 125,  96, 139,  16, 112,  28,  32, 175, 164,  93,  35,
        156,   9, 173, 162,  25, 114,   2,   2, 161, 197,  59,  38,  40,  70,
         80, 191, 133,  48, 164,  45, 165, 140, 194,  22,  87, 1

# Model

In [18]:
from transformers_lib import TransformerBlock, \
        Mixer_TransformerBlock_Encoder, \
        PositionalEncoding, \
        ViT_Classifier

In [19]:
class Mixer_ViT_Classifier(nn.Module):
    
    def __init__(self, image_dim:tuple, patch_size:tuple, hidden_expansion:float, num_blocks:int, num_classes:int, pos_emb=True):
        super().__init__()
        
        self.img_dim = image_dim ### must contain (C, H, W) or (H, W)
        
        ### find patch dim
        d0 = int(image_dim[-2]/patch_size[0])
        d1 = int(image_dim[-1]/patch_size[1])
        assert d0*patch_size[0]==image_dim[-2], "Image must be divisible into patch size"
        assert d1*patch_size[1]==image_dim[-1], "Image must be divisible into patch size"
#         self.d0, self.d1 = d0, d1 ### number of patches in each axis
        __patch_size = patch_size[0]*patch_size[1]*image_dim[0] ## number of channels in each patch
    
        ### find channel dim
        channel_size = d0*d1 ## number of patches
        
        ### after the number of channels are changed
        init_dim = __patch_size
        final_dim = int(__patch_size*hidden_expansion/2)*2
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        #### rescale the patches (patch wise image non preserving transform, unlike bilinear interpolation)
        self.channel_change = nn.Linear(init_dim, final_dim)
        print(f"ViT Mixer : Channels per patch -> Initial:{init_dim} Final:{final_dim}")
        
        
        self.channel_dim = final_dim
        self.patch_dim = channel_size
        
        self.transformer_blocks = []
        
        f = self.get_factors(self.channel_dim)
        print(f)
        fi = np.abs(np.array(f) - np.sqrt(self.channel_dim)).argmin()
        
        _n_heads = f[fi]
        
        ## number of dims per channel -> channel_dim
        print('Num patches', self.patch_dim)
        print(self.channel_dim, _n_heads)
        
        ### Find the block size for sequence:
        block_seq_size = int(2**np.ceil(np.log2(np.sqrt(self.patch_dim))))
        print(f'Mixing with block: {block_seq_size}')
        
#         block = int(np.ceil(np.sqrt(self.patch_dim)))
        for i in range(num_blocks):
            L = Mixer_TransformerBlock_Encoder(self.patch_dim, block_seq_size, self.channel_dim, _n_heads, 0, 2)
            self.transformer_blocks.append(L)
        self.transformer_blocks = nn.Sequential(*self.transformer_blocks)
        
        self.linear = nn.Linear(self.patch_dim*self.channel_dim, num_classes)
        
        self.positional_encoding = PositionalEncoding(self.channel_dim, dropout=0)
        if not pos_emb:
            self.positional_encoding = nn.Identity()
        
        
    def get_factors(self, n):
        facts = []
        for i in range(2, n+1):
            if n%i == 0:
                facts.append(i)
        return facts
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.unfold(x).swapaxes(-1, -2)
        x = self.channel_change(x)
        x = self.positional_encoding(x)
        x = self.transformer_blocks(x)
        x = self.linear(x.view(bs, -1))
        return x

In [53]:
int(2**np.ceil(np.log2(np.sqrt(16)))) ## 32/4 * 32/4 = 64 total patches.. mixed at the block of 4

4

In [54]:
vit_mixer = Mixer_ViT_Classifier((3, 64, 64), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=3, num_classes=200)

ViT Mixer : Channes per patch -> Initial:12 Final:28
[2, 4, 7, 14, 28]
Num patches 1024
28 4
Mixing with block: 32


In [55]:
vit_mixer

Mixer_ViT_Classifier(
  (unfold): Unfold(kernel_size=[2, 2], dilation=1, padding=0, stride=[2, 2])
  (channel_change): Linear(in_features=12, out_features=28, bias=True)
  (transformer_blocks): Sequential(
    (0): Mixer_TransformerBlock_Encoder(
      (sparse_transformers): ModuleList(
        (0): Sparse_TransformerBlock(
          (attention): SelfAttention_Sparse(
            (values): Linear(in_features=28, out_features=28, bias=True)
            (keys): Linear(in_features=28, out_features=28, bias=True)
            (queries): Linear(in_features=28, out_features=28, bias=True)
            (fc_out): Linear(in_features=28, out_features=28, bias=True)
          )
          (norm1): LayerNorm((28,), eps=1e-05, elementwise_affine=True)
          (feed_forward): Sequential(
            (0): Linear(in_features=28, out_features=56, bias=True)
            (1): GELU()
            (2): Linear(in_features=56, out_features=28, bias=True)
          )
          (norm2): LayerNorm((28,), eps=1e-0

In [56]:
vit_mixer(torch.randn(1, 3, 64, 64)).shape

torch.Size([1, 200])

In [28]:
asdasd

NameError: name 'asdasd' is not defined

#### Final Model

In [30]:
torch.manual_seed(SEED)
# model = ViT_Classifier((3, 64, 64), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=8, num_classes=200, pos_emb=True)
# model = Mixer_ViT_Classifier((3, 64, 64), patch_size=[4, 4], hidden_expansion=2.4, num_blocks=4, num_classes=200, pos_emb=True)

model = ViT_Classifier((3, 64, 64), patch_size=[2, 2], hidden_expansion=3, num_blocks=8, num_classes=200, pos_emb=True)
# model = Mixer_ViT_Classifier((3, 64, 64), patch_size=[2, 2], hidden_expansion=3, num_blocks=4, num_classes=200, pos_emb=True)


model = model.to(device)

ViT Mixer : Channes per patch -> Initial:12 Final:36
[2, 3, 4, 6, 9, 12, 18, 36]
36 6


In [31]:
model

ViT_Classifier(
  (unfold): Unfold(kernel_size=[2, 2], dilation=1, padding=0, stride=[2, 2])
  (channel_change): Linear(in_features=12, out_features=36, bias=True)
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (attention): SelfAttention(
        (values): Linear(in_features=36, out_features=36, bias=True)
        (keys): Linear(in_features=36, out_features=36, bias=True)
        (queries): Linear(in_features=36, out_features=36, bias=True)
        (fc_out): Linear(in_features=36, out_features=36, bias=True)
      )
      (norm1): LayerNorm((36,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((36,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=36, out_features=72, bias=True)
        (1): GELU()
        (2): Linear(in_features=72, out_features=36, bias=True)
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (1): TransformerBlock(
      (attention): SelfAttention(
        (values): 

In [36]:
model.transformer_blocks[0].attention.heads

6

In [37]:
model.transformer_blocks[0].attention.embed_size

36

In [22]:
print("number of params: ", sum(p.numel() for p in model.parameters())) 

########### 4x4 ###########
## ViT   ||  6684362 (0->8) 
## SMViT ||  6684362 (0->2*4=8)

number of params:  7459580


## Training

In [23]:
model_name = f'vit_mixer_tiny_s{SEED}'
# model_name = f'vit_sparse_mixer_tiny_s{SEED}' ## sparse but with 8 layers total

In [24]:
EPOCHS = 200
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [25]:
STAT ={'train_stat':[], 'test_stat':[]}

In [26]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    STAT['train_stat'].append((epoch, train_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [27]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    STAT['test_stat'].append((epoch, test_loss/(batch_idx+1), 100.*correct/total)) ### (Epochs, Loss, Acc)
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc
        
    with open(f"./output/{model_name}_data.json", 'w') as f:
        json.dump(STAT, f, indent=0)

In [28]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [29]:
### Train the whole damn thing

for epoch in range(start_epoch, start_epoch+EPOCHS): ## for 200 epochs
    train(epoch)
    test(epoch)
    scheduler.step()

  0%|                                                             | 0/782 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 6.00 GiB (GPU 0; 7.93 GiB total capacity; 6.26 GiB already allocated; 918.50 MiB free; 6.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [36]:
len(train_dataset)//256

781

In [37]:
len(test_dataset)//256

39

In [36]:
xx.shape

torch.Size([256, 3, 64, 64])

In [None]:
### Training Speeds

####### 4*4 Patch ; total of 256 patches
## Sparse (256/16) -> 782/782 [04:55<00:00, 3.37it/s]  || 40/40 [00:06<00:00, 6.65it/s]
## Dense (256) ->     782/782 [07:39<00:00, 2.18it/s]  || 40/40 [00:09<00:00, 4.38it/s]

####### 2*2 Patch ; total of 1024 patches
## Sparse (1024/32) -> 48/782 [00:35<08:57,  1.37it/s]  || 
## Dense (1024)     -> 

In [None]:
best_acc

In [None]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

In [None]:
### the expansion is 2.4
### 83.69 for 12 layers sparse vit
### 82.46 for 12 layers vit
### 82.57 for 10 layers vit
### 84.84 for 9 = (3*3) layers sparse vit
### 82.47 for 9 layers vit
### 83.88 for 6 = (2*3) layers sparse vit
### 81.36 for 6 layers vit
### 81.68 for 3 = (1*3) layers sparse vit
### 81.50 for 3 layers vit

In [None]:
model.load_state_dict(checkpoint['model'])

In [None]:
model

In [None]:
with open(f"./output/{model_name}_data.json", 'r') as f:
    STAT = json.load(f)

In [None]:
STAT

In [None]:
train_stat = np.array(STAT['train_stat'])
test_stat = np.array(STAT['test_stat'])

In [None]:
plt.plot(train_stat[:,1], label='train')
plt.plot(test_stat[:,1], label='test')
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"./output/plots/{model_name}_loss.svg")
plt.show()

In [None]:
plt.plot(train_stat[:,2], label='train')
plt.plot(test_stat[:,2], label='test')
plt.ylabel("Accuracy")
plt.legend()
plt.savefig(f"./output/plots/{model_name}_accs.svg")
plt.show()

In [None]:
### TODO: Experiments
'''
1. Sparse Attention
2. Sparse MLP
3. Sparse Att + MLP

Datasets:
A. Cifar-10/100 -> 4x4 vs 2x2 vs 1x1 patch
B. Tiny-Imagenet -> 16x16 vs 4x4 vs 2x2 patch
'''