In [1]:
#imports
from heapq import nsmallest
from operator import itemgetter
from PIL import Image
from torch.autograd import Variable
from torchvision import models
import argparse
import cv2
import glob
import ipdb
import numpy as np
import os
import sys
import time
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
#model and pruner class
class ModifiedVGG16Model(torch.nn.Module):
    def __init__(self):
        super(ModifiedVGG16Model, self).__init__()

        model = models.vgg16(pretrained=True) 
        self.features = model.features #use the pre-trained feature head

        for param in self.features.parameters(): #freeze the feature head
            param.requires_grad = False 

        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(25088, 4096), #output of last conv is 7x7x512 (25088). Feed that to FC layer 
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 2)) #final classifer is 2 category

    def forward(self, x):
        x = self.features(x) #x.shape torch.Size([32, 512, 7, 7]) (batchsize 32)
        x = x.view(x.size(0), -1) #x.shape torch.Size([32, 25088])
        x = self.classifier(x) #x.shape torch.Size([32, 2])
        return x

class FilterPrunner:
    def __init__(self, model):
        self.model = model
        self.reset()
    
    def reset(self):
        self.filter_ranks = {}

    def forward(self, x):
        self.activations = []
        self.gradients = []
        self.grad_index = 0 #only used to find current activation_index in self.compute_rank
        self.activation_to_layer = {}

        activation_index = 0
        #first pass x through feature head
        for layer, (name, module) in enumerate(self.model.features._modules.items()):
            x = module(x)
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                x.register_hook(self.compute_rank) 
                #when grad of x is computed, self.compute_rank is called with grad of x as arg
                self.activations.append(x)
                self.activation_to_layer[activation_index] = layer
                activation_index += 1
        #pass x through classifier head
        return self.model.classifier(x.view(x.size(0), -1))

    def compute_rank(self, grad):
        activation_index = len(self.activations) - self.grad_index - 1
        activation = self.activations[activation_index] # get current activation

        taylor = activation * grad #grad is arg variable. activation is extracted (saved in forward pass) 
        # Get the average value for every filter, 
        # accross all the other dimensions
        taylor = taylor.mean(dim=(0, 2, 3)).data # sum in all dimensions except the dimension of the output


        if activation_index not in self.filter_ranks:
            self.filter_ranks[activation_index] =  torch.FloatTensor(activation.size(1)).zero_()
            self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda()

        self.filter_ranks[activation_index] += taylor
        self.grad_index += 1

    def lowest_ranking_filters(self, num):
        data = []
        for i in sorted(self.filter_ranks.keys()):
            for j in range(self.filter_ranks[i].size(0)):
                data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))

        return nsmallest(num, data, itemgetter(2))

    def normalize_ranks_per_layer(self):
        '''
        self.filter_ranks is a dict. 
        key is conv filter number. value is a 1D vector (tensor) (taylor scores) of size same as  number of filters in conv layer. 
        self.filter_ranks.keys=dict_keys([12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
        each key represents a conv layer
        0 is earliest layer of 64 filters
        12 is deepest layer of 512 filters
        below list of conv layer number, and number of filters in each conv later. 
        [(0,64),(1,64),(2,128),(3,128),(4,256),(5,256),(6,256),(7,512),(8,512),(9,512),(10,512),(11,512),(12,512)]
        '''
        
        for i in self.filter_ranks:
            v = torch.abs(self.filter_ranks[i])
            #v = v / np.sqrt(torch.sum(v * v))
            v = v / torch.sqrt(torch.sum(v * v))
            self.filter_ranks[i] = v.cpu()  #v (or self.filter_ranks[i] was originally on cuda. check why)
            #what if you dont convert to cpu here
            #self.filter_ranks[i] = v #verify later remving .cpu() has what implications

    def get_prunning_plan(self, num_filters_to_prune):
        filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)
        '''
        filters_to_prune is a list of 3-tuples
        [(layer_number,filter_number,tensor_number)]
        len(filters_to_prune)=32 (filters removed per iteration)
        [(21, 389, tensor(2.0786e-06)), (26, 229, tensor(9.5175e-06)), (26, 480, tensor(2.4851e-05))...]
        '''
        
        # After each of the k filters are prunned,
        # the filter index of the next filters change since the model is smaller.
        filters_to_prune_per_layer = {}
        for (l, f, _) in filters_to_prune: #change with default dict
            if l not in filters_to_prune_per_layer:
                filters_to_prune_per_layer[l] = []
            filters_to_prune_per_layer[l].append(f)
        '''
        {28: [358, 447, 147, 472, 143], 24: [356, 216, 250, 437, 182], 17: [3, 92, 277, 408, 477, 301, 360], 
        19: [210, 38, 262, 285], 14: [116, 112], 21: [428, 11, 430, 15], 26: [415, 91, 93, 167], 10: [116]}
        '''
        for l in filters_to_prune_per_layer:
            filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) #sorts the filter order in indivisual conv layer
            for i in range(len(filters_to_prune_per_layer[l])): 
                filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i #you will prune one filter at a time. once you prune a filter, the filter number of remainig filters shifts 
        '''
        {28: [143, 146, 356, 444, 468], 24: [182, 215, 248, 353, 433], 17: [3, 91, 275, 298, 356, 403, 471], 
        19: [38, 209, 260, 282], 14: [112, 115], 21: [11, 14, 426, 427], 26: [91, 92, 165, 412], 10: [116]}
        '''
        filters_to_prune = []
        for l in filters_to_prune_per_layer:
            for i in filters_to_prune_per_layer[l]:
                filters_to_prune.append((l, i))
        '''
        
        filters_to_prune_per_layer is ->
        {28: [135, 155, 177, 214, 229, 249, 402, 491], 26: [88], 21: [37, 225, 301, 305], 
        24: [83, 137, 191, 197, 236, 289, 426], 5: [64], 17: [9, 145, 151, 438], 7: [41], 
        19: [50, 402, 411, 491], 14: [116], 12: [85]}
        
        filters_to_prune is -> (32)
        [(28, 135), (28, 155), (28, 177), (28, 214), (28, 229), (28, 249), (28, 402), (28, 491), 
        (26, 88), (21, 37), (21, 225), (21, 301), (21, 305), (24, 83), (24, 137), (24, 191), (24, 197), 
        (24, 236), (24, 289), (24, 426), (5, 64), (17, 9), (17, 145), (17, 151), (17, 438), (7, 41), 
        (19, 50), (19, 402), (19, 411), (19, 491), (14, 116), (12, 85)]
        '''
        return filters_to_prune             

