In [None]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import os
import time
import argparse
import datetime
import numpy as np
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter
import matplotlib.pyplot as plt
from timm.scheduler.scheduler import Scheduler
from train_test import train, test
from helpers import get_data_loader
from model import ViT
from helpers import load_checkpoint

In [None]:
torch.manual_seed(42)

In [None]:
train_loader, test_loader = get_data_loader(
    3000, 2, "datasets/cifar-100/cifar-100-python", download=True
)

In [None]:
# torch.norm(student_model.transformer.layers[0].weight).item()

In [None]:
teacher_model = ViT(
    image_size=224,
    patch_size=16,
    num_classes=100,
    dim=768,
    depth=7,
    heads=12,
    mlp_dim=512,
    dropout=0.1,
).to("cuda")
teacher_save_path = "save_model/cifar-100/vit_16_teacher_cifar-100"

teacher_load_path = f"{teacher_save_path}/best_model.pt"
train(
    teacher_model,
    train_loader,
    test_loader,
    epochs=100,
    learning_rate=0.0001,
    device="cuda",
    save_path=teacher_save_path,
    # load_path=teacher_load_path
)

In [None]:
load_checkpoint(
    teacher_model, "save_model/cifar-100/vit_16_teacher_cifar-100/best_model.pt"
)
test(teacher_model, test_loader, device="cuda")

In [None]:
student_base_save_path = "save_model/cifar-100/vit_16_student_base_cifar-100"
student_base_model = ViT(
    image_size=224,
    patch_size=16,
    num_classes=100,
    dim=768,
    depth=6,
    heads=6,
    mlp_dim=512,
    dropout=0.1,
).to("cuda")

train(
    student_base_model,
    train_loader,
    test_loader,
    epochs=100,
    learning_rate=0.0001,
    device="cuda",
    save_path=student_base_save_path,
)

In [None]:
load_checkpoint(
    student_base_model,
    "save_model/cifar-100/vit_16_student_base_cifar-100/best_model.pt",
)
test(student_base_model, test_loader)

In [None]:
student_kd_save_path = "save_model/cifar-100/vit_16_student_kd_cifar-100"
student_kd = ViT(
    image_size=224,
    patch_size=16,
    num_classes=100,
    dim=768,
    depth=6,
    heads=6,
    mlp_dim=512,
    dropout=0.1,
).to("cuda")

teacher_load_path = f"{teacher_save_path}/best_model.pt"

train_kd(
    student_kd,
    teacher_model,
    train_loader,
    test_loader,
    T=2,
    soft_target_loss_weight=0.3,
    ce_loss_weight=0.7,
    epochs=100,
    learning_rate=0.0001,
    save_path=student_kd_save_path,
    load_path_teacher=teacher_load_path,
)

In [None]:
load_checkpoint(
    student_kd, "save_model/cifar-100/vit_16_student_kd_cifar-100/best_model.pt"
)
test(student_kd, test_loader)

In [None]:
model_size_bytes = sum(
    param.numel() for param in student_kd.parameters() if param.requires_grad
)

In [None]:
for name, layer in model.named_children():
    print(name)

In [None]:
module = model.transformer.layers[0]

In [None]:
attention = module[0].fn

In [None]:
print(list(attention.named_parameters()))

In [None]:
linear_layer = attention.to_qkv

In [None]:
linear_layer

In [None]:
prune.random_unstructured(linear_layer, name="weight", amount=0.3)

In [None]:
model.transformer.layers[0][0].fn.to_qkv.weight