# Instructions

Run this notebook to:
* Load a vgg16 model pretrained on the cifar10 dataset, from the "pretrainedmodel" folder.
* Use this pretrained model to perform "Filter Pruning via Geometric Median". The pruned model is fine-tuned for 40 epochs. The pruning is done iteratively. So far the parameters are only zeroed out. The pruned model at this stage is saved as "vgg_cifar10_pruned_net.pth" in the present working directory.
* Finally, architecture modifications are performed and the final pruned model is saved as "vgg_cifar10_arch_pruned_net.pth" in the present working directory.

# Selecting device

In [2]:
#%pip install torch 
import torch 
import torch.nn as nn

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
    print("GPU Available")

GPU Available


In [13]:
! CUDA_VISIBLE_DEVICES=0
#! python ./fpgmdata/testing/pruning_cifar_vgg.py  ./fpgmdata/testing/data/cifar.python --dataset cifar10 --arch vgg --save_path ./logs/vgg_prune_precfg_varience4 --rate_norm 1 --rate_dist 0.2
! python  ./fpgmdata/testing/pruning_unet.py  ./fpgmdata/testing/data/carvana --dataset CARVANA --arch UNet --save_path ./logs/unet_pretrain/prune_precfg_epoch40_varience1 --rate_norm 1 --rate_dist 0.2 --use_pretrain --pretrain_path ./pretrainedmodel/MODEL.pth --use_state_dict --lr 0.001 --epochs 5 --use_precfg

save path : ./logs/unet_pretrain/prune_precfg_epoch40_varience1
{'arch': 'UNet', 'batch_size': 1, 'cuda': True, 'data_path': './fpgmdata/testing/data/carvana', 'dataset': 'CARVANA', 'depth': 16, 'dist_type': 'l2', 'epoch_prune': 1, 'epochs': 40, 'evaluate': False, 'layer_begin': 1, 'layer_end': 1, 'layer_inter': 1, 'log_interval': 100, 'lr': 0.001, 'momentum': 0.9, 'no_cuda': False, 'pretrain_path': './pretrainedmodel/MODEL.pth', 'rate_dist': 0.2, 'rate_norm': 1.0, 'resume': '', 'save_path': './logs/unet_pretrain/prune_precfg_epoch40_varience1', 'seed': 1, 'start_epoch': 0, 'test_batch_size': 1, 'use_precfg': True, 'use_pretrain': True, 'use_state_dict': True, 'weight_decay': 0.0001}
Random Seed: 1
python version : 3.8.5 (default, Sep  4 2020, 07:30:14)  [GCC 7.3.0]
torch  version : 1.12.0+cu102
cudnn  version : 7605
Norm Pruning Rate: 1.0
Distance Pruning Rate: 0.2
Layer Begin: 1
Layer End: 1
Layer Inter: 1
Epoch prune: 1
use pretrain: True
Pretrain path: ./pretrainedmodel/MODEL.pth
D

In [3]:
import torch
import sys
sys.path.append("./fpgmdata/testing")
import models

unpruned_model = models.__dict__['UNet'](n_channels=3, n_classes=2)

checkpoint = torch.load("./pretrainedmodel/MODEL.pth", map_location=device)
mask_values = checkpoint.pop('mask_values', [0, 1])
unpruned_model.load_state_dict(checkpoint)

# unpruned_model.load_state_dict(checkpoint['state_dict'])
unpruned_model.to(device)

total = 0
print('Trainable parameters:')

for n, module in unpruned_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        for name, param in module.named_parameters():
            if param.requires_grad:
                print(name, '\t', param.numel())
                total += param.numel()
print()
print('Total', '\t', total)

Trainable parameters:
weight 	 1728
weight 	 36864
weight 	 73728
weight 	 147456
weight 	 294912
weight 	 589824
weight 	 1179648
weight 	 2359296
weight 	 4718592
weight 	 9437184
weight 	 4718592
weight 	 2359296
weight 	 1179648
weight 	 589824
weight 	 294912
weight 	 147456
weight 	 73728
weight 	 36864
weight 	 128
bias 	 2

Total 	 28239682


In [4]:
cfg = []
for layer in unpruned_model.modules():
    if isinstance(layer, nn.Conv2d):
        cfg.append(layer.out_channels)
    elif isinstance(layer, nn.MaxPool2d):
        cfg.append('M')
    elif isinstance(layer, nn.ConvTranspose2d):
        cfg.append(layer.out_channels) 
    elif isinstance(layer, nn.Linear):
        cfg.append(layer.out_features)

print("Model Configuration (cfg):", cfg)

Model Configuration (cfg): [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 1024, 1024, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 2]


# General function to test a model

In [5]:
import numpy as np
import torch.nn.functional as F
import sys
sys.path.append("./fpgmdata/testing")
from dice_score import multiclass_dice_coeff, dice_coeff, dice_loss