class PrunningFineTuner_VGG16:
    def __init__(self, train_path, test_path, model):
        self.train_data_loader = loader(train_path)
        self.test_data_loader = test_loader(test_path)

        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()
        self.prunner = FilterPrunner(self.model) 
        self.model.train()

    def test(self):
        self.model.eval()
        correct = 0
        total = 0

        for i, (batch, label) in enumerate(self.test_data_loader):
            batch = batch.cuda()
            output = model(Variable(batch))
            pred = output.data.max(1)[1]
            correct += pred.cpu().eq(label).sum()
            total += label.size(0)
        
        print("Accuracy :", float(correct) / total)
        
        self.model.train()

    def train(self, optimizer = None, epoches=10):
        if optimizer is None:
            optimizer = optim.SGD(model.classifier.parameters(), lr=0.0001, momentum=0.9)

        for i in range(epoches):
            print("Epoch: ", i)
            self.train_epoch(optimizer)
            self.test()
        print("Finished fine tuning.")
        

    def train_batch(self, optimizer, batch, label, rank_filters):

        batch = batch.cuda()
        label = label.cuda()

        self.model.zero_grad()
        input = Variable(batch)

        if rank_filters:
            output = self.prunner.forward(input)
            self.criterion(output, Variable(label)).backward()
        else:
            self.criterion(self.model(input), Variable(label)).backward()
            optimizer.step()

    def train_epoch(self, optimizer = None, rank_filters = False):
        for i, (batch, label) in enumerate(self.train_data_loader):
            self.train_batch(optimizer, batch, label, rank_filters)

    def get_candidates_to_prune(self, num_filters_to_prune):
        self.prunner.reset() #filter_ranks = {}
        self.train_epoch(rank_filters = True) #do a forward and backward pass to rank the filters
        self.prunner.normalize_ranks_per_layer()
        return self.prunner.get_prunning_plan(num_filters_to_prune)
        
    def total_num_filters(self):
        filters = 0
        for name, module in self.model.features._modules.items():
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                filters = filters + module.out_channels
        return filters

    def prune(self,num_filters_to_prune_per_iteration=512,percentage_to_prune=67):
        self.test() #Get the accuracy before prunning
        self.model.train()

        #Make sure all the layers are trainable
        for param in self.model.features.parameters():
            param.requires_grad = True
        number_of_filters = self.total_num_filters() #total filters in conv layers in the NN before pruning. 4224
        iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration) #8

        iterations = int(iterations * percentage_to_prune * 0.01) #5

        print("Number of prunning iterations to remove "+ str(percentage_to_prune) +"% filters : ", iterations)

        for _ in range(iterations):
            print("Ranking filters.. ")
            prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration)
            '''
            prune_targets is list of tuples -> 
            [(28, 66), (28, 100), (28, 151), (28, 285), (28, 305), (28, 321), (28, 346), (28, 357), (26, 67), 
            (26, 84), (26, 89), (26, 223), (26, 224), (26, 224), (26, 432), (17, 41), (17, 117), (17, 258), 
            (14, 18), (14, 47), (24, 52), (24, 52), (12, 160), (19, 0), (19, 189), (19, 209), (19, 334), 
            (19, 457), (2, 3), (21, 496), (10, 36), (10, 53)]
            '''
            layers_prunned = {} #just for printing below. no real use
            for layer_index, filter_index in prune_targets: #better to use default dict
                if layer_index not in layers_prunned:
                    layers_prunned[layer_index] = 0
                layers_prunned[layer_index] = layers_prunned[layer_index] + 1 
            '''
            layers_prunned - {17: 7, 28: 8, 26: 6, 21: 5, 19: 1, 24: 2, 14: 2, 7: 1}
            '''
            print("Layer number : number of filters in that layer that will be prunned", layers_prunned)
        
            print("Prunning filters.. ")
            model = self.model.cpu()
            for layer_index, filter_index in prune_targets:
                model = prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=True)

            self.model = model
            self.model = self.model.cuda()

            message = str(100*float(self.total_num_filters()) / number_of_filters) + "%"
            print("Filters prunned", str(message))
            self.test()
            print("Fine tuning to recover from prunning iteration.")
            optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
            self.train(optimizer, epoches = 10)


        print("Finished. Going to fine tune the model a bit more")
        self.train(optimizer, epoches=15)
        #torch.save({'state_dict': model.state_dict()}, 'model_prunned.pt')

