In [13]:
import timm
import torch
import torch.nn as nn
from collections import OrderedDict
import re
import numpy as np
from torchsummary import summary
from collections import defaultdict, OrderedDict
import ipdb
import torch_pruning as tp

In [2]:




class VGG(nn.Module):
    ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

    def __init__(self) -> None:
        super().__init__()

        layers = []
        counts = defaultdict(int)

        def add(name: str, layer: nn.Module) -> None:
            layers.append((f"{name}{counts[name]}", layer))
            counts[name] += 1

        in_channels = 3
        for x in self.ARCH:
            if x != 'M':
                # conv-bn-relu
                add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
                add("bn", nn.BatchNorm2d(x))
                add("relu", nn.ReLU(True))
                in_channels = x
            else:
                # maxpool
                add("pool", nn.MaxPool2d(2))

        self.backbone = nn.Sequential(OrderedDict(layers))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
        x = self.backbone(x)

        # avgpool: [N, 512, 2, 2] => [N, 512]
        x = x.mean([2, 3])

        # classifier: [N, 512] => [N, 10]
        x = self.classifier(x)
        return x


fusing_layers = [
    'Conv2d',
    'BatchNorm2d',
    'ReLU',
    'Linear',
    'BatchNorm1d',
]

import copy








def summary_string_fixed(model, all_layers, input_size, model_name=None, batch_size=-1, dtypes=None):
    if dtypes is None:
        dtypes = [torch.FloatTensor] * len(input_size)
        
    def name_fixer(names):

        return_list = []
        for string in names:
            matches = re.finditer(r'\.\[(\d+)\]', string)
            pop_list = [m.start(0) for m in matches]
            pop_list.sort(reverse=True)
            if len(pop_list) > 0:
                string = list(string)
                for pop_id in pop_list:
                    string.pop(pop_id)
                string = ''.join(string)
            return_list.append(string)
        return return_list

    def register_hook(module, module_idx):
        def hook(module, input, output):
            nonlocal module_idx
            m_key = all_layers[module_idx][0]
            m_key = model_name + "." + m_key

            try:
                eval(m_key)
            except:
                m_key = name_fixer([m_key])[0]

            summary[m_key] = OrderedDict()
            summary[m_key]["type"] = str(type(module)).split('.')[-1][:-2]
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size

            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
                summary[m_key]["weight_shape"] = module.weight.shape
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
                not isinstance(module, nn.Sequential)
                and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    if isinstance(input_size, tuple):
        input_size = [input_size]

    x = [torch.rand(2, *in_size).type(dtype)
         for in_size, dtype in zip(input_size, dtypes)]

    summary = OrderedDict()
    hooks = []

    for module_idx, (layer_name, module) in enumerate(all_layers):
        register_hook(module, module_idx)

    model(*x)

    for h in hooks:
        h.remove()

    return summary








import copy


