# Network Pruning

In this practical, we will be using PyTorch's `prune` module to  to implement different types of pruning mechanism on a CNN trained to detect land use from satellite images. We will be using [EuroSat (RGB)](https://github.com/phelber/eurosat) dataset for this task. This tutorial will work through the following steps - 

1. Load and explore EuroSAT (RGB) dataset containing RGB images. Note that there is a larger sized counterpart to EuroSAT dataset with multispectral (MS) images. For this tutorial, we will be using RGB images instead. Satellite MS images are handeled conveniently by [`torchgeo`](https://github.com/microsoft/torchgeo). EuroSAT(MS) is also available through `torchgeo.datasets`. 
2. Create and train a Deep CNN on EuroSAT (RGB). 
3. Explore various pruning techniques and the their pareto frontier, i,e., trade-off between accuracy and sparsity. Specifically, we will look at 
    - Random unstructured pruning with prune rate applied locally
    - L1 unstructured pruning with prune rate applied locally
    - L1 unstructured pruning with prune rate applied globally
    - L1 structured pruning

**WARNING**: 

1. This tutorial requires retraining the pruned model. Each step of retraining will be time consuming. To alleviate that, this tutorial comes with the trained pruned model in `./models/best_DeepCNN.ckpt`. A function `load_unpruned_model` loads the trained unpruned model. Thus, you can skip the training of the unpruned model. 

2. First 2 steps are standard to constructing any machine learning models. Therefore, we can skim through these steps. There are still some TODOs in case someone wants to run the tutorial from start to finish.

3. Jump over to **B. Network Pruning** to start learning how to prune models in PyTorch

4. If the kernel dies in between, the trained pruned model are saved in `./models` with their corresponding postfixes. Loading a pruned model in PyTorch requires some workarounds. Follow the code in `train` to learn more about details on how to do so. 

## A. Standard training 

### Basic imports

In [None]:
# basic imports
import numpy as np
import matplotlib.pyplot as plt
import math
import pathlib
import glob

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.utils.prune as prune #### NOTE: `prune` module need to be loaded like this only. https://github.com/pytorch/pytorch/issues/32483

from matplotlib.lines import Line2D
from sklearn.model_selection import train_test_split
from collections import defaultdict

# fix seed for reproducibility 
rng = np.random.RandomState(1)
torch.manual_seed(rng.randint(np.iinfo(int).max))

# it is a good practice to define `device` globally
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", device)
else:
    device = torch.device("cpu")
    print("No GPU -> using CPU:", device)

### Load Data

**Note:** There is no train, val, test split in this dataset, so we need to create it ourselves. 

In [None]:
torchvision.datasets.EuroSAT(root="./data", download=True)

In [None]:
data = torchvision.datasets.EuroSAT(root='./data') 

### Explore Data

Look at the documentation of [torchvision.datasets.EuroSAT](https://pytorch.org/vision/stable/generated/torchvision.datasets.EuroSAT.html#torchvision.datasets.EuroSAT) to understand the data structure. 

In [None]:
print("What is the type of data?\n", type(data))
print("\nHow does an observation in data look like?\n", data[0])
print("Each obseervation is a tuple of (image, label)")
print("\nHow does an image in data look like?", data[0][0])
print("Each image is a 3x64x64 tensor")
print("\nHow many observations are there?\n", len(data))

We observe that the label is an index. In order to **understand what these labels mean**, we need to extract the class to index mapping. Can you investigate `dir(data)` to find the attrbiute that stores this map?

In [None]:
print("What is the class label to index mapping?")
### YOUR CODE HERE: call the right attribute of data, i.e, data.

Let's look at some of the images in each of these category. 

In [None]:
label_idx_map = defaultdict(list)
for idx in range(len(data)):
    label_idx_map[data[idx][1]].append(idx)

In [None]:
n_samples_per_class = 5
n_classes = len(data.class_to_idx)

fig, axs = plt.subplots(nrows=n_classes, ncols=n_samples_per_class , figsize=(10,10), dpi=100)

reverse_label_map = {idx:label for label, idx in data.class_to_idx.items()}
for row in range(n_classes):
    label = reverse_label_map[row]
    sample_img_idx = np.random.choice(label_idx_map[row], size=n_samples_per_class)
    
    axs[row][0].set_ylabel(label, fontweight='bold', rotation=0, labelpad=75)
    for j, img_idx in enumerate(sample_img_idx):
        ax = axs[row][j]
        ax.imshow(
            ## YOUR CODE HERE: USE THE RIGHT INDEXING FOR DATA ## 
        )
        ax.axis('off')
    
    axs[row][0].axis('on')
    axs[row][0].xaxis.set_ticks([])
    axs[row][0].yaxis.set_ticks([])
    
_ = fig.suptitle("EuroSAT dataset samples per class")

### Split data for training, validation and testing

Since we are interested in evaluating the generalization performance of the models, we will split the dataset into train and test datasets. 
We will use the train dataset to train the models and use the test datasets to evaluate the performance of these models. 
As per the standard practice, we will apply the normalization procedure as infered from the train dataset to the test dataset at the evaluation time. 

In [None]:
# split indices into train, val, and test indices
TRAIN_SPLIT=0.6
VAL_SPLIT=0.2
TEST_SPLIT=0.2
train_idxs, val_idxs, test_idxs = ## YOUR CODE HERE: use torch.utils.data.random_split to get indices for each of the dataset

We will convert PIL image to a numpy arrray using `np.asarray`, and compute the means and standard deviation across the RGB channels. 

In [None]:
X_train = torch.stack([torch.tensor(np.asarray(data[idx][0])) for idx in train_idxs.indices])

In [None]:
X_train.shape

In [None]:
DATA_MEANS =  (X_train / 255.0).mean(axis=(0,1,2)) # to compute mean per channel, we will reduce the other dimensions
DATA_STD = (X_train / 255.0).std(axis=(0,1,2)) # to compute mean per channel, we will reduce the other dimensions

print(f"data means along three dimensions: {DATA_MEANS}")
print(f"data std along three dimensions: {DATA_STD}")


data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=DATA_MEANS, std=DATA_STD)
])

