#/ Configuration/Setup Section \

In [1]:
############################################
# 1. Imports & Setup
############################################

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models.resnet import Bottleneck
import time
import gc
from copy import deepcopy
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
from google.colab import drive

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



############################################
# 2. Model Definitions
############################################

# RESNET101
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks - 1)
        layers = []
        for st in strides:
            layers.append(block(self.in_planes, planes, st))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
def ResNet18(num_classes): # Discard if not needed
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet101(num_classes):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)

# Pretrained weights loading paths from Google Drive:
resnet18_weights_path = "/content/drive/My Drive/Models/resnet18-cifar10.pth" # Didn't know whether to keep get rid of Resnet18, remove if not needed
resnet101_weights_path = "/content/drive/My Drive/Models/resnet101-cifar100.pth"

def load_resnet_for_dataset(dataset_name): # Just a simple dataloader for either model/dataset
    if dataset_name.lower() == 'cifar10':
        model = ResNet18(num_classes=10).to(device)
        model.load_state_dict(torch.load(resnet18_weights_path, map_location=device))
    elif dataset_name.lower() == 'cifar100':
        model = ResNet101(num_classes=100).to(device)
        model.load_state_dict(torch.load(resnet101_weights_path, map_location=device))
    else:
        raise ValueError("Unsupported dataset for ResNet model")
    return model

# MOBILENETV2
class MobileNetV2_Block(nn.Module):
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(MobileNetV2_Block, self).__init__()
        self.stride = stride
        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1,
                               groups=planes, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3   = nn.BatchNorm2d(out_planes)
        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.stride == 1:
            out = out + self.shortcut(x)
        return out

class MobileNetV2(nn.Module):
    cfg = [(1,  16, 1, 1),
           (6,  24, 2, 1),  # For CIFAR, change stride from 2 to 1
           (6,  32, 3, 2),
           (6,  64, 4, 2),
           (6,  96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]
    def __init__(self, num_classes): # I Passed num_classes as an argument rather than assigning 100 or 10, check the next code cell for clarity
        super(MobileNetV2, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2   = nn.BatchNorm2d(1280)
        self.linear = nn.Linear(1280, num_classes)
    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks - 1)
            for s in strides:
                layers.append(MobileNetV2_Block(in_planes, out_planes, expansion, s))
                in_planes = out_planes
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# VGG16
vgg_cfg = {
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M',
              512, 512, 512, 'M', 512, 512, 512, 'M']
}
class VGG(nn.Module):
    def __init__(self, num_classes, vgg_name="VGG16"): # I Passed num_classes the same way with MobileNetv2 above
        super(VGG, self).__init__()
        self.features = self._make_layers(vgg_cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )
    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=False)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)



############################################
# 3. Mask-Based Pruning Engine
############################################
class pruning_engine_base:
    def __init__(self, pruning_ratio, pruning_method):
        self.pruning_ratio = 1 - pruning_ratio
        self.pruning_method = pruning_method
        self.mask_number = 0.0
        self.device = device
    def base_remove_filter_by_index(self, weight, remove_filter_idx,
                                    bias=None, mean=None, var=None, linear=False):
        with torch.no_grad():
            if mean is not None:
                for idx in remove_filter_idx:
                    weight[idx.item()] = self.mask_number
                    bias[idx.item()]   = self.mask_number
                    mean[idx.item()]   = self.mask_number
                    var[idx.item()]    = self.mask_number
                return weight, bias, mean, var
            elif bias is not None:
                for idx in remove_filter_idx:
                    weight[idx.item()] = self.mask_number
                    bias[idx.item()]   = self.mask_number
                return weight, bias
            else:
                for idx in remove_filter_idx:
                    weight[idx.item()] = self.mask_number
                return weight
    def base_remove_kernel_by_index(self, weight, remove_filter_idx, linear=False):
        with torch.no_grad():
            for idx in remove_filter_idx:
                weight[:, idx.item()] = self.mask_number
        return weight

class L1norm:
    def L1norm_pruning(self, layer):
        weight = layer.weight.data.clone()
        if len(weight.shape) == 4:
            importance = torch.sum(torch.abs(weight), dim=(1,2,3))
        else:
            importance = torch.sum(torch.abs(weight), dim=0)
        _, sorted_idx = torch.sort(importance, dim=0, descending=True)
        return sorted_idx


class pruning_engine(pruning_engine_base):
    def __init__(self, pruning_method, pruning_ratio=0.0, individual=False,
                 conv_to_bn_map=None):
        super().__init__(pruning_ratio, pruning_method)
        self.conv_to_bn_map = conv_to_bn_map if conv_to_bn_map else {}
        self.remove_filter_idx_history = {"previous_layer": None, "current_layer": None}
        self.individual = individual
        self.l1norm_pruning = L1norm()
        self.pruning_criterion = self.l1norm_pruning.L1norm_pruning

    def set_layer(self, layer, main_layer=False):
        self.copy_layer = deepcopy(layer)
        if main_layer:
            if self.individual:
                self.remove_filter_idx_history = {"previous_layer": None, "current_layer": None}
            self.remove_filter_idx_history["previous_layer"] = self.remove_filter_idx_history["current_layer"]
            self.remove_filter_idx_history["current_layer"] = None
            remove_filter_idx = self.pruning_criterion(self.copy_layer)
            num_prune = int(len(remove_filter_idx) * self.pruning_ratio)
            self.remove_filter_idx = remove_filter_idx[num_prune:]
            if self.remove_filter_idx_history["previous_layer"] is None:
                self.remove_filter_idx_history["previous_layer"] = self.remove_filter_idx
            self.remove_filter_idx_history["current_layer"] = self.remove_filter_idx
        return True

    def remove_conv_filter_kernel(self, conv_name=None, model=None):
        if self.copy_layer.bias is not None:
            w, b = self.base_remove_filter_by_index(
                weight=self.copy_layer.weight.data,
                remove_filter_idx=self.remove_filter_idx_history["current_layer"],
                bias=self.copy_layer.bias.data
            )
            self.copy_layer.weight.data = w
            self.copy_layer.bias.data = b
        else:
            w = self.base_remove_filter_by_index(
                weight=self.copy_layer.weight.data,
                remove_filter_idx=self.remove_filter_idx_history["current_layer"]
            )
            self.copy_layer.weight.data = w
        if conv_name and model and (conv_name in self.conv_to_bn_map):
            bn_name = self.conv_to_bn_map[conv_name]
            bn_layer = get_module_by_name(model, bn_name)
            self.remove_bn_layer(bn_layer, self.remove_filter_idx_history["current_layer"])
        return self.copy_layer

    def remove_bn_layer(self, bn_layer, remove_filter_idx):
        w, b, m, v = self.base_remove_filter_by_index(
            weight=bn_layer.weight.data,
            remove_filter_idx=remove_filter_idx,
            bias=bn_layer.bias.data,
            mean=bn_layer.running_mean.data,
            var=bn_layer.running_var.data
        )
        bn_layer.weight.data = w
        bn_layer.bias.data = b
        bn_layer.running_mean.data = m
        bn_layer.running_var.data = v



