<a href="https://colab.research.google.com/github/zeyuanyin/ml801/blob/main/lab_2/lab_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Foundations and Advanced Topics in Machine Learning Lab ML801b -- Deep Learning (2)


Lab Goal: Develop a compressed/efficient model with few numbers of parameters and FLOPs.

- Pruning
    - Fine-grained Pruning
    - Channel-level Pruning


<img src="./prune-1.png" alt="Image" style="width:50%;">

### Part 1: Evaluate a pretrained neural network

In [22]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

Load and normalize the CIFAR10 test dataset using torchvision

In [4]:
print("==> Preparing data..")

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:16<00:00, 10311984.52it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Build the model

In [18]:
print("==> Building model..")

model = torchvision.models.get_model("resnet18", num_classes=10).cuda()
model.conv1 = nn.Conv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
)
model.maxpool = nn.Identity()

print(model)

==> Building model..
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3,

Load the pretrained weights into the model

In [73]:
model_weights_url = "https://github.com/zeyuanyin/ml801/releases/download/lab2/cifar10_resnet18_ckpt.pth"
state_dict = torch.hub.load_state_dict_from_url(model_weights_url)['state_dict']
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # unwrap the module prefix (caused by DataParallel)

model.load_state_dict(state_dict)

Downloading: "https://github.com/zeyuanyin/ml801/releases/download/lab2/cifar10_resnet18_ckpt.pth" to /home/zeyuan/.cache/torch/hub/checkpoints/cifar10_resnet18_ckpt.pth
100%|██████████| 42.7M/42.7M [00:02<00:00, 20.2MB/s]


<All keys matched successfully>

In [29]:
def evaluate(model):
    model.eval().cuda()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print("Accuracy: {:.2f}%({}/{})".format(100.0 * correct / total, correct, total))

    return 100.0 * correct / total

acc = evaluate(model)

print("==> Evaluate: non-pruned model's accuracy = {:.2f}%".format(acc))

100%|██████████| 100/100 [00:00<00:00, 113.05it/s]

Accuracy: 94.92%(9492/10000)
==> Evaluate: non-pruned model's accuracy = 94.92%





### Part 2: Fine-grained Pruning

<img src="./prune-2.png" alt="Image" style="width:50%;">

In [62]:
def fine_grained_pruning(model, pruning_ratio):
    # step 1: collect all the weights
    weight_list = []
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            weight_list.append(module.weight.data.view(-1).abs().clone())


    # step 2: get the threshold according to the global ranking
    weight_concat = torch.cat(weight_list, dim=0)
    sorted_weight, _ = torch.sort(weight_concat)
    thre_index = int(len(sorted_weight) * pruning_ratio)
    thre = sorted_weight[thre_index]

    print("==> global threshold: {:.4f}".format(thre))

    # step 3: set the weight to zero according to the threshold
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            weight_copy = module.weight.data.abs().clone()
            mask = (weight_copy > thre).float().cuda()
            # if mask is one, then the corresponding weight will be retained
            # if mask is zero, then the corresponding weight will be pruned
            module.weight.data.mul_(mask)

    return model

In [64]:
model.load_state_dict(state_dict)

pruning_ratio = 0.9
model = fine_grained_pruning(model, pruning_ratio)
acc = evaluate(model)
print(f"==> Evaluate: {pruning_ratio*100}% pruning model's accuracy = {acc:.2f}%")

==> global threshold: 0.0108


100%|██████████| 100/100 [00:00<00:00, 118.48it/s]

Accuracy: 44.71%(4471/10000)
==> Evaluate: 90.0% pruning model's accuracy = 44.71%





### Part 3: Channel-level Pruning

<img src="./prune-3.png" alt="Image" style="width:50%;">

In [65]:
def fine_channel_level_pruning(model, pruning_ratio):
    # step 1: batchnorm's gamma is the weights we want to prune
    gamma_weight_list = []
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            gamma_weight_list.append(module.weight.data.abs().clone().cuda())

    # step 2: get the threshold according to the global ranking
    gamma_weight_concat = torch.cat(gamma_weight_list, dim=0)
    sorted_gamma_weight, _ = torch.sort(gamma_weight_concat)
    thre_index = int(len(sorted_gamma_weight) * pruning_ratio)
    thre = sorted_gamma_weight[thre_index]

    print("==> global threshold: {:.4f}".format(thre))

    # step 3: set the weight to zero according to the threshold
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            weight_copy = module.weight.data.abs().clone()
            mask = (weight_copy > thre).float().cuda()
            # if mask is one, then the corresponding weight will be retained
            # if mask is zero, then the corresponding weight will be pruned
            module.weight.data.mul_(mask)
            module.bias.data.mul_(mask)

    return model

In [72]:
model.load_state_dict(state_dict)
pruning_ratio = 0.4
model = fine_channel_level_pruning(model, pruning_ratio)
acc = evaluate(model)

==> global threshold: 0.0869


100%|██████████| 100/100 [00:00<00:00, 108.25it/s]

Accuracy: 80.80%(8080/10000)





### Practice in Class:（not graded）

Now it's your turn. 

- Try different pruning ratios and see how the evaluation accuracy changes, you can plot the **accuracy vs. pruning ratio curve**.
- Try different models on different datasets, like CIFAR100. (Codebase is available at https://github.com/kuangliu/pytorch-cifar). Some other pretrained model are provided here
    -  ResNet-50 on CIFAR-10: https://github.com/zeyuanyin/ml801/releases/download/lab2/cifar10_resnet50_ckpt.pth
    -  ResNet-18 on CIFAR-100: https://github.com/zeyuanyin/ml801/releases/download/lab2/cifar100_resnet18_ckpt.pth
    -  ResNet-50 on CIFAR-100: https://github.com/zeyuanyin/ml801/releases/download/lab2/cifar100_resnet50_ckpt.pth