Now, we can apply `data_transforms` before the data is loaded. To do so, we need to create a `torch.utils.data.Dataset` class that can load individual observations and transform them before giving them to `torch.utils.data.DataLoader` that batches these observations together. 

Each of the custom `torch.utils.data.Dataset` classes, require the user to define `__getitem__` function to retrieve a single observation at `index` and `__len__` function to return the total number of observations.

In [None]:
class TransformedData(torch.utils.data.Dataset):
    def __init__(self, data, indices, transform):
        self.data = data
        self.subset_idxs = indices
        self.transform = transform 
    
    def __getitem__(self, index):
        x,y = self.data[
            ## YOUR CODE HERE: Use the correct indexing 
        ]
        
        if self.transform:
            x = self.transform(x)
        return x, y
    
    def __len__(self):
        return len(self.subset_idxs)

In [None]:
val_data  = TransformedData(data, val_idxs.indices, data_transforms)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=256)

test_data = TransformedData(data, test_idxs.indices, data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=256)

### Simple CNN model


In [None]:
class DeepCNN(nn.Module):
    def __init__(self, c_in, num_classes):
        super().__init__()
        self.input_args = [c_in, num_classes]
        
        # READ THROUGH THE CODE TO UNDERSTAND THE MODEL. This is similar to AlexNet.
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(c_in, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),        
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),        
        )
        
        self.flatten = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1,1)), # kernel_size and stride are automatically inferred: https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html
            nn.Flatten(),
        )
        
        self.linear1 = nn.Sequential(
            nn.Linear(16, 256),
            nn.ReLU(),
        )
        
        self.out = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.out(x)
        return x
                

We will write the standard functions to retrieve the attributes of the model, e.g., memory requirements, number of parameters, sparsity in the layers, and global sparsity. 

In [None]:
## NOTE: '_orig' in the if-else conditions. It will become clear in the section B of this tutorial. 
def mem_size(model):
    """
    Get model size in GB (as str: "N GB")
    """
    mem_params = sum(
        [param.nelement() * param.element_size() for param in model.parameters()]
    )
    mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
    mem = mem_params + mem_bufs
    return f"{mem / 1e9:.4f} GB"

def num_params(model):
    """
    Print number of parameters in model's named children
    and total
    """
    s = "Number of parameters (sparsity):\n"
    named_buffers = dict(model.named_buffers())
    n_params = 0
    for name, child in model.named_children():
        n = sum(p.numel() for p in child.parameters())
        if n == 0:
            continue

        nz = 0
        for child_name, p in child.named_parameters():
            if '_orig' in child_name:
                nz +=  torch.sum(named_buffers[f"{name}.{child_name.replace('_orig', '_mask')}"]).item()
            else:
                nz += torch.sum(p!=0).item()
