In [None]:
import torch
import torch.nn as nn
import torch.quantization
import os
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from train_test import test, train, train_kd, test_batch
from helpers import get_data_loader, quantize_model
from model import ViT
from helpers import load_checkpoint, get_model_size_bytes
import time
import copy

import torch.ao.quantization.quantize_fx as quantize_fx
import onnx


import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
torch.manual_seed(42)
train_loader, test_loader = get_data_loader(
    1024, 2, "datasets/cifar-100/cifar-100-python", download=True
)

In [None]:
model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cpu")

In [None]:
train(
    model,
    train_loader,
    test_loader,
    epochs=5,
    learning_rate=0.001,
    device="cuda",
    pruning_method="structured",
    weight_decay=0.0005,
)

In [None]:
model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=100,
    dim=128,
    depth=3,
    heads=3,
    mlp_dim=256,
    dropout=0.1,
).to("cpu")

model_save_path = "save_model/cifar-100/vit_16_teacher_cifar-100"

model_load_path = f"{model_save_path}/best_model.pt"

load_checkpoint(model, f"{model_load_path}")

In [None]:
model.to("cpu")
model.eval()
backend = "x86"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(
    model_static_quantized, inplace=False
)
model_static_quantized.to("cpu")

In [None]:
batch_array = []

for batch in test_loader:
    batch_array.append(batch)

In [None]:
inference_time_pruned = 0
start_time = time.time()
for _ in range(20):
    for image_batch in batch_array:
        with torch.no_grad():
            test_batch(model_static_quantized, image_batch[0], device="cpu")
    inference_time_pruned += time.time() - start_time
print(f"model_quantized_time: {inference_time_pruned} ")