In [23]:
from models import VGG
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim

### ALL COMMMENTS WITH '##' ARE MINE AND NOT CHAT GPT

In [None]:
from utils import get_train_cifar10_dataloader, get_test_cifar10_dataloader, run_epochs, get_hyperparams, test

trainloader = get_train_cifar10_dataloader()
testloader = get_test_cifar10_dataloader()

In [3]:
model = VGG("VGG19")

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torch.optim as optim
from models import VGG

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === Utils ===

def count_nonzero_parameters(model: nn.Module):
    return sum(torch.count_nonzero(p).item() for p in model.parameters())

## This function is that one that I did with you in the last class
## Where we just apply the pruning to the conv2d and linear layers
def apply_global_pruning(model: nn.Module, amount: int):
    parameters_to_prune = []
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

def remove_pruning(model):
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            if hasattr(module, 'weight_mask'):
                prune.remove(module, 'weight')

# === Pruning Methods ===

# 1. Global Pruning, no retrain
## CAUE - There is nothing special here, just applying the pruning function to the model
def method1_global_pruning_no_retrain(model, trainloader, testloader):
    print("\n=== Method 1: Global Pruning without Retrain ===")
    
    apply_global_pruning(model, amount=0.3)

    acc, _ = test(testloader, model)
    
    params = count_nonzero_parameters(model)
    print(f"Accuracy: {acc:.2f}% | Parameters: {params}")

    remove_pruning(model)

# 2. Global Pruning + Retrain
## CAUE - The same as the previous one, but retraining the model after the pruning
def method2_global_pruning_with_retrain(model, trainloader, testloader):
    print("\n=== Method 2: Global Pruning with Retrain ===")

    apply_global_pruning(model, amount=0.3)

    acc, _, _ = run_epochs(
        model,
        train_loader=trainloader,
        test_loader=testloader,
        hyperparams=HYPERPARAMS,
        n_epochs=1,
    )

    params = count_nonzero_parameters(model)
    print(f"Accuracy after retrain: {acc:.2f}% | Parameters: {params}")

    remove_pruning(model)

# 3. Gradual Pruning por etapas + Retrain
## CAUE - What I understood from this one was that we apply the pruning many times (steps)
## It is possible to change the pruning ratio in each step if we want
def method3_gradual_pruning(model, trainloader, testloader):
    print("\n=== Method 3: Gradual Pruning + Retrain ===")
    
    total_steps = 3
    pruning_amount_per_step = 0.1

    for step in range(total_steps):
        print(f"\n-- Step {step+1}/{total_steps} --")
        apply_global_pruning(model, amount=pruning_amount_per_step)

        acc, _, _ = run_epochs(
            model,
            train_loader=trainloader,
            test_loader=testloader,
            hyperparams=HYPERPARAMS,
            n_epochs=1, 
        )

        params = count_nonzero_parameters(model)
        print(f"Accuracy: {acc:.2f}% | Parameters: {params}")

    remove_pruning(model)

# 4. ThiNet
## CAUE - This one I've looked up and I found an article that describes the state of art
## The article can be found in here: https://arxiv.org/abs/1707.06342
## As far what I've read, it uses the L2 norm to discard the whole filter if it is considered not important
## and not the individual weight like how we have been doing
def method4_thinet_style_pruning(model, trainloader, testloader):
    print("\n=== Method 4: ThiNet Style Pruning ===")
    
    def prune_by_feature_map_norm(model, amount=0.3):
        for module in model.modules():
            if isinstance(module, nn.Conv2d):
                weight = module.weight.detach()
                filter_norms = weight.view(weight.size(0), -1).norm(2, dim=1)
                num_filters_to_prune = int(amount * weight.size(0))
                prune_idx = filter_norms.argsort()[:num_filters_to_prune]

                mask = torch.ones(weight.size(0), device=weight.device)
                mask[prune_idx] = 0

                mask = mask[:, None, None, None]
                module.weight.data.mul_(mask)

    prune_by_feature_map_norm(model, amount=0.3)

    acc, _, _ = run_epochs(
        model,
        train_loader=trainloader,
        test_loader=testloader,
        hyperparams=HYPERPARAMS,
        n_epochs=1,
    )

    params = count_nonzero_parameters(model)
    print(f"Accuracy after ThiNet pruning + retrain: {acc:.2f}% | Parameters: {params}")

In [None]:
model = VGG("VGG19").to(device)
trainloader = get_train_cifar10_dataloader()
testloader = get_test_cifar10_dataloader()

## CAUE - I used your function to get the hyperparameters, but I changed the way the optimiser was defined
## because the pickle returns it as a string
HYPERPARAMS = get_hyperparams()
HYPERPARAMS["criterion"] = nn.CrossEntropyLoss()
HYPERPARAMS["optimiser"] = optim.AdamW(model.parameters(), lr=HYPERPARAMS['lr'], weight_decay=HYPERPARAMS['weight_decay'])

acc_before, _ = test(testloader, model)
print(f"Accuracy before pruning: {acc_before:.2f}%")

method1_global_pruning_no_retrain(model, trainloader, testloader)
method2_global_pruning_with_retrain(model, trainloader, testloader)
method3_gradual_pruning(model, trainloader, testloader)
#method4_thinet_style_pruning(model, trainloader, testloader)

Accuracy before pruning: 10.00%

=== Method 1: Global Pruning without Retrain ===
Accuracy: 10.00% | Parameters: 20035016

=== Method 2: Global Pruning with Retrain ===
Epoch: 0
Saving..
Accuracy after retrain: 58.93% | Parameters: 14033322

=== Method 3: Gradual Pruning + Retrain ===

-- Step 1/3 --
Epoch: 0
Saving..
Accuracy: 71.93% | Parameters: 17954449

-- Step 2/3 --
Epoch: 0
Saving..
Accuracy: 74.17% | Parameters: 17954449

-- Step 3/3 --
Epoch: 0
Saving..
Accuracy: 80.17% | Parameters: 17954449

=== Method 4: ThiNet Style Pruning ===


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [45]:
method4_thinet_style_pruning(model, trainloader, testloader)


=== Method 4: ThiNet Style Pruning ===
Epoch: 0
Saving..
Accuracy after ThiNet pruning + retrain: 79.81% | Parameters: 20019550


In [None]:
## CAUE - At some point I had to use this because I got a OutOfMemoryError when running the models
import gc
gc.collect()

25671