############################################
# 4. Training & Evaluation Functions
############################################
def train_model(model, dataloader, optimizer, criterion, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        start_time = time.time()
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {running_loss/len(dataloader):.4f} Time: {time.time()-start_time:.2f}s")
    return model

def evaluate_model(model, dataloader, topk=(1, 5)):
    model.eval()
    correct = 0
    total = 0
    top1_correct = 0
    top5_correct = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Top-k accuracy
            _, pred = outputs.topk(max(topk), dim=1, largest=True, sorted=True)
            top1_correct += pred[:, 0].eq(labels).sum().item()
            top5_correct += torch.any(pred.eq(labels.view(-1, 1)), dim=1).sum().item()

            # For F1/precision/recall
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = 100.0 * correct / total
    top1_acc = 100.0 * top1_correct / total
    top5_acc = 100.0 * top5_correct / total
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    print(f"\nEvaluation Metrics:")
    print(f"Accuracy: {acc:.2f}%")
    print(f"Top-1 Accuracy: {top1_acc:.2f}%")
    print(f"Top-5 Accuracy: {top5_acc:.2f}%")
    print(f"Precision (macro): {precision:.4f}")
    print(f"Recall (macro):    {recall:.4f}")
    print(f"F1 Score (macro):  {f1:.4f}")

    return acc, top1_acc, top5_acc, precision, recall, f1

def get_module_by_name(model, access_string):
    names = access_string.split('.')
    module = model
    for name in names:
        module = getattr(module, name)
    return module

def set_module_by_name(model, access_string, new_module):
    names = access_string.split('.')
    module = model
    for name in names[:-1]:
        module = getattr(module, name)
    setattr(module, names[-1], new_module)



######################################################################################
# 5. More Evaluation Metrics (Compression Ratio, Inference Speed)
######################################################################################
def count_nonzero_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    nonzero_params = sum(torch.count_nonzero(p).item() for p in model.parameters())

    if nonzero_params == 0:
        compression_ratio = float('inf')  # Avoid division by zero
    else:
        compression_ratio = total_params / nonzero_params

    sparsity = (1 - (nonzero_params / total_params)) * 100

    print(f"Total Parameters: {total_params}")
    print(f"Nonzero Parameters: {nonzero_params}")
    print(f"Compression Ratio: {compression_ratio:.3f}x")
    print(f"Sparsity: {sparsity:.3f}%")

    return compression_ratio, sparsity

def measure_inference_speed(model, testloader, device, num_batches=10):
    model.eval()
    total_time = 0.0
    num_samples = 0

    with torch.no_grad():
        for i, (inputs, _) in enumerate(testloader):
            if i >= num_batches:
                break

            inputs = inputs.to(device)
            start_time = time.time()
            _ = model(inputs)
            end_time = time.time()

            batch_time = end_time - start_time
            total_time += batch_time
            num_samples += inputs.size(0)

    if num_samples == 0:
        print("Warning: No samples processed in measure_inference_speed. Returning default values.")
        return 0, 0

    avg_time_per_sample = total_time / num_samples
    avg_fps = num_samples / total_time

    print(f"Inference Speed: {avg_time_per_sample:.6f} sec/sample ({avg_fps:.2f} FPS)")
    return avg_time_per_sample, avg_fps


####################################################################
# 6. Example dataloader usage for ResNet, MobileNetV2, and VGG16
####################################################################
def get_dataloaders(dataset, batch_size=128):
    if dataset == 'cifar10':
        normalize = ((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                                transform=transforms.Compose([
                                                    transforms.RandomCrop(32, padding=4),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(*normalize)
                                                ]))
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                               transform=transforms.Compose([
                                                   transforms.ToTensor(),
                                                   transforms.Normalize(*normalize)
                                               ]))
        num_classes = 10
    elif dataset == 'cifar100':
        normalize = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,
                                                 transform=transforms.Compose([
                                                     transforms.RandomCrop(32, padding=4),
                                                     transforms.RandomHorizontalFlip(),
                                                     transforms.ToTensor(),
                                                     transforms.Normalize(*normalize)
                                                 ]))
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True,
                                                transform=transforms.Compose([
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(*normalize)
                                                ]))
        num_classes = 100
    else:
        raise ValueError("Unsupported dataset name")

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    return trainset, trainloader, testloader, num_classes

