In [None]:
import torch
from functions.train_test import test
from models.vit import ViT
from models.CNN_ViT import CNN_ViT
from models.CNN_ViT_dynamic import CNN_ViT_dynamic
from models.ViT_CNN_early_exit_caca import CNN_ViT_early_exit
from functions.helpers import count_parameters
from functions.plotter import plot_feature_maps, plot_loss_accuracy
import functions as f


In [None]:
torch.manual_seed(42)

CIFAR-100


In [None]:
train_loader, test_loader = f.data_loader.get_data_loader(
    80, 2, "datasets/cifar-100/cifar-100-python", download=True
)

CNN+ViT

In [None]:
base_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=4,
    heads=8,
    mlp_dim=256,
    dropout=0.1,
).to("cpu") #cpu for feature visualisation
base_model_load_path = f"save_model/cifar-100/vit_base/best_model.pt"
print(f"Total parameters:{f.helpers.count_parameters(base_model)}")
f.data_loader.load_checkpoint(base_model,base_model_load_path)
base_model_acc=test(base_model,test_loader,'cuda')
base_model_loss_list,base_model_accuracy_list=f.data_loader.load_lists_from_file('save_model/cifar-100/vit_base/loss_and_accuracy')
f.plotter.plot_loss_accuracy(base_model_loss_list,base_model_accuracy_list,'base_model')

CNN pre ViT feautre extraction + CNN patch embedding

In [None]:
CNN_ViT_model = CNN_ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=2,
    heads=4,
    mlp_dim=256,
    dropout=0.1,
).to("cpu")#cpu for feature visualisation
CNN_ViT_model_load_path = "save_model/cifar-100/CNN_ViT2/best_model.pt"
print(f"Total parameters:{f.helpers.count_parameters(CNN_ViT_model)}")
f.data_loader.load_checkpoint(CNN_ViT_model,CNN_ViT_model_load_path)
CNN_ViT_model_acc=test(CNN_ViT_model,test_loader,'cuda')
CNN_ViT_model_loss_list,CNN_ViT_model_accuracy_list=f.data_loader.load_lists_from_file('save_model/cifar-100/CNN_ViT2/loss_and_accuracy')
f.plotter.plot_loss_accuracy(CNN_ViT_model_loss_list,CNN_ViT_model_accuracy_list,'CNN_ViT_model')

In [None]:
f.plotter.plot_accuracy_comparison(CNN_ViT_model_accuracy_list,base_model_accuracy_list,'CNN_ViT_model','ViT_base_model')

In [None]:
f.data_loader.create_comparison_table([base_model,CNN_ViT_model],[base_model_acc,CNN_ViT_model_acc])

In [None]:
base_model.to('cpu')
x,img=f.data_loader.get_random_image('bee')
plot_feature_maps(base_model,x,img,device='cpu')

In [None]:
CNN_ViT_model.to('cpu')
plot_feature_maps(CNN_ViT_model,x,img,device='cpu')

CNN_ViT dynamic model

In [None]:
CNN_ViT_dynamic_model= CNN_ViT_dynamic(
    image_size=32,
    dim=64,
    patch_size=4,
    num_classes=100,
    depth=4,
    heads=4,
    mlp_dim=256,
    dropout=0.1,
    inference=False
).to("cuda")
CNN_ViT_dynamic_load_path ='save_model/cifar-100/CNN_ViT_dynamic/best_model.pt'

#Explicar num parametros dependiendo de ruta tomada
#Comprarar % de 'early_exits'
#Ver tiempo de inferencia con 1 ejemplo para este y el modelo base

print(f"Total parameters:{f.helpers.count_parameters(CNN_ViT_dynamic_model)}")
f.data_loader.load_checkpoint(CNN_ViT_dynamic_model,CNN_ViT_dynamic_load_path)
CNN_ViT_dynamic_model_acc=test(CNN_ViT_dynamic_model,test_loader,'cuda')
CNN_ViT_dynamic_model_loss_list,CNN_ViT_dynamic_model_accuracy_list=f.data_loader.load_lists_from_file('save_model/cifar-100/CNN_ViT_dynamic/loss_and_accuracy')
f.plotter.plot_loss_accuracy(CNN_ViT_dynamic_model_loss_list,CNN_ViT_dynamic_model_accuracy_list,'CNN_ViT_dynamic_model')

Test dynamic model inference time with batch=1 and compare to base model

In [None]:
CNN_ViT_dynamic_model= CNN_ViT_dynamic(
    image_size=32,
    dim=64,
    patch_size=4,
    num_classes=100,
    depth=4,
    heads=8,
    mlp_dim=256,
    dropout=0.1,
    inference=True
).to("cuda")
f.data_loader.load_checkpoint(CNN_ViT_dynamic_model,CNN_ViT_dynamic_load_path)


Knowledge distillation

Teacher logits based KD

