In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision import models
from torchmetrics import Accuracy

import quantus
import captum
from captum.attr import Saliency, IntegratedGradients, NoiseTunnel
from cleverhans.torch.attacks.projected_gradient_descent import (projected_gradient_descent)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import seaborn as sns

import random
import copy
import gc
import math

import warnings
warnings.filterwarnings('ignore')

# from lisa import LISA
from itertools import chain
from pathlib import Path
from ranger import Ranger

In [None]:
from torch.utils.data import Dataset

class SubsetLISA(Dataset):
    """
    A custom dataset for the LISA subset, created using filtered image and label tensors.
    """
    def __init__(self, image_tensor_path, label_tensor_path,train:bool, transform=None):
        """
        Args:
            image_tensor_path (str): Path to the images tensor file.
            label_tensor_path (str): Path to the labels tensor file.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.images = torch.load(image_tensor_path)
        self.labels = torch.load(label_tensor_path)
        self.transform = transform
        self.train = train
        self._train_test_split()

        assert len(self.images) == len(self.labels), "Images and labels length mismatch"

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.labels)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index of the sample.

        Returns:
            tuple: (image, label) where image is the input tensor and label is the target tensor.
        """
        image = self.images[index]
        target = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, target

    def _train_test_split(self, test_percent: float = 0.16):
        classes = {}
        for i, cl in enumerate(self.labels.numpy()):
            arr = classes.get(cl, [])
            arr.append(i)
            classes[cl] = arr

        train, test = [], []
        for cl, arr in classes.items():
            split_index = int(len(arr) * test_percent)
            test = test + arr[:split_index]
            train = train + arr[split_index:]

        sub = train if self.train else test
        self.images, self.labels = self.images[sub], self.labels[sub]

In [None]:
%run models.ipynb
%run utils.ipynb

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100
batch_size = 128

normalize = transforms.Normalize(mean=[0.4563, 0.4076, 0.3895], std=[0.2298, 0.2144, 0.2259])

lisa_transforms = transforms.Compose([ transforms.ToPILImage(),transforms.ToTensor(),normalize])

# Paths to the saved subset tensors
image_tensor_path = "datasets/lisa-batches/subset_images.tensor"  # Replace with your actual file path
label_tensor_path = "datasets/lisa-batches/subset_labels2.tensor"  # Replace with your actual file path


In [None]:
# Initialize the dataset
train_dataset = SubsetLISA(image_tensor_path, label_tensor_path, train=True, transform = lisa_transforms)
test_dataset = SubsetLISA(image_tensor_path, label_tensor_path, train=False, transform = lisa_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,) # num_workers=4,
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Normal Model

In [None]:
model = vgg16()
learning_rate = 0.01
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = 5e-4)

In [None]:
def train_model(model, epochs):
    model.train()
    for epoch in range(epochs):
        for x_batch, y_batch in train_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if epochs%10==0:
            predictions, labels = evaluate_model(model, test_dataloader, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
            print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

In [None]:
model_normal = train_model(model=model.to(device),epochs=epochs)

In [None]:
# Check test set performance.
predictions, labels = evaluate_model(model_normal, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA Normal is: {(100 * test_acc):.2f}%")

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_normal.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_normal.state_dict(), f=model_save_path)

# Adversarial Model

In [None]:
model = vgg16()
learning_rate = 0.01
criterion = torch.nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = 5e-4)
eps= [0.01, 0.03, 0.06, 0.1, 0.3, 0.5]

In [None]:
def train_adv(model, epsilon, epochs):
    model.train()
    eps = epsilon
    for epoch in range(epochs):
        for x_batch, y_batch in train_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            x_batch = projected_gradient_descent(model, x_batch, eps, eps/10, 40, np.inf)
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if epochs%10==0:
            predictions, labels = evaluate_model(model, test_dataloader, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
            print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model


In [None]:
model_adversarial = train_adv(model=model.to(device),epsilon = eps[3], epochs=epochs)

In [None]:
# Model to GPU and eval mode.
model_adversarial.to(device)
model_adversarial.eval()

In [None]:
# Check test set performance.
predictions, labels = evaluate_model(model_adversarial, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Adversarial Model test accuracy: {(100 * test_acc):.2f}%")

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_adv.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_adversarial.state_dict(), f=model_save_path)

# L1 Pre

In [None]:
model = vgg16()
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)
print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

