In [None]:
import torch
from models.sparse_vit import Sparse_ViT
from models.vit import ViT
from functions.data_loader import get_data_loader,load_lists_from_file,load_checkpoint,get_random_image
from functions.train_test import test,train
from functions.helpers import count_parameters
from functions.plotter import plot_loss_accuracy,plot_accuracy_comparison,plot_feature_maps_strided

#https://github.com/kyegomez/SparseAttention

In [None]:
train_loader,test_loader = get_data_loader(80,2,"datasets/cifar-100/cifar-100-python",download=False)

In [None]:
model = Sparse_ViT(    
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=4,
    heads=8,
    mlp_dim=256,
    dropout=0.1,
    pool='mean'
).to("cpu") 

In [None]:
x,img =get_random_image('bee')
load_checkpoint(model,'save_model/cifar-100/sparse_vit/best_model.pt')
plot_feature_maps_strided(model,x,img)

In [None]:
sparse_vit_save_path = 'save_model/cifar-100/sparse_vit'

In [None]:
train(
    model,
    train_loader,
    test_loader,
    epochs=30,
    learning_rate=0.001,
    device="cuda",
    weight_decay=0.0005,
    save_path=sparse_vit_save_path,
)

In [None]:
base_model = ViT(    
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=4,
    heads=8,
    mlp_dim=256,
    dropout=0.1,
).to("cuda") 

In [None]:
base_save_path = 'save_model/cifar-100/vit_base2'
train(
    base_model,
    train_loader,
    test_loader,
    epochs=30,
    learning_rate=0.001,
    device="cuda",
    weight_decay=0.0005,
    # save_path=base_save_path,
)

In [None]:
loss,base_model_accuracy=load_lists_from_file("save_model/cifar-100/vit_base/loss_and_accuracy")
loss2,sparse_model_accuracy= load_lists_from_file("save_model/cifar-100/sparse_vit/loss_and_accuracy")

In [None]:
plot_accuracy_comparison(base_model_accuracy,sparse_model_accuracy,'base_ViT','sparse_ViT')

In [None]:
load_checkpoint(model,f"{sparse_vit_save_path}/best_model.pt")
test(model,test_loader,'cuda')

In [None]:
count_parameters(model)


In [None]:
count_parameters(base_model)