**Task**:

Compare regular Vision Transformer (ViT) with Performer applying different attention kernels leveraging deterministic kernel features: Performer-ReLU and Performer-exp.

Record training, inference time and classification accuracy on eval tests for all three Transformer types.

**Note**:

A Performer-f variant is a Transformer replacing regular softmax attention kernel
$$K(q, k) = \exp(\frac{qk^{\top}}{\sqrt{d_{QK}}})$$
with
$$K(q, k) = f(q)f(k)^{\top}$$

Libraries

In [1]:
# import os
# from google.colab import drive
# drive.mount('/content/drive')
# os.chdir("/content/drive/My Drive/2024_fall/data-mining/notebooks")

# !pip install -q optuna

In [2]:
import time
from functools import partial
import sys, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)

sys.path.append(os.path.join(os.getcwd(), '..'))
from src.util import get_data
from src.vit import Trainer

Data

In [3]:
### DEFINE DATASET ###
dataset_name = 'mnist'
### DEFINE DATASET ###

### GETTING THE DATA ###
data_dir = '../data'; dwn = False

dataset_dir = f'{data_dir}/{dataset_name}'
train_dir = f'{dataset_dir}/train'; test_dir = f'{dataset_dir}/test'

checkpoint_dir = f'{data_dir}/checkpoints'
log_dir = f'{data_dir}/logs'

for dir in (data_dir, dataset_dir, train_dir, test_dir, checkpoint_dir, log_dir):
    if not os.path.exists(dir): os.makedirs(dir)

train_set, test_set, channels, image_size, num_classes = get_data(
    dataset_name=dataset_name, train_dir=train_dir, test_dir=test_dir, dwn=dwn
)

print(f'Channels: {channels}, Image size: {image_size}, Num. classes: {num_classes}.')
len(train_set), len(test_set)
### GETTING THE DATA ###

Channels: 1, Image size: 28, Num. classes: 10.


(60000, 10000)

Hyperparam Optimization

In [4]:
# def print_best_callback(study, trial):
#     print(f'Best value: {study.best_value}, Best params: {study.best_trial.params}')

# def objective(trial, device, train_set, test_set, model_kwargs, train_kwargs, dl_kwargs):
#     bs = trial.suggest_categorical('bs', [64, 128])
#     lr = 0.001
#     gm = 0.95

#     ps = 4

#     di = 256
#     dh = 256
#     md = 512

#     he = trial.suggest_categorical('he', [3, 6])
#     de = trial.suggest_categorical('de', [1, 2])

#     train_loader = DataLoader(dataset=train_set, batch_size=bs, **dl_kwargs)
#     test_loader  = DataLoader(dataset=test_set , batch_size=bs, **dl_kwargs)

#     model = ViT(patch_size=ps, dim=di, depth=de, heads=he, mlp_dim=md, dim_head=dh, **model_kwargs).to(device)

#     avg_vacc = train_valid(trial=trial, device=device, model=model, lr=lr, gamma=gm,
#         train_dataset=train_loader, test_dataset=test_loader, **train_kwargs)

#     return avg_vacc

# objective = partial(
#     objective,
#     device=device, train_set=train_set, test_set=test_set,
#     model_kwargs=dict(
#         image_size=image_size, num_classes=num_classes, pool='cls', channels=channels,
#         dropout=0.10, emb_dropout=0.10, attn_type='learn', la_depth=1
#     ),
#     train_kwargs=dict(
#         epochs=5, save_freq=11, checkpoint_name='temp', verbose=False
#     ),
#     dl_kwargs=dl_kwargs
# )

# study = optuna.create_study(direction='maximize')
# study.optimize(objective, callbacks=[print_best_callback], n_trials=8, show_progress_bar=True, n_jobs=-1)

In [5]:
# print( study.best_trial.params )

# fig = optuna.visualization.plot_param_importances(study)
# fig.show()

In [6]:
# fig = optuna.visualization.plot_contour(study, params=['bs', 'he'])
# fig.show()

Training and Validation (using optimal params.)

In [13]:
### MODEL PARAMS ###
patch_size = 7
dim = 2
dim_head = 2
mlp_dim = 4
heads = 1
depth = 1

pool = 'cls'; dropout = 0.10; emb_dropout = 0.10

la_depth = 1
la_exp = False
### MODEL PARAMS ###

### OTHER PARAMS ###
lr = 0.001; gamma = 0.95
epochs = 1

batch_size = 128
### OTHER PARAMS ###

In [15]:
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
valid_loader = DataLoader(dataset=test_set , batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# all_model_names = ('learn', 'relu', 'softmax', 'exp')
all_model_names = ('relu', 'exp')

for model_name in all_model_names:
    checkpoint_name = f'{checkpoint_dir}/{model_name}_{dataset_name}'
    log_name = f'{log_dir}/{model_name}_{dataset_name}.csv'

    t = Trainer(
        lr=lr, gamma=gamma, 
        image_size=image_size, patch_size=patch_size, num_classes=num_classes, dim=dim, 
        depth=depth, heads=heads, mlp_dim=mlp_dim, pool=pool, channels=channels, 
        dim_head=dim_head, dropout=dropout, emb_dropout=emb_dropout, 
        attn_type=model_name, la_depth=la_depth, la_exp=la_exp
    )

    hist = t.train_valid(
        epochs=epochs, train_loader=train_loader, valid_loader=valid_loader, 
        log_dir=log_name, checkpoint_dir=checkpoint_name, model_name=model_name, 
        notebook=True
    )

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Training and Validation Time

In [9]:
# for model_name in all_model_names:
#     elapsed = final_results[model_name][2]
#     print(f'Model: {model_name}, Time: {round( elapsed / (60) , 4)} min.')

Accuracy / Loss Visualization

In [10]:
# fig, ax = plt.subplots(1, 1, figsize=(12, 4))
# col = 'acc' if False else 'loss'

# for model_name, linestyle in zip(all_model_names, ('-', ':', '-.', '--')[:len(all_model_names)]):
#     hist = final_results[model_name][1]

#     ax.plot(range(1, epochs+1), hist[f'avg_{col}'], label=f'{model_name} train {col}',
#         color='red', linestyle=linestyle)
#     ax.plot(range(1, epochs+1), hist[f'avg_v{col}'], label=f'{model_name} test {col}',
#         color='blue', linestyle=linestyle)

# plt.legend(); plt.tight_layout(); plt.show()

Inference Time

In [11]:
# for model_name in all_model_names:
#     model = final_results[model_name][0]
#     model.eval()

#     start = time.time()
#     with torch.no_grad():
#         for vinputs, _ in test_loader:
#             model(vinputs.to(device))
#     elapsed = time.time() - start # seconds
#     print(f'Model: {model_name}, Time: {round( elapsed / (60) , 4)} min.')