print("\nAll done!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

All done!


In [2]:
# Loading CIFAR10
DATASET = "cifar10"
trainset, trainloader, testloader, num_classes = get_dataloaders(DATASET)

# / Model Tests \

<hr>

# CIFAR10 PRETRAINED MODELS SECTION
### MobileNetV2

In [3]:
## PRE-TRAINED TEST!!!!!
print("\n=== MobileNetV2 Example (CIFAR10) ===")

# 1) Define model architecture with correct num_classes
mobilenet_model = MobileNetV2(num_classes=num_classes).to(device)

# 2) Load pretrained weights from Google Drive
model_path = "/content/drive/MyDrive/Models/CIFAR10/Model@Mobilenetv2_ACC@95.82.pt"
checkpoint = torch.load(model_path, map_location=device)
mobilenet_model.load_state_dict(checkpoint["state_dict"])

# 3) Optimizer + loss (optional fine-tuning)
optimizer_mb = optim.Adam(mobilenet_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 4) Baseline evaluation
print("MobileNetV2 baseline evaluation:")
(mobileNet_acc_before, mobileNet_top1_acc_before, mobileNet_top5_acc_before,
 mobileNet_precision_before, mobileNet_recall_before, mobileNet_f1score_before
) = evaluate_model(mobilenet_model, testloader)

# 5) Build a list of all Conv2D layers, Pruning beyond first layer
layer_store_mb = [m for m in mobilenet_model.modules() if isinstance(m, nn.Conv2d)]
layers_to_prune_mb = [
    name for name, module in mobilenet_model.named_modules()
    if isinstance(module, nn.Conv2d) and 'shortcut' not in name and 'layers' not in name
]
conv_to_bn_map_mb = {}
prev_conv = None
for name, module in mobilenet_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prev_conv = name
    elif isinstance(module, nn.BatchNorm2d) and prev_conv:
        conv_to_bn_map_mb[prev_conv] = name
        prev_conv = None

# 6) Check compression + speed BEFORE pruning
print('\nCompression Details BEFORE Pruning')
mobileNet_compression_ratio_before, mobileNet_sparsity_before = count_nonzero_params(mobilenet_model)

print("\nMobileNetV2 inference speed BEFORE pruning:")
mobileNet_avg_time_per_sample_before, mobileNet_avg_fps_before = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mobileNet_avg_time_per_sample_before:.6f} sec/sample ({mobileNet_avg_fps_before:.2f} FPS)")


# --------------------------------------------------------- #
# ONE-SHOT PRUNING                                          #
# --------------------------------------------------------- #
print("\n--- MobileNetV2 One-Shot Pruning (CIFAR10) ---")

prune_engine_mb = pruning_engine(
    pruning_method="L1norm",
    pruning_ratio=0.2,           # 20% of filters are pruned
    conv_to_bn_map=conv_to_bn_map_mb
)

for layer_name in layers_to_prune_mb:
    print(f"Pruning MobileNetV2 layer: {layer_name}")
    orig_layer = get_module_by_name(mobilenet_model, layer_name)
    prune_engine_mb.set_layer(orig_layer, main_layer=True)
    masked_layer = prune_engine_mb.remove_conv_filter_kernel(conv_name=layer_name, model=mobilenet_model)
    set_module_by_name(mobilenet_model, layer_name, masked_layer)

    # Adjust BatchNorm layer dynamically
    if layer_name in conv_to_bn_map_mb:
        bn_name = conv_to_bn_map_mb[layer_name]
        bn_layer = get_module_by_name(mobilenet_model, bn_name)
        prune_engine_mb.remove_bn_layer(bn_layer, prune_engine_mb.remove_filter_idx_history["current_layer"])

print("MobileNetV2 evaluation AFTER one-shot Pruning:")
(mobileNet_acc_oneshot, mobileNet_top1_acc_oneshot, mobileNet_top5_acc_oneshot,
 mobileNet_precision_oneshot, mobileNet_recall_oneshot, mobileNet_f1score_oneshot
) = evaluate_model(mobilenet_model, testloader)

# Checking stats after One-Shot
print('\nCompression Details AFTER One-Shot Pruning')
mobileNet_compression_ratio_oneshot, mobileNet_sparsity_oneshot = count_nonzero_params(mobilenet_model)

print("\nMobileNetV2 inference speed AFTER One-Shot pruning:")
mobileNet_avg_time_per_sample_oneshot, mobileNet_avg_fps_oneshot = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mobileNet_avg_time_per_sample_oneshot:.6f} sec/sample ({mobileNet_avg_fps_oneshot:.2f} FPS)")


# --------------------------------------------------------- #
# ITERATIVE PRUNING                                         #
# --------------------------------------------------------- #
print("\n--- MobileNetV2 Iterative Pruning (CIFAR10) ---")
num_iter = 2
iter_mask_ratio = 0.1  # 10% pruned each iteration

for it in range(num_iter):
    print(f"\nIteration {it+1} for MobileNetV2")

    prune_engine_iter_mb = pruning_engine(
        pruning_method="L1norm",
        pruning_ratio=iter_mask_ratio,
        conv_to_bn_map=conv_to_bn_map_mb
    )

    for layer_name in layers_to_prune_mb:
        print(f"Pruning MobileNetV2 layer: {layer_name}")
        current_layer = get_module_by_name(mobilenet_model, layer_name)
        prune_engine_iter_mb.set_layer(current_layer, main_layer=True)
        masked_layer = prune_engine_iter_mb.remove_conv_filter_kernel(conv_name=layer_name, model=mobilenet_model)
        set_module_by_name(mobilenet_model, layer_name, masked_layer)

        # Also prune the Batchnorm, if optional
        if layer_name in conv_to_bn_map_mb:
            bn_name = conv_to_bn_map_mb[layer_name]
            bn_layer = get_module_by_name(mobilenet_model, bn_name)
            prune_engine_iter_mb.remove_bn_layer(bn_layer, prune_engine_iter_mb.remove_filter_idx_history["current_layer"])

    # Fine-tune
    print("Fine-tuning MobileNetV2 after iteration...")
    train_model(mobilenet_model, trainloader, optimizer_mb, criterion, num_epochs=2)

    # Evaluate
    mobileNet_acc_iter, mobileNet_top1_acc_iter, mobileNet_top5_acc_iter, \
    mobileNet_precision_iter, mobileNet_recall_iter, mobileNet_f1score_iter = evaluate_model(mobilenet_model, testloader)