@torch.no_grad()
def prune_cwp(model, pruning_ratio_list):
    
    
    def get_importance(layer, sparsity):

        def get_input_channel_importance(weight):
            importances = []
            for i_c in range(weight.shape[1]):
                channel_weight = weight.detach()[:, i_c]
                importance = torch.norm(channel_weight)
                importances.append(importance.view(1))
            return torch.cat(importances)

        sorted_indices = torch.argsort(get_input_channel_importance(layer.weight), descending=True)
        n_keep = int(round(len(sorted_indices) * (1.0 - sparsity)))
        indices_to_keep = sorted_indices[:n_keep]
        return indices_to_keep
    
    pruned_model = copy.deepcopy(model)
    

    def get_layer_name(obj):
        if isinstance(obj, list):
            layer_list = []
            for internal_layer in obj:
                layer_list.append(eval(internal_layer.replace('model', 'pruned_model')))
            return layer_list
        else:
            nonlocal pruned_model
            return eval(obj.replace('model', 'pruned_model'))

    for list_ind in range(len(possible_indices_ranges)):
        sparsity = pruning_ratio_list[list_ind]
        layer_list = np.asarray(possible_indices_ranges[list_ind])

        prev_conv = get_layer_name(layer_list[0, 0])
        prev_bn = get_layer_name(layer_list[0, 1])
        next_convs = [c for c in get_layer_name(list(layer_list[1:, 0]))]
        next_bns = [b for b in get_layer_name(list(layer_list[1:-1, 1]))]  # Avoid last 0

        if (len(next_bns) == 0):
            iter_layers = zip([prev_conv, prev_bn], [next_convs, []])
        else:
            iter_layers = zip([prev_conv, prev_bn], [next_convs, next_bns])

        importance_list_indices = get_importance(layer=next_convs[-1], sparsity=sparsity)

        def prune_bn(layer, importance_list_indices):

            layer.weight.set_(layer.weight.detach()[importance_list_indices])
            layer.bias.set_(layer.bias.detach()[importance_list_indices])
            layer.running_mean.set_(layer.running_mean.detach()[importance_list_indices])
            layer.running_var.set_(layer.running_var.detach()[importance_list_indices])

        for prev_layer, next_layers in iter_layers:
            if(prev_layer != 0): #No BatchNorm Used:
                
                if (str(type(prev_layer)).split('.')[-1][:-2] == 'BatchNorm2d'):  # BatchNorm2d
                    prune_bn(prev_layer, importance_list_indices)
                else:
                    prev_layer.weight.set_(prev_conv.weight.detach()[importance_list_indices, :])
                    if prev_layer.bias is not None:
                                bias_shape = prev_layer.weight.shape[0]
                                prev_layer.bias = nn.Parameter(prev_layer.bias[:bias_shape])
                    

            if (len(next_layers) != 0):
                for next_layer in next_layers:
                    if (str(type(next_layer)).split('.')[-1][:-2] == 'BatchNorm2d'):  # BatchNorm2d
                        prune_bn(next_layer, importance_list_indices)
                    else:
                        if (next_layer.weight.shape[1] == 1):
                            

                            next_layer.weight.set_(next_layer.weight.detach()[importance_list_indices, :])
                            number_of_channels = len(importance_list_indices)
                            next_layer.groups = number_of_channels
                           
                            if next_layer.bias is not None:
                                bias_shape = next_layer.weight.shape[0]
                                next_layer.bias = nn.Parameter(next_layer.bias[:bias_shape])
     
                              
                        else:
                            
                            next_layer.weight.set_(next_layer.weight.detach()[:, importance_list_indices])
#                             next_layer.groups = len(importance_list_indices)
                            if next_layer.bias is not None:
                                bias_shape = next_layer.weight.shape[0]
                                next_layer.bias = nn.Parameter(next_layer.bias[:bias_shape])
                
#                             if(next_layer.bias!=None):
#                                 next_layer.bias.set_(next_layer.bias[:next_layer.weight.shape[1]])
#                                 print(len(next_layer.bias))

    return pruned_model, model