#         nz = sum(torch.count_nonzero(p) for p in child.parameters())
        sparsity = (n-nz)/n
        s += f"  • {name:<15}: {n} \t {sparsity*100:2.3f}%\n"
        n_params += n
    s += f"{'total':<19}: {n_params}"

    return s


def compute_sparsity(model):
    """
    Computes global sparsity for the model.
    """
    named_buffers = dict(model.named_buffers())
    total_nnz, total_params = 0, 0
    for n, p in model.named_parameters():
        total_params += p.numel()
        
        if '_orig' in n:
            total_nnz += torch.sum(named_buffers[n.replace('_orig', '_mask')]).item()
        else:
            total_nnz += torch.sum(p!=0).item()

    return (total_params - total_nnz) / total_params


def pp_model_summary(model):
    print(num_params(model))
    print(f"{'Total memory':<18} : {mem_size(model)}")
    
    sparsity = compute_sparsity(model)
    print(f"{'Global sparsity':<18} : {100*sparsity: 2.3f}%")

Instantiate the model and print it's attributes.

In [None]:
model = DeepCNN(3, 10)
pp_model_summary(model)

Define a standard training function that can be called on `model`. Rest of the arguments are intended for retraining the pruned model as mentioned in the docstring. Note that this setup is quite standard, and you might have followed it in the previous courses (e.g., Modern Network Architecture, Denoising Autoencoders, Attention is all you need, Deep Autoencoders)

In [None]:
def train(model, save_with_pruned_postfix="",n_epochs=100, print_every=1):
    """Trains the model. 
    
    Args:
        model (torch.nn.Module): model to be trained 
        save_with_pruned_postfix (str): The best model is saved with this as postfix
        n_epochs (int): maximum number of epochs to run
        print_every (int): print performance at every these number of epochs
    
    Returns:
        model (torch.nn.Module): best perforrming model 
        metrics (dict): metrics, e.g., losses, accuracy, etc.
    """
    model.to(device)
    
    pruning_postfix = f"_prune_{save_with_pruned_postfix}" if save_with_pruned_postfix else ""
    
    # fix seed for reproducibility 
    rng = np.random.RandomState(1)
    torch.manual_seed(rng.randint(np.iinfo(int).max))
    
    # create a model directory to store the best model
    model_dir = pathlib.Path("./models").resolve()
    if not model_dir.exists():
        model_dir.mkdir()
        
    epoch_size=n_epochs
    batch_size=64
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    best_val_acc = 0
    no_improvement_count = 0
    for epoch in range(n_epochs):
        
        # training loss 
        epoch_indices = rng.choice(train_idxs.indices, epoch_size * batch_size, replace=True)
        train_dataloader = torch.utils.data.DataLoader(
            ## YOUR CODE HERE: Specify the transformed dataset here
            batch_size=batch_size, 
            num_workers=4)
        train_loss, train_acc = process(model, train_dataloader, optimizer)
        
        # validation loss 
        with torch.no_grad():
            val_loss, val_acc = process(model, val_dataloader, optimizer=None)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()
            }, model_dir / f"best_{model.__class__.__name__}{pruning_postfix}.ckpt")
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            
            if no_improvement_count % 10 == 0:
                print("Early stopping...")
                break
        
        # logging 
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)  
        
        if epoch % print_every == 0:
            print(f"Epoch: {epoch}\t train loss:{train_loss: 0.5f}\t train acc: {100*train_acc:2.3f}%\t val_loss:{val_loss:0.5f}\t val_acc:{100*val_acc:2.3f}%")
        
    print(f"best val acc: {best_val_acc:0.3f}")

    # load the best model
    model = model.__class__(*model.input_args)
    if save_with_pruned_postfix:
        apply_mask_to_loaded_model(model) # a freshly initalized model doesn't 
    model.load_state_dict(torch.load(model_dir /  f"best_{model.__class__.__name__}{pruning_postfix}.ckpt")['model_state'])
    model = model.to(device) 
    
    metrics = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }
    return model, metrics


