In [None]:
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from train_test import train, test, train_kd
from vit import ViT
from vit_CNNFF import ViTCNNFF
from helpers import count_parameters
from plotter import plot_feature_maps, plot_loss_accuracy
from data_loader import (
    get_data_loader,
    load_checkpoint,
    load_lists_from_file,
    get_random_image,
)


In [None]:
torch.manual_seed(42)

In [None]:
train_loader, test_loader = get_data_loader(
    80, 2, "datasets/cifar-10/cifar-10-python", download=True
)

In [None]:
base_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=3,
    heads=4,
    mlp_dim=128,
    dropout=0.1,
).to("cpu") #cpu for feature visualisation
base_model_save_path = "save_model/cifar-100/vit16_base"
base_model_load_path = f"{base_model_save_path}/best_model.pt"
print(f"Total parameters:{count_parameters(base_model)}")

In [None]:
CNNFF_model = ViTCNNFF(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=3,
    heads=4,
    mlp_dim=128,
    dropout=0.1,
).to("cpu")#cpu for feature visualisation
CNNFF_model_save_path = "save_model/cifar-100/vit16_CNNFF"
CNNFF_model_load_path = f"{CNNFF_model_save_path}/best_model.pt"
print(f"Total parameters:{count_parameters(CNNFF_model)}")

In [None]:
x,img=get_random_image('bee')

In [None]:
# train(
#     base_model,
#     train_loader,
#     test_loader,
#     epochs=100,
#     learning_rate=0.001,
#     device="cuda",
#     weight_decay=0.0005,
#     save_path=base_model_save_path,
# )

In [None]:
load_checkpoint(base_model,base_model_load_path)
test(base_model,test_loader,device="cpu")

In [None]:
plot_feature_maps(base_model,x,img,device='cpu')

In [None]:
# train(
#     CNNFF_model,
#     train_loader,
#     test_loader,
#     epochs=100,
#     learning_rate=0.001,
#     device="cuda",
#     weight_decay=0.0005,
#     save_path=CNNFF_model_save_path,
# )

In [None]:
load_checkpoint(CNNFF_model,CNNFF_model_load_path)
test(CNNFF_model,test_loader,device='cpu')

In [None]:
plot_feature_maps(base_model,x,img)