In [112]:
import timm
from torchsummary import summary
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import collections
from collections import defaultdict
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import ipdb
import timm
from torchvision.datasets import CIFAR10
import torch_pruning as tp
import torchvision.models as models
import time

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import brevitas.nn as qnn


import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)




# check device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class config:
    lr = 1e-4
    n_classes = 10
    epochs = 2
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batch_size = 64

In [107]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


train_dataset = CIFAR10(root='data/', download=True, transform=transform_train)
valid_dataset = CIFAR10(root='data/',  download=True,train=False, transform=transform_test)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=config.batch_size, shuffle=False)




Files already downloaded and verified
Files already downloaded and verified


In [128]:


def training(model, train_loader, valid_loader):
    
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    criterion = nn.CrossEntropyLoss()
    model = model.to(config.device)
    
    model.train()
    validation_acc = 0
    ########### Train  ###############
    for ep in range(config.epochs):
        running_loss = 0
        running_acc = 0
        for batch_idx, data in enumerate(train_loader):
            optim.zero_grad()
            image, label = data
            image, label = image.to(config.device), label.to(config.device)
            out = model(image)
            loss = criterion(out, label)
            acc = torch.argmax(out, 1) - label
            running_acc+= len(acc[acc==0])
            running_loss+= loss.item() * label.size(0)
            loss.backward()
            optim.step()
        epoch_loss = running_loss/ len(train_loader.dataset)
        train_epoch_acc = running_acc/ len(train_loader.dataset)
        
        
        ########### Validate #############


        print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {ep+1}\t'
                  f'Train loss: {epoch_loss:.4f}\t'
                  f'Train accuracy: {100 * train_epoch_acc:.2f}\t')
        val_acc = validate(model, valid_loader)
        if(validation_acc< val_acc):
            validation_acc = val_acc
            torch.save(model.state_dict(), './weights/dl-x25'+str(validation_acc)+'.pth')
            
    print("Final Validation Accuracy:", validation_acc, "\n \n")
    return model

def validate(model, valid_loader):
    val_running_acc = 0
    model.eval()
    for batch_idx, data in enumerate(valid_loader):

        image, label = data
        image, label = image.to(config.device), label.to(config.device)
        out = model(image)
        acc = torch.argmax(out, 1) - label
        val_running_acc+= len(acc[acc==0])
    val_epoch_acc = val_running_acc/ len(valid_loader.dataset)


    print(f'{datetime.now().time().replace(microsecond=0)} --- '
              f'Valid accuracy: {100 * val_epoch_acc:.2f}')
    return val_epoch_acc
        
            


In [152]:
number_conv_layers = 0


def universal_layer_identifier(identification, module, module_name):
    if(identification == None):
        if isinstance(module, nn.Conv2d):
            return True
    else:
        if (identification in module_name):
            return True
def universal_get_layer_id_pruning(model, pruning_type, pruning_percent, identification=None): #identification = none
    if(pruning_type == 'L1'):
        strategy  = tp.strategy.L1Strategy()
    elif(pruning_type == 'L2'):
        strategy  = tp.strategy.L2Strategy()
    channels_pruned = []
    
    def find_instance(obj):
        if isinstance(obj, nn.Conv2d):
            pruning_idx = strategy(obj.weight, amount = pruning_percent)
            channels_pruned.append(pruning_idx)
            return None
        elif isinstance(obj, list):
            for internal_obj in obj:
                find_instance(internal_obj)
        elif (hasattr(obj, '__class__')):
            for internal_obj in obj.children():
                find_instance(internal_obj)
        elif isinstance(obj, OrderedDict):
            for key, value in obj.items():
                find_instance(value)

    find_instance(model)

    channels_pruned = np.asarray(channels_pruned, dtype=object)
    return channels_pruned



def universal_filter_pruning(model, input_shape, channels_pruned, identification=None):
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs= torch.randn(input_shape).to(config.device))

    layer_id = 0

    def find_instance(obj):
        if isinstance(obj, nn.Conv2d):
            global number_conv_layers
            number_conv_layers+=obj.out_channels
            pruning_plan = DG.get_pruning_plan(obj, tp.prune_conv_out_channel, idxs=channels_pruned[layer_id])
            pruning_plan.exec()
            return None
        elif isinstance(obj, list):
            for internal_obj in obj:
                find_instance(internal_obj)
        elif (hasattr(obj, '__class__')):
            for internal_obj in obj.children():
                find_instance(internal_obj)
        elif isinstance(obj, OrderedDict):
            for key, value in obj.items():
                find_instance(value)

    find_instance(model)
    return model








In [156]:
from thop import profile, clever_format


def pruner(model,experiment_name, config, input_dims, pruning_stratergy, pruning_percent,  train_loader, valid_loader):
    original_model = model
    input = torch.randn((config.batch_size, )+ input_dims).to(config.device)
    
    macs_original, params_original = profile(original_model, inputs=(input, ))