In [None]:
def process(model, dataloader, optimizer=None):
    n_samples = 0
    running_loss, running_acc = 0, 0
    for batch, labels in dataloader:
        # transfer to GPU if avaiable
        batch = batch.to(device)
        labels = labels.to(device)

        n_samples += batch.shape[0]
        
        # forward pass
        outputs = model(batch)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = outputs.argmax(dim=1)
        
        # backward pass 
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        running_acc += (preds == labels).sum().float().item()
        
    return running_loss / n_samples, running_acc / n_samples
        
    

**Note:** This section can be skipped. The trained model will be provided in `./models` folder.

In [None]:
model, metrics = train(model)

In [None]:
# load the best model
def load_unpruned_model():
    model = DeepCNN(3, 10)
    model_dir = pathlib.Path("./models").resolve()
    model = model.__class__(*model.input_args)
    model.load_state_dict(torch.load(model_dir /  f"best_{model.__class__.__name__}.ckpt")['model_state'])
    model = model.to(device) 
    return model

def compute_test_accuracy(model):
    ## YOUR CODE HERE
    return test_acc

model = load_unpruned_model()
unpruned_test_acc = compute_test_accuracy(model)

# 
print(f"Test accuracy of unpruned model:{unpruned_test_acc * 100:2.3f}%")
pp_model_summary(model)

## B. Network Pruning


**How pruning is implemented PyTorch code?**

Recall that PyTorch stores its weights in `model.state_dict()`. PyTorch's pruning module changes the `model.state_dict()`. Specifically, it acts on the parameters defined by the `name` (e.g., `weight`, `bias`) attribute of the `module` (e.g., `nn.Linear` layer defined in `model`). During pruning, it computes a pruning mask, such that the weights to be pruned are assigned a value of 0. This mask is stored in `model.named_buffers()` with the key as `name_mask` (e.g., `weight_mask`, or `bias_mask`). At the same time, `model.state_dict()` now changes the corresponding key to `name_orig` (e.g., `weight_orig` or `bias_orig`), which stores the origin unpruned weights. At the run time, when the `model.forward` is called, the mask is applied to the unpruned weights. 

**What are the implications of such an implementation?**

1. **Size of the model**: Unless the `name_mask` parameters are removed, the size of the model will only grow. Thus, after pruning is complete, one can call `prune.remove` to thrrow away masks and store the weights to their pruned values. 

2. **Iterative pruning**: Every pruning iteration will create it's own mask. Thus, nothing special needs to be done. These masks are stored in a list as `module._forward_pre_hooks.values()`. 

3. **Saving a pruned model**: As explained above, pruned models carry masks with them. Thus, the pruned models need to be saved with their entire state if they are to be used later (e.g., after the pre-empted training). In our tutorial, we will only be saving `model.state_dict()`, however, in practice, you would want to save `model.named_buffers()` too. 

4. **Loading a pruned model**: Loading a pruned model follows a slightly different procedure as well. An unpruned model will not have any `name_orig` in their keys. However, the `state_dict` of the pruned model will contain such keys. Hence, the loading procedure will complain. Therefore, one needs to use `prune.identity` on these keys to tell PyTorch that these parameters were being pruned. 

**Telling PyTorch which layers to prune**

`prune` requires the user to specify which layers and what type parameters (e.g., `weight` or `biases`) are to be pruned. 


-----

In our tutorial, we will focus our pruning efforts only on the convolutional layers defined in `model.conv1`, `model.conv2`, `model.conv3`, and `model.conv4`. And we will focus only on `weight` parameters of these layers. 

**Note**: Each of these `conv` layers are defined as `nn.Sequential` which can be treated as a `list`. 

In [None]:
def get_weights_to_prune(model):
    """Returns the list of (module, name) in model to be pruned."""
    return [
        ## YOUR CODE HERE: Pass the tuple of (module, name) here to be used in the code cells below. 
        ## Pass such tuples for kernel weights in model.conv1, model.conv2, model.conv3, model.conv4. 
        ## There will be a total of 4 tuples.
    ]

def apply_mask_to_loaded_model(model):
    """
    Applies `prune.identity` to the freshly loaded model that was pruned earlier. 
    `prune.identity` tells `torch` which weights are in the process of being pruned, and hence need to be trated differrently.
    """
    m_weights_to_prune = get_weights_to_prune(model)
    for module, name in m_weights_to_prune:
        prune.identity(module=module, name=name)

### Random unstructured pruning