# Final metrics after iterative pruning
print('Compression Details AFTER Iterative Pruning')
mobileNet_compression_ratio_iter, mobileNet_sparsity_iter = count_nonzero_params(mobilenet_model)

print("\nMobileNetV2 Model inference speed AFTER Iterative pruning:")
mobileNet_avg_time_per_sample_iter, mobileNet_avg_fps_iter = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mobileNet_avg_time_per_sample_iter:.6f} sec/sample ({mobileNet_avg_fps_iter:.2f} FPS)")



=== MobileNetV2 Example (CIFAR10) ===
MobileNetV2 baseline evaluation:

Evaluation Metrics:
Accuracy: 95.80%
Top-1 Accuracy: 95.80%
Top-5 Accuracy: 99.78%
Precision (macro): 0.9580
Recall (macro):    0.9580
F1 Score (macro):  0.9579

Compression Details BEFORE Pruning
Total Parameters: 2296922
Nonzero Parameters: 2296922
Compression Ratio: 1.000x
Sparsity: 0.000%

MobileNetV2 inference speed BEFORE pruning:
Inference Speed: 0.000318 sec/sample (3142.96 FPS)
Inference Speed: 0.000318 sec/sample (3142.96 FPS)

--- MobileNetV2 One-Shot Pruning (CIFAR10) ---
Pruning MobileNetV2 layer: conv1
Pruning MobileNetV2 layer: conv2
MobileNetV2 evaluation AFTER one-shot Pruning:

Evaluation Metrics:
Accuracy: 95.50%
Top-1 Accuracy: 95.50%
Top-5 Accuracy: 99.80%
Precision (macro): 0.9550
Recall (macro):    0.9550
F1 Score (macro):  0.9549

Compression Details AFTER One-Shot Pruning
Total Parameters: 2296922
Nonzero Parameters: 2214287
Compression Ratio: 1.037x
Sparsity: 3.598%

MobileNetV2 inference

### VGG16

In [4]:
## PRE-TRAINED TEST!!!!
print("\n=== VGG16 Example (CIFAR10) ===")

# 1) Define VGG16 with correct num_classes
vgg_model = VGG(num_classes=num_classes, vgg_name="VGG16").to(device)

# 2) Load pretrained weights
model_path = "/content/drive/MyDrive/Models/CIFAR10/Model@VGG16_ACC@95.26.pt"
checkpoint = torch.load(model_path, map_location=device)
vgg_model.load_state_dict(checkpoint["state_dict"])

# 3) Set model to evaluation mode
vgg_model.eval()

# 4) Optimizer
optimizer_vgg = optim.Adam(vgg_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 5) Baseline evaluation
print("VGG16 baseline evaluation:")
vgg_acc_before, vgg_top1_acc_before, vgg_top5_acc_before, vgg_precision_before, vgg_recall_before, vgg_f1score_before = evaluate_model(vgg_model, testloader)

# Build a list of all Conv2D layers, Pruning beyond first layer
layer_store_vgg = [m for m in vgg_model.modules() if isinstance(m, nn.Conv2d)]
layers_to_prune_vgg = [
    name for name, module in vgg_model.named_modules()
    if isinstance(module, nn.Conv2d)
]

conv_to_bn_map_vgg = {}
prev_conv = None
for name, module in vgg_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prev_conv = name
    elif isinstance(module, nn.BatchNorm2d) and prev_conv:
        conv_to_bn_map_vgg[prev_conv] = name
        prev_conv = None

# --- Before pruning stats ---
print(f'\nCompression Details BEFORE Pruning')
vgg_compression_ratio_before, vgg_sparsity_before = count_nonzero_params(vgg_model)

print("\nVGG16 inference speed BEFORE pruning:")
vgg_avg_time_per_sample_before, vgg_avg_fps_before = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_before:.6f} sec/sample ({vgg_avg_fps_before:.2f} FPS)")

# --------------------------------------------------------- #
# ONE-SHOT PRUNING                                          #
# --------------------------------------------------------- #
print("\n--- VGG16 One-Shot Pruning (CIFAR10) ---")

prune_engine_vgg = pruning_engine(
    pruning_method="L1norm",
    pruning_ratio=0.2,  # 20% pruned
    conv_to_bn_map=conv_to_bn_map_vgg
)

for layer_name in layers_to_prune_vgg:
    print(f"Pruning VGG16 layer: {layer_name}")
    orig_layer = get_module_by_name(vgg_model, layer_name)
    prune_engine_vgg.set_layer(orig_layer, main_layer=True)
    masked_layer = prune_engine_vgg.remove_conv_filter_kernel(conv_name=layer_name, model=vgg_model)
    set_module_by_name(vgg_model, layer_name, masked_layer)

print("VGG16 evaluation after one-shot Pruning:")
vgg_acc_oneshot, vgg_top1_acc_oneshot, vgg_top5_acc_oneshot, vgg_prec_oneshot, vgg_rec_oneshot, vgg_f1_oneshot = evaluate_model(vgg_model, testloader)

# After one-shot stats
print(f'\nCompression Details AFTER One-Shot Pruning')
vgg_compression_ratio_oneshot, vgg_sparsity_oneshot = count_nonzero_params(vgg_model)

print("\nVGG16 inference speed AFTER One-Shot pruning:")
vgg_avg_time_per_sample_oneshot, vgg_avg_fps_oneshot = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_oneshot:.6f} sec/sample ({vgg_avg_fps_oneshot:.2f} FPS)")