def layer_mapping(model):
    
    def get_all_layers(model, parent_name=''):
        layers = []

        def reformat_layer_name(str_data):
            try:
                split_data = str_data.split('.')
                for ind in range(len(split_data)):
                    data = split_data[ind]
                    if (data.isdigit()):
                        split_data[ind] = "[" + data + "]"
                final_string = '.'.join(split_data)

                iters_a = re.finditer(r'[a-zA-Z]\.\[', final_string)
                indices = [m.start(0) + 1 for m in iters_a]
                iters = re.finditer(r'\]\.\[', final_string)
                indices.extend([m.start(0) + 1 for m in iters])

                final_string = list(final_string)
                final_string = [final_string[i] for i in range(len(final_string)) if i not in indices]

                str_data = ''.join(final_string)

            except:
                pass

            return str_data

        for name, module in model.named_children():
            full_name = f"{parent_name}.{name}" if parent_name else name
            test_name = "model." + full_name
            try:
                eval(test_name)
                layers.append((full_name, module))
            except:
                layers.append((reformat_layer_name(full_name), module))
            if isinstance(module, nn.Module):
                layers.extend(get_all_layers(module, parent_name=full_name))
        return layers
    all_layers = get_all_layers(model)
    model_summary = summary_string_fixed(model, all_layers, (3, 64, 64), model_name='model')  # , device="cuda")

    name_type_shape = []
    for key in model_summary.keys():
        data = model_summary[key]
        if ("weight_shape" in data.keys()):
            name_type_shape.append([key, data['type'], data['weight_shape'][0]])
        #     else:
    #         name_type_shape.append([key, data['type'], 0 ])
    name_type_shape = np.asarray(name_type_shape)

    name_list = name_type_shape[:, 0]

    r_name_list = np.asarray(name_list)
    random_picks = np.random.randint(0, len(r_name_list), 10)
    test_name_list = r_name_list[random_picks]
    eval_hit = False
    for layer in test_name_list:
        try:
            eval(layer)

        except:
            eval_hit = True
            break
    if (eval_hit):
        fixed_name_list = name_fixer(r_name_list)
        name_type_shape[:, 0] = fixed_name_list

    layer_types = name_type_shape[:, 1]
    layer_shapes = name_type_shape[:, 2]
    mapped_layers = {'model_layer': [], 'Conv2d_BatchNorm2d_ReLU': [], 'Conv2d_BatchNorm2d': [], 'Linear_ReLU': [],
                     'Linear_BatchNorm1d': []}

    def detect_sequences(lst):
        i = 0
        while i < len(lst):

            if i + 2 < len(lst) and [l for l in lst[i: i + 3]] == [
                fusing_layers[0],
                fusing_layers[1],
                fusing_layers[2],
            ]:
                test_layer = layer_shapes[i: i + 2]
                if (np.all(test_layer == test_layer[0])):
                    mapped_layers['Conv2d_BatchNorm2d_ReLU'].append(
                        np.take(name_list, [i for i in range(i, i + 3)]).tolist()
                    )
                    i += 3

            elif i + 1 < len(lst) and [l for l in lst[i: i + 2]] == [
                fusing_layers[0],
                fusing_layers[1],
            ]:
                test_layer = layer_shapes[i: i + 2]
                if (np.all(test_layer == test_layer[0])):
                    mapped_layers['Conv2d_BatchNorm2d'].append(
                        np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                    )
                    i += 2
            # if i + 1 < len(lst) and [ type(l) for l in lst[i:i+2]] == [fusing_layers[0], fusing_layers[2]]:
            #     detected_sequences.append(np.take(name_list,[i for i in range(i,i+2)]).tolist())
            #     i += 2
            # elif i + 1 < len(lst) and [ type(l) for l in lst[i:i+2]] == [fusing_layers[1], fusing_layers[2]]:
            #     detected_sequences.append(np.take(name_list,[i for i in range(i,i+2)]).tolist())
            #     i += 2
            elif i + 1 < len(lst) and [l for l in lst[i: i + 2]] == [
                fusing_layers[3],
                fusing_layers[2],
            ]:
                mapped_layers['Linear_ReLU'].append(
                    np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                )
                i += 2
            elif i + 1 < len(lst) and [l for l in lst[i: i + 2]] == [
                fusing_layers[3],
                fusing_layers[4],
            ]:
                mapped_layers['Linear_BatchNorm1d'].append(
                    np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                )
                i += 2
            else:
                i += 1

    detect_sequences(layer_types)

    for keys, value in mapped_layers.items():
        mapped_layers[keys] = np.asarray(mapped_layers[keys])

    mapped_layers['name_type_shape'] = name_type_shape
    # self.mapped_layers = mapped_layers

    # CWP
    keys_to_lookout = ['Conv2d_BatchNorm2d_ReLU', 'Conv2d_BatchNorm2d']
    pruning_layer_of_interest, qat_layer_of_interest = [], []

    # CWP or QAT Fusion Layers
    for keys in keys_to_lookout:
        data = mapped_layers[keys]
        if (len(data) != 0):
            qat_layer_of_interest.append(data)
    mapped_layers['qat_layers'] = np.asarray(qat_layer_of_interest)

    return mapped_layers



# GMP
#         layer_of_interest=mapped_layers['name_type_shape'][:,0] # all layers with weights
#         Check for all with weights
# Wanda

# def string_fixer(name_list):
#     for ind in range(len(name_list)):
#         modified_string = re.sub(r'\.(\[)', r'\1', name_list[ind])
#         name_list[ind] = modified_string