As a starting example, we will randomly prune weights in an unstructured fashion. To do so, we will use `prune.random_unstructured` functionality to define which modules and what parameters are to be pruned. Read the more [here](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#pruning-a-module).

In [None]:
weights_to_prune = get_weights_to_prune(model)

PRUNE_RATE=0.3
for module, name in weights_to_prune:
    prune.random_unstructured(
        ## YOUR CODE HERE: read the documentation above
    )

Let's look at the sparsity induced by above pruning and the resulting test accuracy. 

In [None]:
pruned_test_acc = compute_test_accuracy(model)
print(f"Pruned test acc: {100*pruned_test_acc: 2.3f}%")
pp_model_summary(model)

Let's retrain the pruned mdoel and observe the effect on the test accuracy?

**Q.** Does the test accuracy increase or decrease? Can you reason why?

In [None]:
retrained_model, _ = train(model, save_with_pruned_postfix=f"random_unstructured_{PRUNE_RATE}", n_epochs=50, print_every=10)
pruned_test_acc = compute_test_accuracy(retrained_model)
print(f"Pruned test acc: {100*pruned_test_acc: 2.3f}%")
pp_model_summary(retrained_model)


### Iterative finetuning

In practice, pruning is applied iterative. For this tutorial, we will apply pruning twice before jumping to other methods.

**Q.** How does the performance vary in this iteration? 

In [None]:
weights_to_prune = get_weights_to_prune(retrained_model)
for module, name in weights_to_prune:
    prune.random_unstructured(
        ## YOUR CODE HERE: read the documentation above
    )


print("Iteration #2: \n")

# before finetuning
pruned_test_acc = compute_test_accuracy(retrained_model)
print(f"Pruned test acc: {100*pruned_test_acc: 2.3f}%")
pp_model_summary(retrained_model)

# after finetuning
retrained_model, _ = train(retrained_model, save_with_pruned_postfix=f"random_unstructured_{PRUNE_RATE}_#2", n_epochs=50, print_every=10)
pruned_test_acc = compute_test_accuracy(retrained_model)
print(f"Pruned test acc: {100*pruned_test_acc: 2.3f}%")
pp_model_summary(retrained_model)


### L1 Unstructured (local)

To prune the weights with lowest magnitude (i.e., L1-norm), use the `prune.11_unstructured`. Read more at the [documentation here](https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.l1_unstructured.html). 

### L1 Structured (local)

`prune.ln_structured` lets the user define the dimension along which norm is to be computed. Thus, if the CNN weights are of shape `[3, 32, 32]` and we want to eliminate some filters, i.e., channels in dimension 0, the function expects an argument `dim` corresponding to that. [Read more about the function here](https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.ln_structured.html).

### L1 Unstructured (global)

Finally, `prune.global_unstructured` can apply the prune rate globally across the specified modules. [Read its documentation here](https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.global_unstructured.html).

In [None]:
model = load_unpruned_model()
unpruned_test_acc = compute_test_accuracy(model)

PRUNE_RATES = [0.3, 0.5, 0.7]
N_RETRAINING_EPOCHS = 15
PRINT_EVERY=N_RETRAINING_EPOCHS//3

# random unstructured pruning 
print("\nRandom Unstructured (local)\n")
random_unstructured_performance = []
for prune_rate in PRUNE_RATES:
    model = load_unpruned_model()
    weights_to_prune = get_weights_to_prune(model)
    for module, name in weights_to_prune:
        prune.random_unstructured(
            ## YOUR CODE HERE: read the documentation above
        )
    
    retrained_model, _ = train(model, save_with_pruned_postfix=f"random_unstructured_{prune_rate}", n_epochs=N_RETRAINING_EPOCHS, print_every=PRINT_EVERY)
    sparsity = compute_sparsity(retrained_model)
    test_acc = compute_test_accuracy(retrained_model)
    
    random_unstructured_performance.append((sparsity, test_acc))


# l1 unstructured pruning (local)
print("\nL1 Unstructured (local)\n")
l1_unstructured_performance = []
for prune_rate in PRUNE_RATES:
    model = load_unpruned_model()
    weights_to_prune = get_weights_to_prune(model)
    for module, name in weights_to_prune:
        prune.l1_unstructured(
            ## YOUR CODE HERE: read the documentation above
        )
    
    retrained_model, _ = train(model, save_with_pruned_postfix=f"l1_unstructured_{prune_rate}", n_epochs=N_RETRAINING_EPOCHS, print_every=PRINT_EVERY)
    sparsity = compute_sparsity(retrained_model)
    test_acc = compute_test_accuracy(retrained_model)
    
    l1_unstructured_performance.append((sparsity, test_acc))

# l1 unstructured pruning (global)
print("\nL1 Unstructured (global)\n")
l1_unstructured_global_performance = []
for prune_rate in PRUNE_RATES:
    model = load_unpruned_model()
    weights_to_prune = get_weights_to_prune(model)
    prune.global_unstructured(
        ## YOUR CODE HERE: read the documentation above
    )
    
    retrained_model, _ = train(model, save_with_pruned_postfix=f"l1_unstructured_global_{prune_rate}", n_epochs=N_RETRAINING_EPOCHS, print_every=PRINT_EVERY)
    sparsity = compute_sparsity(retrained_model)
    test_acc = compute_test_accuracy(retrained_model)
    
    l1_unstructured_global_performance.append((sparsity, test_acc))


# l1 structured pruning (local) 
print("\nL1 structured\n")
l1_structured_performance = []
for prune_rate in PRUNE_RATES:
    model = load_unpruned_model()
    weights_to_prune = get_weights_to_prune(model)
    for module, name in weights_to_prune:
        prune.ln_structured(
            ## YOUR CODE HERE: read the documentation above
        )

    
    retrained_model, _ = train(model, save_with_pruned_postfix=f"l1_structured_{prune_rate}", n_epochs=N_RETRAINING_EPOCHS, print_every=PRINT_EVERY)
    sparsity = compute_sparsity(retrained_model)
    test_acc = compute_test_accuracy(retrained_model)
    
    l1_structured_performance.append((sparsity, test_acc))




In [None]:
# plot various performances

fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(12,6), dpi=100)