In [None]:
student_kd = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=3,
    heads=6,
    mlp_dim=256,
    dropout=0.1,
    # feature_distill=True
).to("cuda")
student_load_path = 'save_model/cifar-100/vit_featurekd/best_model.pt'
print(f"Total parameters:{f.helpers.count_parameters(student_kd)}")
f.data_loader.load_checkpoint(student_kd,student_load_path)
student_kd_acc = test(student_kd,test_loader,'cuda')
student_kd_loss_list,student_kd_accuracy_list = f.data_loader.load_lists_from_file('save_model/cifar-100/vit_featurekd/loss_and_accuracy')
f.plotter.plot_loss_accuracy(student_kd_loss_list,student_kd_accuracy_list,'student_kd')

In [None]:
student_base_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=3,
    heads=6,
    mlp_dim=256,
    dropout=0.1,
).to("cuda")
student_base_load_path = "save_model/cifar-100/vit_16_student_base_cifar-100/best_model.pt"
print(f"Total parameters:{f.helpers.count_parameters(student_base_model)}")
f.data_loader.load_checkpoint(student_base_model,student_base_load_path)
student_base_model_acc = test(student_base_model,test_loader,'cuda')
student_base_model_loss_list,student_base_model_accuracy_list = f.data_loader.load_lists_from_file('save_model/cifar-100/vit_16_student_base_cifar-100/loss_and_accuracy')
f.plotter.plot_loss_accuracy(student_base_model_loss_list,student_base_model_accuracy_list,'student_base_model')


In [None]:
f.plotter.plot_accuracy_comparison(student_kd_accuracy_list,student_base_model_accuracy_list,'student_kd','student_base_model')

In [None]:
f.data_loader.create_comparison_table([student_kd,student_base_model,base_model],[student_kd_acc,student_base_model_acc,base_model_acc])

In [None]:
#Add inference time comparison and co2 eq 

CNN_ViT early exit


In [1]:
import torch
from functions.train_test import test
from models.vit import ViT
from models.CNN_ViT import CNN_ViT
from models.ViT_early_exit import ViT_early_exit
# from models.prueba_early_exit import prueba_early_exit
from functions.helpers import count_parameters
from functions.plotter import plot_feature_maps, plot_loss_accuracy
import functions as f


In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f68646cbb10>

In [3]:
train_loader, test_loader = f.data_loader.get_data_loader(
    80, 2, "datasets/cifar-100/cifar-100-python", download=True
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
ViT_early_exit = ViT_early_exit(
    image_size=32,
    dim=64,
    patch_size=4,
    num_classes=100,
    depth=6,
    heads=6,
    mlp_dim=256,
    dropout=0.1,
    early_exit=True
).to("cuda")
ViT_early_exit_load_path='save_model/cifar-100/ViT_early_exit/best_model.pt'
print(f"Total parameters:{f.helpers.count_parameters(ViT_early_exit)}")
f.data_loader.load_checkpoint(ViT_early_exit,ViT_early_exit_load_path)



Total parameters:833068


In [None]:
ViT_early_exitacc = test(CNN_ViT_early_exit,test_loader,'cuda')


In [5]:
train_loader_1_example, test_loader_1_example = f.data_loader.get_data_loader(
    1, 2, "datasets/cifar-100/cifar-100-python", download=True
)

ViT_early_exit_acc,num_early_exits = test(ViT_early_exit,test_loader_1_example,'cuda')
ViT_early_exit_loss_list,ViT_early_exit_model_accuracy_list = f.data_loader.load_lists_from_file('save_model/cifar-100/ViT_early_exit/loss_and_accuracy')
f.plotter.plot_loss_accuracy(ViT_early_exit_loss_list,ViT_early_exit_model_accuracy_list,'ViT_early_exit')

Files already downloaded and verified
Files already downloaded and verified
0.10688401758670807
0.15825140476226807
0.11509288102388382
0.08715715259313583
0.5141652226448059
0.15405693650245667
0.12335602194070816
0.10742839425802231
0.2083924263715744
0.25711947679519653
0.08790196478366852
0.14461620151996613
0.3698786199092865
0.11317052692174911
0.13583198189735413
0.12149877846240997
0.18878218531608582
0.06376760452985764
0.7414701581001282
0.4288255274295807
0.27461060881614685
0.08594335615634918
0.09642641991376877
0.22444334626197815
0.061572588980197906
0.11365935951471329
0.10800616443157196
0.10977625846862793
0.12733446061611176
0.7311151623725891
0.12279388308525085
0.13832597434520721
0.1649239957332611
0.10457748174667358
0.09833556413650513
0.11512923240661621
0.8051317930221558
0.4731251299381256
0.16168950498104095
0.15226659178733826
0.0979204922914505
0.09346020221710205
0.16290158033370972
0.1468125432729721
0.15358027815818787
0.0788307785987854
0.1311475187540

KeyboardInterrupt: 

In [None]:
num_early_exits

In [None]:
f.data_loader.create_comparison_table([base_model,CNN_ViT_early_exit],[base_model_acc,CNN_ViT_early_exit_acc])