In [3]:
#helper functions
        
def replace_layers(model, i, indexes, layers):
    if i in indexes:
        return layers[indexes.index(i)]
    return model[i]

def prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=False):
    _, conv = list(model.features._modules.items())[layer_index] #pluck out the current conv layer
    '''
    print(list(model.features._modules.items())[layer_index])
    > ('19', Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))

    model.features._modules is an ordered dict of {conv layer number : conv layer object}
    model.features._modules.items() gives a iterator over contents of dict
    list(model.features._modules.items()) converts ordered dict into list of [(key,value)] ....
    '''
    next_conv = None
    offset = 1 
    while layer_index + offset <  len(model.features._modules.items()): #check  if layer number (layer_index + offset) would be a valid layer
        res =  list(model.features._modules.items())[layer_index+offset] #pluck out one of the next layers
        if isinstance(res[1], torch.nn.modules.conv.Conv2d): #check  if one of the next layers is a conv layer
            next_name, next_conv = res #if so, save it and break the loop
            break
        offset = offset + 1
    # above is done because there can be case of conv-relu-conv ; conv-relu-maxpool-conv; conv-batchnorm-relu-maxpool-conv 
    # many cases possible. makes the "next conv" search architecture agnostic
    # next, create a new conv layer for current conv
    new_conv = \
        torch.nn.Conv2d(in_channels = conv.in_channels, \
            out_channels = conv.out_channels - 1, #main change
            kernel_size = conv.kernel_size, \
            stride = conv.stride,
            padding = conv.padding,
            dilation = conv.dilation,
            groups = conv.groups,
            bias = (conv.bias is not None)) #bias is boolean arg. use bias of original conv had bias

    old_weights = conv.weight.data.cpu().numpy()
    new_weights = new_conv.weight.data.cpu().numpy() #would be random to begin with. xavier init

    new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :]
    new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :] #skip filter_index weights in old_weights
    #now that new_weights are made, save them to new_conv
    new_conv.weight.data = torch.from_numpy(new_weights) 
    new_conv.weight.data = new_conv.weight.data.cuda() #hard code use cuda

    bias_numpy = conv.bias.data.cpu().numpy() #old bias

    bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32) #one less bias. initialize a vector of zeros of same size
    bias[:filter_index] = bias_numpy[:filter_index]
    bias[filter_index : ] = bias_numpy[filter_index + 1 :] #skip bias_numpy[filter_index]
    new_conv.bias.data = torch.from_numpy(bias) #save to new_conv
    new_conv.bias.data = new_conv.bias.data.cuda() #hardcode cuda

    if not next_conv is None: #if next layer is a conv
        #make a new conv with one less in_channels
        next_new_conv = \
            torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\
                out_channels =  next_conv.out_channels, \
                kernel_size = next_conv.kernel_size, \
                stride = next_conv.stride,
                padding = next_conv.padding,
                dilation = next_conv.dilation,
                groups = next_conv.groups,
                bias = (next_conv.bias is not None))

        old_weights = next_conv.weight.data.cpu().numpy()
        new_weights = next_new_conv.weight.data.cpu().numpy()

        new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
        new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :]
        next_new_conv.weight.data = torch.from_numpy(new_weights)
        next_new_conv.weight.data = next_new_conv.weight.data.cuda()

        next_new_conv.bias.data = next_conv.bias.data  #no change is bias. copy over
        
    #if not next_conv is None: #redundant i think. 
        features_list=[replace_layers(model.features, i, [layer_index, layer_index+offset], \
                    [new_conv, next_new_conv]) for i, _ in enumerate(model.features)]
        features = torch.nn.Sequential(*(features_list)) #stitch the network back together
        del model.features #free memory
        del conv

        model.features = features

    else:
        #Prunning the last conv layer. This affects the first linear layer of the classifier.
        # now model.features would have only 1 new conv layer to attach in
        model.features = torch.nn.Sequential(
                *(replace_layers(model.features, i, [layer_index], \
                    [new_conv]) for i, _ in enumerate(model.features)))
        layer_index = 0
        old_linear_layer = None
        for _, module in model.classifier._modules.items():
            if isinstance(module, torch.nn.Linear):
                old_linear_layer = module
                break #find the first occurance of linear_layer and break
            layer_index = layer_index  + 1

        if old_linear_layer is None:
            raise BaseException("No linear laye found in classifier")
        params_per_input_channel = old_linear_layer.in_features // conv.out_channels

        new_linear_layer = \
            torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, 
                old_linear_layer.out_features)
        
        old_weights = old_linear_layer.weight.data.cpu().numpy()
        new_weights = new_linear_layer.weight.data.cpu().numpy()        

        new_weights[:, : filter_index * params_per_input_channel] = \
            old_weights[:, : filter_index * params_per_input_channel]
        new_weights[:, filter_index * params_per_input_channel :] = \
            old_weights[:, (filter_index + 1) * params_per_input_channel :]
        
        new_linear_layer.bias.data = old_linear_layer.bias.data #bias remains same

        new_linear_layer.weight.data = torch.from_numpy(new_weights) #save the new weights in FC layer object
        new_linear_layer.weight.data = new_linear_layer.weight.data.cuda() #hard code cuda

        classifier = torch.nn.Sequential(
            *(replace_layers(model.classifier, i, [layer_index], \
                [new_linear_layer]) for i, _ in enumerate(model.classifier)))

        del model.classifier
        del next_conv
        del conv
        model.classifier = classifier

    return model

