In [None]:
# Import some useful packages for this homework
import os
import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset # "ConcatDataset" and "Subset" are possibly useful
from torchvision.datasets import DatasetFolder, VisionDataset
from torchsummary import summary
from tqdm.auto import tqdm

In [None]:
import matplotlib.pyplot as plt

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define testing transforms
test_tfm = transforms.Compose([
    # It is not encouraged to modify this part if you are using the provided teacher model. This transform is stardard and good enough for testing.
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

In [None]:
class FoodDataset(Dataset):
    def __init__(self, path, tfm=test_tfm, files = None):
        super().__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {path} sample",self.files[0])
        self.transform = tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
        return im,label

In [None]:
# Form valid dataloaders
valid_set = FoodDataset(os.path.join('./food11-hw13', "validation"), tfm=test_tfm)
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [None]:
def evaluate(model):
    model.to(device)
    model.eval()

    valid_accs = []
    valid_lens = []

    for batch in tqdm(valid_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs) # MEDIUM BASELINE

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels).float().sum()

        # Record the loss and accuracy.
        batch_len = len(imgs)
        valid_accs.append(acc)
        valid_lens.append(batch_len)

    # The average accuracy for entire validation set is the average of the recorded values.
    valid_acc = sum(valid_accs) / sum(valid_lens)
    return valid_acc.item()

Let's say now you want to prune all the parameters named with `weight` in all the `nn.Conv2d` layers in the `model`, with pruning ratio **0.2**. Then please refer to the code below to achieve this.

In [None]:
valid_acc_list = []

for ratio in np.arange(0, 1, 0.05):
    # Specify the pruning ratio
    ratio = round(ratio, 2)
    # Load model
    teacher_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11)
    teacher_ckpt_path = os.path.join('./food11-hw13', "resnet18_teacher.ckpt")
    teacher_model.load_state_dict(torch.load(teacher_ckpt_path, map_location='cpu'))
    for name, module in teacher_model.named_modules():
        if isinstance(module, torch.nn.Conv2d): # if the nn.module is torch.nn.Conv2d
            prune.l1_unstructured(module, name='weight', amount=ratio) # use 'prune' method provided by 'torch.nn.utils.prune' to prune the weight parameters in the nn.Conv2d layers
    # Next, you just have to generize the above code to different ratio and test the accuracy on the validation set of food11-hw13.
    valid_acc = evaluate(teacher_model)
    valid_acc_list.append(valid_acc)
    print(valid_acc)


In [None]:
plt.figure(figsize=(12,6))
plt.plot(np.arange(0, 1, 0.05), valid_acc_list, "-o")
plt.grid(ls="--")
plt.xticks(np.arange(0, 1, 0.05))
plt.title("Pruning Ratio vs. Model Accuracy")
plt.xlabel("Pruning Ratio")
plt.ylabel("Model Accuracy")
# plt.savefig("pruning.png")
plt.show()

In [None]:
valid_acc_list