# --------------------------------------------------------- #
# ITERATIVE PRUNING                                         #
# --------------------------------------------------------- #
print("\n--- VGG16 Iterative Pruning (CIFAR10) ---")
num_iter = 2
iter_mask_ratio = 0.1  # 10% each iteration

for it in range(num_iter):
    print(f"\nIteration {it+1} for VGG16")

    prune_engine_iter_vgg = pruning_engine(
        pruning_method="L1norm",
        pruning_ratio=iter_mask_ratio,
        conv_to_bn_map=conv_to_bn_map_vgg
    )

    for layer_name in layers_to_prune_vgg:
        print(f"Pruning VGG16 layer: {layer_name}")
        current_layer = get_module_by_name(vgg_model, layer_name)
        prune_engine_iter_vgg.set_layer(current_layer, main_layer=True)
        masked_layer = prune_engine_iter_vgg.remove_conv_filter_kernel(conv_name=layer_name, model=vgg_model)
        set_module_by_name(vgg_model, layer_name, masked_layer)

    print("Fine-tuning VGG16 after iteration...")
    vgg_model = train_model(vgg_model, trainloader, optimizer_vgg, criterion, num_epochs=2)
    evaluate_model(vgg_model, testloader)

# Final stats after iterative pruning
print('Compression Details After Iterative Pruning')
vgg_compression_ratio_iter, vgg_sparsity_iter = count_nonzero_params(vgg_model)

print("\nVGG16 inference speed AFTER Iterative pruning:")
vgg_avg_time_per_sample_iter, vgg_avg_fps_iter = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_iter:.6f} sec/sample ({vgg_avg_fps_iter:.2f} FPS)")



=== VGG16 Example (CIFAR10) ===
VGG16 baseline evaluation:

Evaluation Metrics:
Accuracy: 95.43%
Top-1 Accuracy: 95.43%
Top-5 Accuracy: 99.85%
Precision (macro): 0.9544
Recall (macro):    0.9543
F1 Score (macro):  0.9542

Compression Details BEFORE Pruning
Total Parameters: 33646666
Nonzero Parameters: 33646666
Compression Ratio: 1.000x
Sparsity: 0.000%

VGG16 inference speed BEFORE pruning:
Inference Speed: 0.000061 sec/sample (16442.64 FPS)
Inference Speed: 0.000061 sec/sample (16442.64 FPS)

--- VGG16 One-Shot Pruning (CIFAR10) ---
Pruning VGG16 layer: features.0
Pruning VGG16 layer: features.3
Pruning VGG16 layer: features.7
Pruning VGG16 layer: features.10
Pruning VGG16 layer: features.14
Pruning VGG16 layer: features.17
Pruning VGG16 layer: features.20
Pruning VGG16 layer: features.24
Pruning VGG16 layer: features.27
Pruning VGG16 layer: features.30
Pruning VGG16 layer: features.34
Pruning VGG16 layer: features.37
Pruning VGG16 layer: features.40
VGG16 evaluation after one-shot 

<hr>

# CIFAR100 PRETRAINED MODELS SECTION


In [5]:
# Loading CIFAR100
DATASET = "cifar100"
trainset, trainloader, testloader, num_classes = get_dataloaders(DATASET)

### Resnet101

In [6]:
print("\n=== ResNet101 Example (CIFAR-100) ===")

# 1) Define model
resnet_model = ResNet101(num_classes=num_classes).to(device)

# 2) Load pretrained weights
model_path = "/content/drive/My Drive/Models/CIFAR100/Model@ResNet101_ACC@83.41.pt"
checkpoint = torch.load(model_path, map_location=device)
resnet_model.load_state_dict(checkpoint)
resnet_model.eval()

