In [None]:
import torch
import torch.nn as nn
import os
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from train_test import test,train,train_kd,test_compare_inference
from helpers import get_data_loader,quantize_model
from model import ViT
from helpers import load_checkpoint,get_model_size_bytes
import time

import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
torch.manual_seed(42)
train_loader, test_loader = get_data_loader(
    1024, 2, "datasets/cifar-100/cifar-100-python", download=True
)

In [None]:
model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cuda")


model_pruned = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cuda")


In [None]:

train(
    model_pruned,
    train_loader,
    test_loader,
    epochs=50,
    learning_rate=0.001,
    device="cuda",
    pruning_method='structured',
    weight_decay=0.0005,
)

In [None]:
load_checkpoint(
    model, "save_model/cifar-100/vit_16_teacher_cifar-100/best_model.pt"
)

test(model, test_loader, device="cuda")

model_pruned=model

In [None]:
test(model_pruned, test_loader, device="cuda")


In [None]:

def prune_tensors(model):
    for i in range(3):
        tensor_qkv = model.transformer.layers[i][0].fn.to_qkv.weight
        tensor_to_out = model.transformer.layers[i][0].fn.to_out[0].weight
        tensor_net0 = model.transformer.layers[i][1].fn.net[0].weight
        tensor_net3 = model.transformer.layers[i][1].fn.net[3].weight

        for tensor_name, tensor in [("to_qkv", tensor_qkv), ("to_out", tensor_to_out), ("net0", tensor_net0), ("net3", tensor_net3)]:
            print(tensor_name)
            # Check if the tensor has any rows
            # if tensor.size(0) == 0:
            #     continue  # Skip empty tensor
            
            # Iterate over the rows of the tensor in reverse order
            for j in range(tensor.size(0)-1, -1, -1):
                # Check if all elements in the current row are zero
                if torch.all(tensor[j] == 0):
                    # Remove the row
                    tensor = torch.cat((tensor[:j], tensor[j+1:]), dim=0)
        
            if tensor_name == "to_qkv":
                model.transformer.layers[i][0].fn.to_qkv.weight = tensor
            elif tensor_name == "to_out":
                model.transformer.layers[i][0].fn.to_out[0].weight = tensor
            elif tensor_name == "net0":
                model.transformer.layers[i][1].fn.net[0].weight = tensor
            elif tensor_name == "net3":
                model.transformer.layers[i][1].fn.net[3].weight = tensor




In [None]:
model_pruned.transformer.layers[0][1].fn.net[0].weight.size()

In [None]:
prune_tensors(model_pruned)

In [None]:
model_pruned.transformer.layers[0][1].fn.net[0].weight.size()

In [None]:
for param in model_pruned.parameters():
    param = nn.Parameter(torch.sparse_coo_tensor(param.shape).to("cuda"))

In [None]:
for name, module in model_pruned.named_modules():
    if isinstance(module, nn.Linear):
        # Prune the 'weight' parameter of the linear layer
        prune.ln_structured(
            module,
            name='weight',  # Prune the 'weight' parameter
            amount=0.2,      # Prune 50% of the connections
            n=2,             # Use L2-norm for pruning
            dim=0            # Prune along the first dimension (rows)
        )
        print(f"Pruned {name}")

In [None]:
for name, param in model_pruned.named_parameters():
    print(f"Parameter name: {name}")
    print(param)

In [None]:
for param_name, param in model_pruned.named_parameters():
    if param.requires_grad:

        print(param)

In [None]:
image_batch, label_batch = next(iter(test_loader))
image_batch = image_batch.to("cuda")

# Test student_base
start_time = time.time()
with torch.no_grad():
    test_compare_inference(model, image_batch,device="cuda")
inference_time_base = time.time() - start_time
print(f"model_base_time: {inference_time_base} ")

# Test student_kd
start_time = time.time()
with torch.no_grad():
    test_compare_inference(model_pruned, image_batch,device="cuda")
inference_time_pruned = time.time() - start_time
print(f"model_pruned: {inference_time_pruned} ")