In [3]:



def cwp_possible_layers(layer_name_list):
    possible_indices = []
    idx = 0
    
    while idx < len(layer_name_list):
        current_value = layer_name_list[idx]
        layer_shape = eval(current_value).weight.shape
        curr_merge_list = []
        curr_merge_list.append([current_value, 0])
        hit_catch = False
        for internal_idx in range(idx + 1, len(layer_name_list) - 1):
            new_layer = layer_name_list[internal_idx]
            new_layer_shape = eval(new_layer).weight.shape
            if len(new_layer_shape) == 4:
                curr_merge_list.append([new_layer, 0])
                if layer_shape[0] == new_layer_shape[1]:
                    hit_catch = True
                    break
            elif len(new_layer_shape) == 1:
                curr_merge_list[len(curr_merge_list) - 1][1] = new_layer
        possible_indices.append(curr_merge_list)
        if hit_catch == True:
            idx = internal_idx
        else:
            idx += 1
    return possible_indices




    

In [121]:

# #load the pretrained model

# resnet18 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
# densenet = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)

# super_net_name = "ofa_supernet_mbv3_w10" 

vgg = VGG()
mobilenet_v2 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
mobilenet_v3 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v3_small', pretrained=True)
# # super_net = torch.hub.load('mit-han-lab/once-for-all', super_net_name, pretrained=True)
# model_list = [ mobilenet_v3 ]

Using cache found in /Users/sathya/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/sathya/.cache/torch/hub/pytorch_vision_v0.10.0


In [122]:
pruning_ratio = 0.95

model = copy.deepcopy(mobilenet_v3)

for param in model.parameters():
    param.requires_grad = True


mapped_layers = layer_mapping(model)
name_list = mapped_layers['name_type_shape'][:, 0]

possible_indices_ranges = cwp_possible_layers(name_list)
possible_indices_ranges = [lst for lst in possible_indices_ranges if len(lst) > 1]
    
pruning_layer_length = len(possible_indices_ranges)
pruning_ratio_list = (pruning_layer_length) * [pruning_ratio]


In [123]:
name_list