#     macs_original, params_original = clever_format([macs_original, params_original], "%.3f")
    
    torch.save(original_model.state_dict(), './weights/'+experiment_name+'.pth')
    print("\n \n Original Validation Accuracy: \n \n")
    validate(model, valid_loader)
    
    channels_pruned = universal_get_layer_id_pruning(model, pruning_stratergy, pruning_percent)
    print("\n \n ################################# Post Purning ################################# \n \n ")
    print("Original Conv Layers in the Model:", number_conv_layers ,"\n Number of Layers Selected:", len(channels_pruned), "\n Number of Filters Pruned:",sum([len(x) for x in channels_pruned]))
    pruned_model = universal_filter_pruning(model, (config.batch_size,)+input_dims, channels_pruned).to(config.device)
    pruned_model = training(pruned_model, train_loader, valid_loader)
    torch.save(pruned_model, './weights/'+experiment_name+'_pruned_model.pth')
    torch.save(pruned_model.state_dict(), './weights/'+experiment_name+'_pruned.pth')
    print("\n \n Pruned Validation Accuracy: \n \n")
    validate(pruned_model, valid_loader)
    
    print("\n \n ################################# MAC's and Parameters Comparison #################################")

    
    macs_pruned, params_pruned = profile(pruned_model, inputs=(input, ))

    print("\n \n Original Model MAC's and Params:",macs_original, params_original)
    print("Pruned Model MAC's and Params:",macs_pruned, params_pruned )




In [158]:

resnet = timm.create_model('resnet18', num_classes=10).to("cuda")
  # Pretrained Model
resnet.load_state_dict(torch.load('/home/beast/Downloads/resnet18.pth'))


pruner(resnet,"resnet18", config,(3,32,32), "L2", 0.06,  train_loader, valid_loader)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.

 
 Original Validation Accuracy: 
 

20:32:28 --- Valid accuracy: 72.75

 
 ################################# Post Purning ################################# 
 
 
Original Conv Layers in the Model: 4764 
 Number of Layers Selected: 20 
 Number of Filters Pruned: 275
20:32:48 --- Epoch: 1	Train loss: 0.9083	Train accuracy: 68.01	
20:32:50 --- Valid accuracy: 72.65
20:33:11 --- Epoch: 2	Train loss: 0.7704	Train accuracy: 72.79	

In [6]:
                              ######### Original Model  #########
# original_model = timm.create_model('resnet18', num_classes=10).to("cuda")
# original_model.load_state_dict(torch.load('./weights/resnet18.pth'), strict=False)
# validate(original_model, valid_loader)
# summary(original_model,(3,32,32))



In [7]:
                                ######### Pruned Model  #########
# pruned_model = torch.load("./weights/resnet18_pruned_model.pth")
# pruned_model.load_state_dict(torch.load('./weights/resnet18_pruned.pth'), strict=False)
# validate(pruned_model, valid_loader)
# summary(pruned_model,(3,32,32))
    

In [140]:
pruned_model = torch.load("./weights/resnet18_pruned_model.pth").to("cuda")
pruned_model.eval()



ResNet(
  (conv1): Conv2d(3, 55, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(55, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(55, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(61, 55, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(55, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(55, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(61, eps=1e-05, m

In [141]:
                              ######## Pick the layers to fuse ########
layer_list = []
layer_fuse_list= []
flag =-1

for name, layer in pruned_model.named_modules():
    
    if(isinstance(layer, nn.ReLU)):
            
        layer_fuse_list.append(name)
        layer_list.append(layer_fuse_list)
        layer_fuse_list = []

    if ((len(layer_fuse_list)<2) and (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.BatchNorm2d))):
        
        layer_fuse_list.append(name)
        
        

optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
pruned_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
pruned_model_prepared = torch.quantization.prepare_qat(pruned_model.train())

In [142]:
pruned_model_prepared = training(pruned_model_prepared, train_loader, valid_loader)

14:26:37 --- Epoch: 1	Train loss: 0.7816	Train accuracy: 72.40	
14:26:40 --- Valid accuracy: 74.32
14:27:03 --- Epoch: 2	Train loss: 0.7209	Train accuracy: 74.48	
14:27:06 --- Valid accuracy: 75.23
Final Validation Accuracy: 0.7523 
 



In [145]:


# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
pruned_model_prepared = pruned_model_prepared.to("cpu")
pruned_model_prepared.eval()



ResNet(
  (conv1): Conv2d(
    3, 55, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0019, 0.0009, 0.0012, 0.0014, 0.0014, 0.0015, 0.0015, 0.0017, 0.0011,
              0.0014, 0.0012, 0.0011, 0.0014, 0.0008, 0.0017, 0.0014, 0.0012, 0.0011,
              0.0016, 0.0015, 0.0012, 0.0017, 0.0014, 0.0008, 0.0011, 0.0014, 0.0016,
              0.0011, 0.0016, 0.0017, 0.0013, 0.0015, 0.0018, 0.0017, 0.0013, 0.0010,
              0.0014, 0.0017, 0.0015, 0.0011, 0.0015, 0.0016, 0.0015, 0.0018, 0.0017,
              0.0015, 0.0013, 0.0020, 0.0011, 0.0014, 0.0015, 0.0013, 0.0012, 0.0011,
              0.0014]), zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0], dtype=torch.int32), dty

In [146]:
pruned_model_int8 = torch.quantization.convert(pruned_model_prepared)

In [150]:
torch.save(pruned_model_int8.state_dict(), './weights/pruned_model_int8-weights.pth')
torch.save(pruned_model_int8, './weights/pruned_model_int8.pth')

In [98]:

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
        