# 3) Optimizer + Loss
optimizer_res = optim.Adam(resnet_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# === Baseline Evaluation === #
print("ResNet101 baseline evaluation:")
res_acc_before, res_top1_before, res_top5_before, res_precision_before, res_recall_before, res_f1_before = evaluate_model(resnet_model, testloader)

print("\nCompression Details BEFORE Pruning:")
res_compression_before, res_sparsity_before = count_nonzero_params(resnet_model)

print("\nResNet101 inference speed BEFORE pruning:")
res_time_before, res_fps_before = measure_inference_speed(resnet_model, testloader, device)
print(f"Inference Speed: {res_time_before:.6f} sec/sample ({res_fps_before:.2f} FPS)")

# 4) Setup for Pruning
layers_to_prune_res = [
    name for name, module in resnet_model.named_modules()
    if isinstance(module, nn.Conv2d) and 'shortcut' not in name
]

conv_to_bn_map_res = {}
prev_conv = None
for name, module in resnet_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prev_conv = name
    elif isinstance(module, nn.BatchNorm2d) and prev_conv:
        conv_to_bn_map_res[prev_conv] = name
        prev_conv = None

# --------------------------------------------------------- #
# ONE-SHOT PRUNING                                          #
# --------------------------------------------------------- #
print("\n--- ResNet101 One-Shot Pruning (CIFAR-100) ---")
prune_engine_res = pruning_engine(
    pruning_method="L1norm",
    pruning_ratio=0.2,
    conv_to_bn_map=conv_to_bn_map_res
)

for layer_name in layers_to_prune_res:
    print(f"Pruning layer: {layer_name}")
    orig_layer = get_module_by_name(resnet_model, layer_name)
    prune_engine_res.set_layer(orig_layer, main_layer=True)
    masked_layer = prune_engine_res.remove_conv_filter_kernel(conv_name=layer_name, model=resnet_model)
    set_module_by_name(resnet_model, layer_name, masked_layer)

print("ResNet101 evaluation AFTER One-Shot Pruning:")
res_acc_oneshot, res_top1_oneshot, res_top5_oneshot, res_precision_oneshot, res_recall_oneshot, res_f1_oneshot = evaluate_model(resnet_model, testloader)

print("\nCompression Details AFTER One-Shot Pruning:")
res_compression_oneshot, res_sparsity_oneshot = count_nonzero_params(resnet_model)

print("\nResNet101 inference speed AFTER One-Shot Pruning:")
res_time_oneshot, res_fps_oneshot = measure_inference_speed(resnet_model, testloader, device)
print(f"Inference Speed: {res_time_oneshot:.6f} sec/sample ({res_fps_oneshot:.2f} FPS)")

# --------------------------------------------------------- #
# ITERATIVE PRUNING                                         #
# --------------------------------------------------------- #
print("\n--- ResNet101 Iterative Pruning (CIFAR-100) ---")
num_iter = 2
iter_mask_ratio = 0.1

for it in range(num_iter):
    print(f"\nIteration {it + 1}")

    prune_engine_iter = pruning_engine(
        pruning_method="L1norm",
        pruning_ratio=iter_mask_ratio,
        conv_to_bn_map=conv_to_bn_map_res
    )

    for layer_name in layers_to_prune_res:
        print(f"Pruning layer: {layer_name}")
        current_layer = get_module_by_name(resnet_model, layer_name)
        prune_engine_iter.set_layer(current_layer, main_layer=True)
        masked_layer = prune_engine_iter.remove_conv_filter_kernel(conv_name=layer_name, model=resnet_model)
        set_module_by_name(resnet_model, layer_name, masked_layer)

    print("Fine-tuning ResNet101 AFTER pruning iteration...")
    resnet_model = train_model(resnet_model, trainloader, optimizer_res, criterion, num_epochs=2)
    res_acc_iter, res_top1_iter, res_top5_iter, res_precision_iter, res_recall_iter, res_f1_iter = evaluate_model(resnet_model, testloader)

# Final compression + speed
print("\nCompression Details AFTER Iterative Pruning:")
res_compression_iter, res_sparsity_iter = count_nonzero_params(resnet_model)

print("\nResNet101 inference speed AFTER Iterative Pruning:")
res_time_iter, res_fps_iter = measure_inference_speed(resnet_model, testloader, device)
print(f"Inference Speed: {res_time_iter:.6f} sec/sample ({res_fps_iter:.2f} FPS)")

print("\nAll done!")



=== ResNet101 Example (CIFAR-100) ===
ResNet101 baseline evaluation:

Evaluation Metrics:
Accuracy: 83.38%
Top-1 Accuracy: 83.38%
Top-5 Accuracy: 96.20%
Precision (macro): 0.8353
Recall (macro):    0.8338
F1 Score (macro):  0.8336

Compression Details BEFORE Pruning:
Total Parameters: 42697380
Nonzero Parameters: 42697380
Compression Ratio: 1.000x
Sparsity: 0.000%

ResNet101 inference speed BEFORE pruning:
Inference Speed: 0.000190 sec/sample (5261.21 FPS)
Inference Speed: 0.000190 sec/sample (5261.21 FPS)

--- ResNet101 One-Shot Pruning (CIFAR-100) ---
Pruning layer: conv1
Pruning layer: layer1.0.conv1
Pruning layer: layer1.0.conv2
Pruning layer: layer1.0.conv3
Pruning layer: layer1.1.conv1
Pruning layer: layer1.1.conv2
Pruning layer: layer1.1.conv3
Pruning layer: layer1.2.conv1
Pruning layer: layer1.2.conv2
Pruning layer: layer1.2.conv3
Pruning layer: layer2.0.conv1
Pruning layer: layer2.0.conv2
Pruning layer: layer2.0.conv3
Pruning layer: layer2.1.conv1
Pruning layer: layer2.1.conv

### MobileNetV2

In [7]:
print("\n=== MobileNetV2 Example (CIFAR-100)  ===")

# 1) Define model for CIFAR-100
mobilenet_model = MobileNetV2(num_classes=num_classes).to(device)

# 2) Load pretrained weights
model_path = "/content/drive/My Drive/Models/CIFAR100/Model@Mobilenetv2_ACC@79.32.pt"
checkpoint = torch.load(model_path, map_location=device)
mobilenet_model.load_state_dict(checkpoint)

mobilenet_model.eval()

# 3) Setup optimizer/loss
optimizer_mb = optim.Adam(mobilenet_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 4) Baseline evaluation
print("MobileNetV2 baseline evaluation (CIFAR-100):")
mb_acc_before, mb_top1_before, mb_top5_before, mb_precision_before, mb_recall_before, mb_f1_before = evaluate_model(mobilenet_model, testloader)

print("\nCompression Details Before Pruning:")
mb_compression_before, mb_sparsity_before = count_nonzero_params(mobilenet_model)

print("\nInference Speed Before Pruning:")
mb_time_before, mb_fps_before = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mb_time_before:.6f} sec/sample ({mb_fps_before:.2f} FPS)")

# 5) Setup for Pruning
layer_store_mb = [m for m in mobilenet_model.modules() if isinstance(m, nn.Conv2d)]
layers_to_prune_mb = [
    name for name, module in mobilenet_model.named_modules()
    if isinstance(module, nn.Conv2d) and 'shortcut' not in name and 'layers' not in name
]

conv_to_bn_map_mb = {}
prev_conv = None
for name, module in mobilenet_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prev_conv = name
    elif isinstance(module, nn.BatchNorm2d) and prev_conv:
        conv_to_bn_map_mb[prev_conv] = name
        prev_conv = None

# --------------------------------------------------------- #
# ONE-SHOT PRUNING                                          #
# --------------------------------------------------------- #
print("\n--- MobileNetV2 One-Shot Pruning (CIFAR-100) ---")
prune_engine_mb = pruning_engine(
    pruning_method="L1norm",
    pruning_ratio=0.2,           # 20% pruned
    conv_to_bn_map=conv_to_bn_map_mb
)