array(['model.features[0][0]', 'model.features[0][1]',
       'model.features[1].block[0][0]', 'model.features[1].block[0][1]',
       'model.features[1].block[1].fc1', 'model.features[1].block[1].fc2',
       'model.features[1].block[2][0]', 'model.features[1].block[2][1]',
       'model.features[2].block[0][0]', 'model.features[2].block[0][1]',
       'model.features[2].block[1][0]', 'model.features[2].block[1][1]',
       'model.features[2].block[2][0]', 'model.features[2].block[2][1]',
       'model.features[3].block[0][0]', 'model.features[3].block[0][1]',
       'model.features[3].block[1][0]', 'model.features[3].block[1][1]',
       'model.features[3].block[2][0]', 'model.features[3].block[2][1]',
       'model.features[4].block[0][0]', 'model.features[4].block[0][1]',
       'model.features[4].block[1][0]', 'model.features[4].block[1][1]',
       'model.features[4].block[2].fc1', 'model.features[4].block[2].fc2',
       'model.features[4].block[3][0]', 'model.features[4].block[

In [124]:
conv_layers = []
for ind in range(len(name_list)):
    layer_name= name_list[ind]
    layer= eval(layer_name)
    if(isinstance(layer, torch.nn.modules.conv.Conv2d)):
        conv_layers.append(layer_name)
        

In [125]:
conv_layers

['model.features[0][0]',
 'model.features[1].block[0][0]',
 'model.features[1].block[1].fc1',
 'model.features[1].block[1].fc2',
 'model.features[1].block[2][0]',
 'model.features[2].block[0][0]',
 'model.features[2].block[1][0]',
 'model.features[2].block[2][0]',
 'model.features[3].block[0][0]',
 'model.features[3].block[1][0]',
 'model.features[3].block[2][0]',
 'model.features[4].block[0][0]',
 'model.features[4].block[1][0]',
 'model.features[4].block[2].fc1',
 'model.features[4].block[2].fc2',
 'model.features[4].block[3][0]',
 'model.features[5].block[0][0]',
 'model.features[5].block[1][0]',
 'model.features[5].block[2].fc1',
 'model.features[5].block[2].fc2',
 'model.features[5].block[3][0]',
 'model.features[6].block[0][0]',
 'model.features[6].block[1][0]',
 'model.features[6].block[2].fc1',
 'model.features[6].block[2].fc2',
 'model.features[6].block[3][0]',
 'model.features[7].block[0][0]',
 'model.features[7].block[1][0]',
 'model.features[7].block[2].fc1',
 'model.featur

In [126]:
summary(model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
         Hardswish-3           [-1, 16, 16, 16]               0
            Conv2d-4             [-1, 16, 8, 8]             144
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12             [-1, 16, 8, 8]               0
           Conv2d-13             [-1, 16, 8, 8]             256
      BatchNorm2d-14             [-1, 1

In [127]:
eval('model.features[11].block[2].fc1').weight.shape, eval('model.features[11].block[2].fc2').weight.shape, eval('model.features[11].block[3][0]').weight.shape

(torch.Size([144, 576, 1, 1]),
 torch.Size([576, 144, 1, 1]),
 torch.Size([96, 576, 1, 1]))

In [128]:
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs =[2, 6, 9,0,13,16,15,20,23,7,32]

pruning_group = DG.get_pruning_group( eval('model.features[11].block[2].fc2'), tp.prune_conv_out_channels, idxs=pruning_idxs )

# 3. prune all grouped layer that is coupled with model.conv1
if DG.check_pruning_group(pruning_group):
    pruning_group.prune()

In [129]:
pruned_model = copy.deepcopy(model)

In [130]:
torch.save(pruned_model.state_dict(),'pruned.pth')

In [131]:
eval('pruned_model.features[11].block[2].fc1').weight.shape, eval('pruned_model.features[11].block[2].fc2').weight.shape, eval('pruned_model.features[11].block[3][0]').weight.shape

(torch.Size([144, 565, 1, 1]),
 torch.Size([565, 144, 1, 1]),
 torch.Size([96, 565, 1, 1]))

In [132]:
summary(pruned_model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
         Hardswish-3           [-1, 16, 16, 16]               0
            Conv2d-4             [-1, 16, 8, 8]             144
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12             [-1, 16, 8, 8]               0
           Conv2d-13             [-1, 16, 8, 8]             256
      BatchNorm2d-14             [-1, 1

In [59]:
# 1. build dependency graph for resnet18
model = copy.deepcopy(mobilenet_v3)

correct = []
example_inputs = torch.randn(1,3,32,32)
for layer_ind in range(len(conv_layers)):
    layer = conv_layers[layer_ind]
    DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

    # 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
    pruning_idxs = pruning_idxs=[2, 6, 9]
    try:
        print(layer_ind)
        pruning_group = DG.get_pruning_group( layer, tp.prune_conv_out_channels, idxs=pruning_idxs )

        # 3. prune all grouped layer that is coupled with model.conv1
        if DG.check_pruning_group(pruning_group):
            pruning_group.prune()
        correct.append(layer_ind)
        
    except:
        print("L:",layer_ind)

0
L: 0
1
L: 1
2
L: 2
3
L: 3
4
L: 4
5
L: 5
6
L: 6
7
L: 7
8
L: 8
9
L: 9
10
L: 10
11
L: 11
12
L: 12
13
L: 13
14
L: 14
15
L: 15
16
L: 16
17
L: 17
18
L: 18
19
L: 19
20
L: 20
21
L: 21
22
L: 22
23
L: 23
24
L: 24
25
L: 25
26
L: 26
27
L: 27
28
L: 28
29
L: 29
30
L: 30
31
L: 31
32
L: 32
33
L: 33
34
L: 34
35
L: 35
36
L: 36
37
L: 37
38
L: 38
39
L: 39
40
L: 40
41
L: 41
42
L: 42
43
L: 43
44
L: 44
45
L: 45
46
L: 46
47
L: 47
48
L: 48
49
L: 49
50
L: 50
51
L: 51