# unpruned model 
axs.hlines(y=unpruned_test_acc, xmin=0, xmax=1, linestyles="--", colors='#E8384F', linewidth=2)

# random unstructured
x,y = zip(*random_unstructured_performance)
axs.plot(x, y, color="#208EA3", linestyle=":", marker="o", label="random unstructured", linewidth=2)

# l1 unstructured (local)
x,y = zip(*l1_unstructured_performance)
axs.plot(x, y, color="#A4C61A", linestyle=":", marker="o", label="l1 unstructured", linewidth=2)


# l1 unstructured (global)
x,y = zip(*l1_unstructured_global_performance)
axs.plot(x, y, color="#8D9F9B", linestyle=":", marker="o", label="l1 unstructured (global)", linewidth=2)


# l1 structured
x,y = zip(*l1_structured_performance)
axs.plot(x, y, color="#37A862", linestyle=":", marker="o", label="l1 structured", linewidth=2)

axs.set_ylabel("Test accuracy", fontsize=20)
axs.set_xlabel("Sparsity", fontsize=20)

# tick size
for tick in axs.xaxis.get_major_ticks():
    tick.label.set_fontsize(15)

for tick in axs.yaxis.get_major_ticks():
    tick.label.set_fontsize(15)

axs.set_xlim(0,1)
axs.grid(True, linestyle=':')

# legend

legend = []
legend.append(Line2D([0,1], [1,0], color="#E8384F", label="Unpruned model", linewidth=5))
legend.append(Line2D([0,1], [1,0], color="#208EA3", label="Random Unstructured", linewidth=5))
legend.append(Line2D([0,1], [1,0], color="#A4C61A", label="L1 Unstructured", linewidth=5))
legend.append(Line2D([0,1], [1,0], color="#8D9F9B", label="L1 Unstructured (global)", linewidth=5))
legend.append(Line2D([0,1], [1,0], color="#37A862", label="L1 Structured", linewidth=5))
lgd = fig.legend(handles=legend, ncol=1, fontsize=15, loc="center right", fancybox=True, bbox_to_anchor=(1.0, 0.5, 0.2, 0))


_ = fig.suptitle("Sparsity-Accuracy Tradeoff", fontsize=20)


### Saving a pruned model 

To reduce the size of the pruned model, we remove the masks in `model.state_dict()`. This is convenientyl done by `prune.remove(module, name)`. Read more about it [here](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#remove-pruning-re-parametrization).

### Defining a custom pruning method 

Applying a pruning method requires defining computing mask. This can be done through a custom class. See [here](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#extending-torch-nn-utils-prune-with-custom-pruning-functions) to learn how to extend pruning method. 