for layer_name in layers_to_prune_mb:
    print(f"Pruning layer: {layer_name}")
    current_layer = get_module_by_name(mobilenet_model, layer_name)
    prune_engine_mb.set_layer(current_layer, main_layer=True)
    masked_layer = prune_engine_mb.remove_conv_filter_kernel(conv_name=layer_name, model=mobilenet_model)
    set_module_by_name(mobilenet_model, layer_name, masked_layer)

    # Optionally also prune BN
    if layer_name in conv_to_bn_map_mb:
        bn_name = conv_to_bn_map_mb[layer_name]
        bn_layer = get_module_by_name(mobilenet_model, bn_name)
        prune_engine_mb.remove_bn_layer(bn_layer, prune_engine_mb.remove_filter_idx_history["current_layer"])

# Evaluate after One-Shot
print("MobileNetV2 evaluation AFTER One-Shot Pruning (CIFAR-100):")
mb_acc_oneshot, mb_top1_oneshot, mb_top5_oneshot, mb_precision_oneshot, mb_recall_oneshot, mb_f1_oneshot = evaluate_model(mobilenet_model, testloader)

print("\nCompression Details AFTER One-Shot Pruning:")
mb_compression_oneshot, mb_sparsity_oneshot = count_nonzero_params(mobilenet_model)

print("\nInference Speed AFTER One-Shot Pruning:")
mb_time_oneshot, mb_fps_oneshot = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mb_time_oneshot:.6f} sec/sample ({mb_fps_oneshot:.2f} FPS)")

# --------------------------------------------------------- #
# ITERATIVE PRUNING                                         #
# --------------------------------------------------------- #
print("\n--- MobileNetV2 Iterative Pruning (CIFAR-100) ---")
num_iter = 2
iter_mask_ratio = 0.1

for it in range(num_iter):
    print(f"\nIteration {it + 1}")

    prune_engine_iter_mb = pruning_engine(
        pruning_method="L1norm",
        pruning_ratio=iter_mask_ratio,
        conv_to_bn_map=conv_to_bn_map_mb
    )

    for layer_name in layers_to_prune_mb:
        print(f"Pruning layer: {layer_name}")
        current_layer = get_module_by_name(mobilenet_model, layer_name)
        prune_engine_iter_mb.set_layer(current_layer, main_layer=True)
        masked_layer = prune_engine_iter_mb.remove_conv_filter_kernel(conv_name=layer_name, model=mobilenet_model)
        set_module_by_name(mobilenet_model, layer_name, masked_layer)

        if layer_name in conv_to_bn_map_mb:
            bn_name = conv_to_bn_map_mb[layer_name]
            bn_layer = get_module_by_name(mobilenet_model, bn_name)
            prune_engine_iter_mb.remove_bn_layer(bn_layer, prune_engine_iter_mb.remove_filter_idx_history["current_layer"])

    print("Fine-tuning MobileNetV2 after pruning iteration...")
    mobilenet_model = train_model(mobilenet_model, trainloader, optimizer_mb, criterion, num_epochs=2)

# Final evaluation
print("MobileNetV2 evaluation after Iterative Pruning:")
mb_acc_iter, mb_top1_iter, mb_top5_iter, mb_precision_iter, mb_recall_iter, mb_f1_iter = evaluate_model(mobilenet_model, testloader)

print("\nCompression Details After Iterative Pruning:")
mb_compression_iter, mb_sparsity_iter = count_nonzero_params(mobilenet_model)

print("\nInference Speed After Iterative Pruning:")
mb_time_iter, mb_fps_iter = measure_inference_speed(mobilenet_model, testloader, device)
print(f"Inference Speed: {mb_time_iter:.6f} sec/sample ({mb_fps_iter:.2f} FPS)")



=== MobileNetV2 Example (CIFAR-100)  ===
MobileNetV2 baseline evaluation (CIFAR-100):

Evaluation Metrics:
Accuracy: 78.80%
Top-1 Accuracy: 78.80%
Top-5 Accuracy: 95.22%
Precision (macro): 0.7899
Recall (macro):    0.7880
F1 Score (macro):  0.7878

Compression Details Before Pruning:
Total Parameters: 2412212
Nonzero Parameters: 2412212
Compression Ratio: 1.000x
Sparsity: 0.000%

Inference Speed Before Pruning:
Inference Speed: 0.000201 sec/sample (4965.83 FPS)
Inference Speed: 0.000201 sec/sample (4965.83 FPS)

--- MobileNetV2 One-Shot Pruning (CIFAR-100) ---
Pruning layer: conv1
Pruning layer: conv2
MobileNetV2 evaluation AFTER One-Shot Pruning (CIFAR-100):

Evaluation Metrics:
Accuracy: 78.29%
Top-1 Accuracy: 78.29%
Top-5 Accuracy: 94.92%
Precision (macro): 0.7888
Recall (macro):    0.7829
F1 Score (macro):  0.7832

Compression Details AFTER One-Shot Pruning:
Total Parameters: 2412212
Nonzero Parameters: 2329577
Compression Ratio: 1.035x
Sparsity: 3.426%

Inference Speed AFTER One-

###VGG16

In [8]:
print("\n=== VGG16 Example (CIFAR-100) ===")

# 1) Define model
vgg_model = VGG(num_classes=num_classes, vgg_name="VGG16").to(device)

# 2) Load pretrained weights
model_path = "/content/drive/My Drive/Models/CIFAR100/Model@VGG16_ACC@76.41.pt"
checkpoint = torch.load(model_path, map_location=device)
vgg_model.load_state_dict(checkpoint)

# 3) Set to eval mode
vgg_model.eval()