def test_model(model):
    model.eval()
    # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # timings = []
    dice_score1 = 0

    #GPU-WARM-UP
    # i=0
    # for data in testloader:
    #     if(i>20):
    #         break
    #     # images, labels = data
    #     # images = images.to(device)
    #     # _ = model(images)

    #     image, mask_true = data['image'], data['mask']
    #     # move images and labels to correct device and type
    #     image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
    #     mask_true = mask_true.to(device=device, dtype=torch.long)

    #     # predict the mask
    #     mask_pred = model(image)

    #     i += 1
    
    # correct = 0
    # total = 0
    with torch.no_grad():
        for data in testloader:
            # images, labels = data
            # images = images.to(device)
            # labels = labels.to(device)

            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            image, mask_true = data['image'], data['mask']
            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # starter.record()
            torch.cuda.empty_cache()

            if torch.cuda.is_available():
                model.cuda()
                
            # predict the mask
            mask_pred = model(image)
            # outputs = model(images)
            
            # ender.record()
            
            # WAIT FOR GPU SYNC
            # torch.cuda.synchronize()
            # curr_time = starter.elapsed_time(ender)
            # timings.append(curr_time)
            
            # convert to one-hot format
            mask_true = F.one_hot(mask_true, model.n_classes).permute(0, 3, 1, 2).float()
            mask_pred = F.one_hot(mask_pred.argmax(dim=1), model.n_classes).permute(0, 3, 1, 2).float()
            # compute the Dice score, ignoring background
            dice_score1 += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
            # _, predicted = torch.max(outputs.data, 1)
            # total += labels.size(0)
            # correct += (predicted == labels).sum().item()

    # print('Accuracy of the network on the 10000 test images: '+str(100 * correct / total))
    print('Accuracy of the network on the 10000 test images: '+str((dice_score1/max(len(testloader),1))))
    
    # tot = np.sum(timings)
    # mean_syn_per_batch = np.sum(timings) / len(timings)
    # std_syn_per_batch = np.std(timings)
    # print("Total inference time for test data: "+str(tot))
    # print("Mean inference time per test batch: "+str(mean_syn_per_batch))
    # print("Standard deviation of inference times per batch: "+str(std_syn_per_batch))
    model.train()

# Loading and normalizing images using TorchVision


In [6]:
import torchvision
import torchvision.transforms as transforms

In [7]:
%pip install tqdm
sys.path.append("./fpgmdata/testing")
from data_loading import CarvanaDataset
import os

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
#                                  transforms.Pad(4),
#                                  transforms.RandomCrop(32),
#                                  transforms.RandomHorizontalFlip(),
#                                  transforms.ToTensor(),
#                                  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#                              ]),
#                                         download=True)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
#                                           shuffle=True, num_workers=2)

# testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#             ]),
#                                        download=True)
# testloader = torch.utils.data.DataLoader(testset,
#                                          batch_size=32,
#                                          shuffle=False, num_workers=2)

# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

dataset = CarvanaDataset("./fpgmdata/testing/data/carvana"+'/imgs', "./fpgmdata/testing/data/carvana"+'/masks', scale=1)
dataset=torch.utils.data.Subset(dataset, range(0,100))
# 2. Split into train / validation partitions
val_percent = 0.2
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = torch.utils.data.random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

# 3. Create data loaders
loader_args = dict(batch_size=1, num_workers=os.cpu_count(), pin_memory=True)

trainloader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)

testloader = torch.utils.data.DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

Note: you may need to restart the kernel to use updated packages.


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5088/5088 [03:10<00:00, 26.75it/s]


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)


# Testing the accuracy of the unpruned model

In [9]:
import gc
torch.cuda.empty_cache()
gc.collect()

test_model(unpruned_model)

AttributeError: 'int' object has no attribute 'float'

In [None]:
torch.cuda.empty_cache()
del unpruned_model

# Loading the pruned (only zeroed out) model

In [None]:
# pruned_model = vgg().to(device)

pruned_model = UNet(n_channels=3, n_classes=2)

pruned_model.load_state_dict(torch.load("./logs/unet_pretrain/prune_precfg_epoch40_varience1/checkpoint.pth.tar")['state_dict'])

# Saving the pruned (only zeroed out) model

In [None]:
torch.save(pruned_model, './unet_carvana_pruned_net.pth') # without .state_dict

# Let's test the accuracy of the pruned (only zeroed out) model

In [10]:
import gc
torch.cuda.empty_cache()
gc.collect()

test_model(pruned_model)

NameError: name 'pruned_model' is not defined

# Changing the architecture

In [None]:
!pip install torch-pruning
import torch_pruning as tp
    
for name, module in pruned_model.named_modules():
    if isinstance(module, torch.nn.Conv2d): #Iterating over all the conv2d layers of the model
        channel_indices = [] #Stores indices of the channels to prune within this conv layer
        t = module.weight.clone().detach()
        t = t.reshape(t.shape[0], -1)
        z = torch.all(t == 0, dim=1)
        z = z.tolist()
        
        for i, flag in enumerate(z):
            if(flag):
                channel_indices.append(i)

        if(channel_indices == []):
            continue
        
        # 1. build dependency graph for vgg
        DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=torch.randn(1,3,32,32).to(device))

        # 2. Specify the to-be-pruned channels. Here we prune those channels indexed by idxs.
        group = DG.get_pruning_group(module, tp.prune_conv_out_channels, idxs=channel_indices)
        #print(group)

        # 3. prune all grouped layers that are coupled with the conv layer (included).
        if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
            group.prune()
    
# 4. Save & Load
pruned_model.zero_grad() # We don't want to store gradient information
torch.save(pruned_model, './unet_carvana_arch_pruned_net.pth') # without .state_dict

# Let's test the accuracy of the pruned model after the architecture modifications

In [None]:
test_model(pruned_model)

# Arch pruned model reload check

In [None]:
reloaded_model = torch.load('./vgg_cifar10_arch_pruned_net.pth')

In [None]:
test_model(reloaded_model)