In [None]:
import torch
import torch.nn as nn
import os
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from train_test import test, train, train_kd, test_batch
from helpers import get_data_loader, quantize_model
from model import ViT
from helpers import load_checkpoint, get_model_size_bytes
import time

import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
torch.manual_seed(42)
train_loader, test_loader = get_data_loader(
    1024, 2, "datasets/cifar-100/cifar-100-python", download=True
)

In [None]:
model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cuda")


model_pruned = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cuda")

In [None]:
# train(
#     model_pruned,
#     train_loader,
#     test_loader,
#     epochs=50,
#     learning_rate=0.001,
#     device="cuda",
#     pruning_method="structured",
#     weight_decay=0.0005,
# )

In [None]:
load_checkpoint(model, "save_model/cifar-100/vit_16_teacher_cifar-100/best_model.pt")

test(model, test_loader, device="cuda")

In [None]:
test(model_pruned, test_loader, device="cuda")

In [None]:
for name, module in model_pruned.named_modules():
    if isinstance(module, nn.Linear):
        m=prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
        m=prune.remove(m,name="weight")
        print(f"Pruned {name}")

In [None]:
model_pruned.transformer.layers[0][1].fn.net[0].weight.size()

In [None]:
model_pruned.transformer.layers[0][0].fn.to_qkv.weight

In [None]:
for layer in model_pruned.modules():
    if isinstance(layer, nn.Linear):
        for param_name, param in layer.named_parameters():
            sparse_param = param.cpu().to_sparse()
            sparse_param = sparse_param.to("cuda")
            sparse_param = nn.Parameter(sparse_param)
            setattr(layer, param_name, sparse_param)


In [None]:
for name, param in model_pruned.named_parameters():
    print(f"Parameter name: {name}")
    print(param)

In [None]:
batch_array=[]

for batch in test_loader:
    batch_array.append(batch)   

In [None]:
inference_time_base=0
start_time = time.time()
for image_batch in batch_array:
    with torch.no_grad():
        test_batch(model, image_batch[0], device="cuda")
inference_time_base += time.time() - start_time
print(f"model_base_time: {inference_time_base} ")



In [None]:
inference_time_pruned=0
start_time = time.time()
for image_batch in batch_array:
    with torch.no_grad():
        test_batch(model_pruned, image_batch[0], device="cuda")
inference_time_pruned += time.time() - start_time
print(f"model_pruned_time: {inference_time_pruned} ")