In [None]:
print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

for name, module in model.named_modules():
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

In [None]:
model_l1_unstructured = train_model(model = model.to(device), epochs = epochs)

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_l1_pre.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_l1_unstructured.state_dict(), f=model_save_path)

In [None]:
# Check test set performance.
predictions, labels = evaluate_model(model_l1_unstructured, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA L1 pre is: {(100 * test_acc):.2f}%")


# L1 Post No Tune

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_normal.pth"
model_save_path = model_path / model_name
model = vgg16().to(device)
model.load_state_dict(torch.load(model_save_path))

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

for name, module in model.named_modules():
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

model.to(device)
model.eval()

# Check test set performance.
predictions, labels = evaluate_model(model, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG-16 pretrained no tuning is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_l1_post_notune.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model.state_dict(), f=model_save_path)

# L1 Post Tuned

In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)

def train_model(model, epochs):
    model.train()
    for epoch in range(epochs):
        for x_batch, y_batch in train_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if epochs%10==0:
            predictions, labels = evaluate_model(model, test_dataloader, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
            print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

model_tuned = train_model(model = model.to(device), epochs = epochs)

# Check test set performance.
predictions, labels = evaluate_model(model_tuned, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA L1 tuned post  is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_l1_post_tuned.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_tuned.state_dict(), f=model_save_path)

# Global Pre

In [None]:
model = vgg16()
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

parameters_to_prune = (
    (model.features[0], 'weight'),
    (model.features[2], 'weight'),
    (model.features[5], 'weight'),
    (model.features[7], 'weight'),
    (model.features[10], 'weight'),
    (model.features[12], 'weight'),
    (model.features[14], 'weight'),
    (model.features[17], 'weight'),
    (model.features[19], 'weight'),
    (model.features[21], 'weight'),
    (model.features[24], 'weight'),
    (model.features[26], 'weight'),
    (model.features[28], 'weight'),
    (model.classifier[1], 'weight'),
    (model.classifier[4], 'weight'),
    (model.classifier[6], 'weight')
)

prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]

def train_global_pruned(model, epochs):
    for iter_prune_round in range(1):
        print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
        
        # Prune layer-wise in a structured manner-
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method = prune.L1Unstructured,
            amount = prune_rates_global[iter_prune_round]
            
        )
    
        # Print current global sparsity level-
        print(f"VGG global sparsity = {compute_sparsity_vgg(model):.2f}%")
        
        
        # Fine-training loop-
        print("\nFine-tuning pruned model to recover model's performance\n")
        model.train()
        for epoch in range(epochs):
            for x_batch, y_batch in train_dataloader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                output = model(x_batch)
                loss = criterion(output, y_batch)
                loss.backward()
                optimizer.step()
    
            # Evaluate model!
            if epochs%10==0:
                predictions, labels = evaluate_model(model, test_dataloader, device)
                test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
                print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

model_global = train_global_pruned(model = model.to(device), epochs = epochs)

# Check test set performance.
predictions, labels = evaluate_model(model_global, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA Global pre is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_global_pre.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_global.state_dict(), f=model_save_path)

# Global Post No Tune

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_normal.pth"
model_save_path = model_path / model_name
model = vgg16().to(device)
model.load_state_dict(torch.load(model_save_path))

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

parameters_to_prune = (
    (model.features[0], 'weight'),
    (model.features[2], 'weight'),
    (model.features[5], 'weight'),
    (model.features[7], 'weight'),
    (model.features[10], 'weight'),
    (model.features[12], 'weight'),
    (model.features[14], 'weight'),
    (model.features[17], 'weight'),
    (model.features[19], 'weight'),
    (model.features[21], 'weight'),
    (model.features[24], 'weight'),
    (model.features[26], 'weight'),
    (model.features[28], 'weight'),
    (model.classifier[1], 'weight'),
    (model.classifier[4], 'weight'),
    (model.classifier[6], 'weight')
)

prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]

