In [None]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
from train_test import train, test,train_kd
from data_loader import get_data_loader,load_checkpoint
from model import ViT
from helpers import count_parameters

In [None]:
torch.manual_seed(42)

In [None]:
train_loader, test_loader = get_data_loader(
    80, 2, "datasets/cifar-10/cifar-10-python", download=True
)

In [None]:
teacher_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=3,
    heads=4,
    mlp_dim=128,
    distilling=False,
    dropout=0.1,
    CNN_FF=True
).to("cuda")

teacher_save_path = "save_model/cifar-10/vit16_teacher_cifar-10"
teacher_load_path = f"{teacher_save_path}/best_model.pt"

student_kd_save_path = "save_model/cifar-10/vit16_student_kd_cifar-10"

student_kd = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=2,
    heads=2,
    mlp_dim=128,
    distilling=False,
    dropout=0.1,
    CNN_FF=True

).to("cuda")


In [None]:

train_kd(
    student_kd,
    teacher_model,
    train_loader,
    test_loader,
    T=20,
    alpha=0.3,
    epochs=100,
    learning_rate=0.001,
    device='cuda',
    weight_decay=0.0005,
    save_path=student_kd_save_path,
    load_path_teacher=teacher_load_path,
)

In [None]:
train(
    teacher_model,
    train_loader,
    test_loader,
    epochs=500,
    learning_rate=0.0001,
    device="cuda",
    weight_decay=0.0005,
    save_path=teacher_save_path,
    # load_path=teacher_load_path
)

In [None]:
print(f"Total number of parameters: {count_parameters(student_kd)}")

In [None]:
student_base_save_path = "save_model/cifar-100/vit_16_student_base_cifar-100"
student_base_model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=64,
    depth=2,
    heads=2,
    mlp_dim=128,
    distilling=False,
    dropout=0.1,
    CNN_FF=True
).to("cuda")

train(
    student_base_model,
    train_loader,
    test_loader,
    epochs=100,
    learning_rate=0.001,
    weight_decay=0.0005,
    device="cuda",
    save_path=student_base_save_path,
)

In [None]:
load_checkpoint(teacher_model,teacher_load_path)

In [None]:
test(teacher_model,test_loader,'cuda')

In [None]:
pruning_method = prune.L1Unstructured

In [None]:
for name, module in teacher_model.named_modules():
    if isinstance(module, nn.Linear):
        prune.ln_structured(module, name='weight', amount=0.9, n=1, dim=0)
        print(module)


In [None]:
teacher_model

In [None]:
for name, layer in teacher_model.named_children():
    print(name)

In [None]:
module = teacher_model.transformer.layers[0]

In [None]:
for name, param in teacher_model.named_parameters():
    if "bias" in name:
        param= nn.Parameter(torch.sparse.FloatTensor(param.shape).to('cuda'))
        print(param)
    # else:
    #     print(param)


In [None]:
for name, module in teacher_model.named_modules():
    print(module.transformer.layers[0][0].fn.to_qkv.weight)

In [None]:
module

In [None]:
attention = module[0].fn

In [None]:
print(list(attention.named_parameters()))

In [None]:
linear_layer = attention.to_qkv

In [None]:
linear_layer

In [None]:
prune.random_unstructured(linear_layer, name="weight", amount=0.3)

In [None]:
teacher_model.transformer.layers[0][0].fn.to_qkv.weight