def loader(path, batch_size=32, num_workers=4, pin_memory=True):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data.DataLoader(
        datasets.ImageFolder(path,
                             transforms.Compose([
                                 transforms.Resize(256),
                                 transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory)

def test_loader(path, batch_size=32, num_workers=4, pin_memory=True):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data.DataLoader(
        datasets.ImageFolder(path,
                             transforms.Compose([
                                 transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory)

In [4]:
model = ModifiedVGG16Model()
model = model.cuda()
initial_training_obj = PrunningFineTuner_VGG16("train", "test", model)
initial_training_obj.train(epoches=2)

Epoch:  0
Accuracy : 0.9625
Epoch:  1
Accuracy : 0.97
Finished fine tuning.


In [5]:
initial_training_obj.test()

Accuracy : 0.97


In [6]:
torch.save({'state_dict': model.state_dict()}, 'checkpoint_models/trained_model_state.pt') #save

## prune

In [8]:
model = ModifiedVGG16Model()
model = model.cuda()
checkpoint = torch.load('checkpoint_models/trained_model_state.pt') #load
model.load_state_dict(checkpoint['state_dict'])
model=model.cuda()
pruner_obj = PrunningFineTuner_VGG16("train", "test", model)
pruner_obj.test()
pruner_obj.prune(percentage_to_prune=20)


Accuracy : 0.97
Accuracy : 0.97
Number of prunning iterations to remove 20% filters :  1
Ranking filters.. 
Layer number : number of filters in that layer that will be prunned {14: 11, 28: 216, 21: 38, 17: 35, 26: 88, 24: 48, 19: 34, 2: 2, 12: 11, 5: 5, 10: 15, 0: 3, 7: 6}
Prunning filters.. 
Filters prunned 87.87878787878788%
Accuracy : 0.9625
Fine tuning to recover from prunning iteration.
Epoch:  0
Accuracy : 0.97625
Epoch:  1
Accuracy : 0.98
Epoch:  2
Accuracy : 0.98375
Epoch:  3
Accuracy : 0.97625
Epoch:  4
Accuracy : 0.98
Epoch:  5
Accuracy : 0.9775
Epoch:  6
Accuracy : 0.97875
Epoch:  7
Accuracy : 0.96
Epoch:  8
Accuracy : 0.985
Epoch:  9
Accuracy : 0.98375
Finished fine tuning.
Finished. Going to fine tune the model a bit more
Epoch:  0
Accuracy : 0.98875
Epoch:  1
Accuracy : 0.98125
Epoch:  2
Accuracy : 0.98125
Epoch:  3
Accuracy : 0.98625
Epoch:  4
Accuracy : 0.98375
Epoch:  5
Accuracy : 0.9825
Epoch:  6
Accuracy : 0.98375
Epoch:  7
Accuracy : 0.9825
Epoch:  8
Accuracy : 0.98

In [9]:
torch.save(model, 'checkpoint_models/pruned_model.pt') 
#save entire model since this is a custom model which would be lost after kernel stops. 
# will have to reprune to get it back. not deterministic. 
#ignore the warning below

  "type " + obj.__name__ + ". It won't be checked "


In [10]:
pruned_model=torch.load('checkpoint_models/pruned_model.pt') 
model = pruned_model.cuda()
pruner_obj = PrunningFineTuner_VGG16("train", "test", model)
pruner_obj.test()


Accuracy : 0.98375


### helpers
ipdb.set_trace()