for iter_prune_round in range(1):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
    
    # Prune layer-wise in a structured manner-
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method = prune.L1Unstructured,
        amount = prune_rates_global[iter_prune_round]
    )

    # Print current global sparsity level-
    print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")


model.to(device)
model.eval()

# Check test set performance.
predictions, labels = evaluate_model(model, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG-16 pretrained no tuning is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_global_post_notune.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model.state_dict(), f=model_save_path)

# Global Post Tuned

In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)

def train_model(model, epochs):
    model.train()
    for epoch in range(epochs):
        for x_batch, y_batch in train_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if epochs%10==0:
            predictions, labels = evaluate_model(model, test_dataloader, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
            print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

model_tuned = train_model(model = model.to(device), epochs = epochs)

# Check test set performance.
predictions, labels = evaluate_model(model_tuned, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA Global post tuned is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_global_post_tuned.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_tuned.state_dict(), f=model_save_path)

# Layered Pre

In [None]:
model = vgg16().to(device)
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)

def train_layered_pruned(model, epochs):
    for iter_prune_round in range(1):
        print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
        
        # Prune layer-wise in a structured manner-
        prune.ln_structured(model.features[0], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[2], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[5], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[7], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[10], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[12], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[14], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[17], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[19], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[21], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[24], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[26], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[28], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[1], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[4], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[6], name = "weight", amount = 0.1, n = 2, dim = 0)
        
        # Print current global sparsity level-
        print(f"VGG global sparsity = {compute_sparsity_vgg(model):.2f}%")
        
        
        # Fine-training loop-
        print("\nFine-tuning pruned model to recover model's performance\n")
        model.train()
        for epoch in range(epochs):
            for x_batch, y_batch in train_dataloader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                output = model(x_batch)
                loss = criterion(output, y_batch)
                loss.backward()
                optimizer.step()
    
            # Evaluate model!
            if epochs%10==0:
                predictions, labels = evaluate_model(model, test_dataloader, device)
                test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
                print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

model_layered_structured = train_layered_pruned(model = model.to(device), epochs = epochs)

# Check test set performance.
predictions, labels = evaluate_model(model_layered_structured, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA Layered pre is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_layered_pre.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_layered_structured.state_dict(), f=model_save_path)

# Layered Post No Tune

In [None]:
model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_normal.pth"
model_save_path = model_path / model_name
model = vgg16().to(device)
model.load_state_dict(torch.load(model_save_path))

print(f"VGG-16 global sparsity = {compute_sparsity_vgg(model):.2f}%")

for iter_prune_round in range(1):
        print(f"\n\nIterative Global pruning round = {iter_prune_round + 1}")
        
        # Prune layer-wise in a structured manner-
        prune.ln_structured(model.features[0], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[2], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[5], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[7], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[10], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[12], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[14], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[17], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[19], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[21], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[24], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[26], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.features[28], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[1], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[4], name = "weight", amount = 0.1, n = 2, dim = 0)
        prune.ln_structured(model.classifier[6], name = "weight", amount = 0.1, n = 2, dim = 0)
        
        # Print current global sparsity level-
        print(f"VGG global sparsity = {compute_sparsity_vgg(model):.2f}%")

model.to(device)
model.eval()

# Check test set performance.
predictions, labels = evaluate_model(model, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG-16 post no tuning is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_layered_post_notune.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model.state_dict(), f=model_save_path)

# Layered Post Tuned

In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)

def train_model(model, epochs):
    model.train()
    for epoch in range(epochs):
        for x_batch, y_batch in train_dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if epochs%10==0:
            predictions, labels = evaluate_model(model, test_dataloader, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
            print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")
    return model

model_tuned = train_model(model = model.to(device), epochs = epochs)

# Check test set performance.
predictions, labels = evaluate_model(model_tuned, test_dataloader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())        
print(f"Test accuracy for VGG LISA Layered post tuned is: {(100 * test_acc):.2f}%")

model_path = Path("models")
model_path.mkdir(parents=True, exist_ok=True)

model_name = "lisa_layered_post_tuned.pth"
model_save_path = model_path / model_name

print(f"Saving the model: {model_save_path}")
torch.save(obj=model_tuned.state_dict(), f=model_save_path)