In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from train_test import train_kd_pruning,train_pruning,test
from helpers import get_data_loader
from model import ViT
from helpers import load_checkpoint

In [2]:
torch.manual_seed(42)
train_loader, test_loader = get_data_loader(
    3000, 2, "datasets/cifar-100/cifar-100-python", download=True
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
teacher_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=768,
    depth=7,
    heads=12,
    mlp_dim=512,
    dropout=0.1,
).to("cuda")

teacher_save_path = "save_model/cifar-100/vit_16_teacher_cifar-100"

In [None]:

student_kd_save_path = "save_model/cifar-100/vit_16_student_kd_cifar-100"
student_kd = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=768,
    depth=6,
    heads=6,
    mlp_dim=512,
    dropout=0.1,
).to("cuda")

teacher_load_path = f"{teacher_save_path}/best_model.pt"

train_kd_pruning(
    student_kd,
    teacher_model,
    train_loader,
    test_loader,
    T=2,
    soft_target_loss_weight=0.3,
    ce_loss_weight=0.7,
    epochs=100,
    learning_rate=0.0001,
    device='cuda',
    save_path=student_kd_save_path,
    load_path_teacher=teacher_load_path,
)

In [None]:
load_checkpoint(
    student_kd, "save_model/cifar-100/vit_16_student_kd_cifar-100/best_model.pt"
)

In [4]:

train_pruning(
    teacher_model,
    train_loader,
    test_loader,
    epochs=100,
    learning_rate=0.0001,
    device="cuda",
    save_path=teacher_save_path,
    pruning_method='structured'
    # load_path=teacher_load_path
)

Epoch 1/100, Loss: 4.325147797079647
Current Learning Rate: 9.997532801828658e-05
Test Accuracy: 9.11%
max_test_accuracy : 9.11
Epoch 2/100, Loss: 3.8460660401512596
Current Learning Rate: 9.990133642141359e-05
Test Accuracy: 14.90%
max_test_accuracy : 14.9
Epoch 3/100, Loss: 3.5764001537771786
Current Learning Rate: 9.977809823015401e-05
Test Accuracy: 17.90%
max_test_accuracy : 17.9
Epoch 4/100, Loss: 3.3905582007239845
Current Learning Rate: 9.960573506572391e-05
Test Accuracy: 20.53%
max_test_accuracy : 20.53
Epoch 5/100, Loss: 3.232750654220581
Current Learning Rate: 9.93844170297569e-05
Test Accuracy: 22.00%
max_test_accuracy : 22.0
Epoch 6/100, Loss: 3.089136235854205
Current Learning Rate: 9.911436253643445e-05
Test Accuracy: 24.54%
max_test_accuracy : 24.54
Epoch 7/100, Loss: 2.9554501982296215
Current Learning Rate: 9.879583809693738e-05
Test Accuracy: 26.11%
max_test_accuracy : 26.11
Epoch 8/100, Loss: 2.839271910050336
Current Learning Rate: 9.842915805643157e-05
Test Accur

KeyboardInterrupt: 

In [8]:
test(teacher_model, test_loader, device="cuda")

Test Accuracy: 20.71%


20.71

In [9]:
load_checkpoint(
    teacher_model, "save_model/cifar-100/vit_16_teacher_cifar-100/best_model.pt"
)
test(teacher_model, test_loader, device="cuda")

Test Accuracy: 20.71%


20.71