# 4) Optimizer & Loss
optimizer_vgg = optim.Adam(vgg_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# === Baseline Evaluation === #
print("VGG16 baseline evaluation:")
vgg_acc_before, vgg_top1_acc_before, vgg_top5_acc_before, vgg_precision_before, vgg_recall_before, vgg_f1score_before = evaluate_model(vgg_model, testloader)

# Compression
print(f'\nCompression Details Before Pruning:')
vgg_compression_ratio_before, vgg_sparsity_before = count_nonzero_params(vgg_model)

# Inference Speed
print(f"\nVGG16 inference speed BEFORE pruning:")
vgg_avg_time_per_sample_before, vgg_avg_fps_before = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_before:.6f} sec/sample ({vgg_avg_fps_before:.2f} FPS)")

# Prepare Layers for Pruning
layer_store_vgg = [m for m in vgg_model.modules() if isinstance(m, nn.Conv2d)]
layers_to_prune_vgg = [name for name, module in vgg_model.named_modules() if isinstance(module, nn.Conv2d)]

conv_to_bn_map_vgg = {}
prev_conv = None
for name, module in vgg_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prev_conv = name
    elif isinstance(module, nn.BatchNorm2d) and prev_conv:
        conv_to_bn_map_vgg[prev_conv] = name
        prev_conv = None

# --------------------------------------------------------- #
# ONE-SHOT PRUNING                                          #
# --------------------------------------------------------- #
print("\n--- VGG16 One-Shot Pruning ---")
prune_engine_vgg = pruning_engine(
    pruning_method="L1norm",
    pruning_ratio=0.2,   # 20% pruned
    conv_to_bn_map=conv_to_bn_map_vgg
)

for layer_name in layers_to_prune_vgg:
    print(f"Pruning VGG16 layer: {layer_name}")
    orig_layer = get_module_by_name(vgg_model, layer_name)
    prune_engine_vgg.set_layer(orig_layer, main_layer=True)
    masked_layer = prune_engine_vgg.remove_conv_filter_kernel(conv_name=layer_name, model=vgg_model)
    set_module_by_name(vgg_model, layer_name, masked_layer)

# === Eval After One-Shot ===
print("VGG16 evaluation AFTER One-Shot Pruning:")
vgg_acc_oneshot, vgg_top1_acc_oneshot, vgg_top5_acc_oneshot, vgg_precision_oneshot, vgg_recall_oneshot, vgg_f1score_oneshot = evaluate_model(vgg_model, testloader)

print(f'\nCompression Details AFTER One-Shot Pruning:')
vgg_compression_ratio_oneshot, vgg_sparsity_oneshot = count_nonzero_params(vgg_model)

print(f"\nVGG16 inference speed AFTER One-Shot pruning:")
vgg_avg_time_per_sample_oneshot, vgg_avg_fps_oneshot = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_oneshot:.6f} sec/sample ({vgg_avg_fps_oneshot:.2f} FPS)")


# ---------------------------------------------------------
# ITERATIVE PRUNING
# ---------------------------------------------------------
print("\n--- VGG16 Iterative Pruning ---")
num_iter = 2
iter_mask_ratio = 0.1

for it in range(num_iter):
    print(f"\nIteration {it+1} for VGG16")

    prune_engine_iter_vgg = pruning_engine(
        pruning_method="L1norm",
        pruning_ratio=iter_mask_ratio,
        conv_to_bn_map=conv_to_bn_map_vgg
    )

    for layer_name in layers_to_prune_vgg:
        print(f"Pruning VGG16 layer: {layer_name}")
        current_layer = get_module_by_name(vgg_model, layer_name)
        prune_engine_iter_vgg.set_layer(current_layer, main_layer=True)
        masked_layer = prune_engine_iter_vgg.remove_conv_filter_kernel(conv_name=layer_name, model=vgg_model)
        set_module_by_name(vgg_model, layer_name, masked_layer)

    print("Fine-tuning VGG16 after iteration...")
    vgg_model = train_model(vgg_model, trainloader, optimizer_vgg, criterion, num_epochs=2)

# === Final Eval After Iterative Pruning ===
print("VGG16 evaluation AFTER Iterative Pruning:")
vgg_acc_iter, vgg_top1_acc_iter, vgg_top5_acc_iter, vgg_precision_iter, vgg_recall_iter, vgg_f1score_iter = evaluate_model(vgg_model, testloader)

print(f'\nCompression Details AFTER Iterative Pruning:')
vgg_compression_ratio_iter, vgg_sparsity_iter = count_nonzero_params(vgg_model)

print(f"\nVGG16 inference speed AFTER Iterative pruning:")
vgg_avg_time_per_sample_iter, vgg_avg_fps_iter = measure_inference_speed(vgg_model, testloader, device)
print(f"Inference Speed: {vgg_avg_time_per_sample_iter:.6f} sec/sample ({vgg_avg_fps_iter:.2f} FPS)")



=== VGG16 Example (CIFAR-100) ===
VGG16 baseline evaluation:

Evaluation Metrics:
Accuracy: 76.51%
Top-1 Accuracy: 76.51%
Top-5 Accuracy: 93.48%
Precision (macro): 0.7680
Recall (macro):    0.7651
F1 Score (macro):  0.7653

Compression Details Before Pruning:
Total Parameters: 34015396
Nonzero Parameters: 34015396
Compression Ratio: 1.000x
Sparsity: 0.000%

VGG16 inference speed BEFORE pruning:
Inference Speed: 0.000080 sec/sample (12458.06 FPS)
Inference Speed: 0.000080 sec/sample (12458.06 FPS)

--- VGG16 One-Shot Pruning ---
Pruning VGG16 layer: features.0
Pruning VGG16 layer: features.3
Pruning VGG16 layer: features.7
Pruning VGG16 layer: features.10
Pruning VGG16 layer: features.14
Pruning VGG16 layer: features.17
Pruning VGG16 layer: features.20
Pruning VGG16 layer: features.24
Pruning VGG16 layer: features.27
Pruning VGG16 layer: features.30
Pruning VGG16 layer: features.34
Pruning VGG16 layer: features.37
Pruning VGG16 layer: features.40
VGG16 evaluation AFTER One-Shot Pruning