In [1]:
import timm
import torch
import torch.nn as nn
from collections import OrderedDict
import re
import numpy as np
from torchsummary import summary
import torchvision
from collections import defaultdict, OrderedDict
import ipdb
import torch_pruning as tp

In [2]:
def layer_mapping(model):
    
    def get_all_layers(model, parent_name=''):
        layers = []

        def reformat_layer_name(str_data):
            try:
                split_data = str_data.split('.')
                for ind in range(len(split_data)):
                    data = split_data[ind]
                    if (data.isdigit()):
                        split_data[ind] = "[" + data + "]"
                final_string = '.'.join(split_data)

                iters_a = re.finditer(r'[a-zA-Z]\.\[', final_string)
                indices = [m.start(0) + 1 for m in iters_a]
                iters = re.finditer(r'\]\.\[', final_string)
                indices.extend([m.start(0) + 1 for m in iters])

                final_string = list(final_string)
                final_string = [final_string[i] for i in range(len(final_string)) if i not in indices]

                str_data = ''.join(final_string)

            except:
                pass

            return str_data

        for name, module in model.named_children():
            full_name = f"{parent_name}.{name}" if parent_name else name
            test_name = "model." + full_name
            try:
                eval(test_name)
                layers.append((full_name, module))
            except:
                layers.append((reformat_layer_name(full_name), module))
            if isinstance(module, nn.Module):
                layers.extend(get_all_layers(module, parent_name=full_name))
        return layers
    all_layers = get_all_layers(model)
    model_summary = summary_string_fixed(model, all_layers, (3, 64, 64), model_name='model')  # , device="cuda")

    name_type_shape = []
    for key in model_summary.keys():
        data = model_summary[key]
        if ("weight_shape" in data.keys()):
            name_type_shape.append([key, data['type'], [data['weight_shape']]])
        #     else:
    #         name_type_shape.append([key, data['type'], 0 ])
    name_type_shape = np.asarray(name_type_shape)

    name_list = name_type_shape[:, 0]

    r_name_list = np.asarray(name_list)
    random_picks = np.random.randint(0, len(r_name_list), 10)
    test_name_list = r_name_list[random_picks]
    eval_hit = False
    for layer in test_name_list:
        try:
            eval(layer)

        except:
            eval_hit = True
            break
    if (eval_hit):
        fixed_name_list = name_fixer(r_name_list)
        name_type_shape[:, 0] = fixed_name_list

    layer_types = name_type_shape[:, 1]
    layer_shapes = name_type_shape[:, 2]


    return name_type_shape








def summary_string_fixed(model, all_layers, input_size, model_name=None, batch_size=-1, dtypes=None):
    if dtypes is None:
        dtypes = [torch.FloatTensor] * len(input_size)
        
    def name_fixer(names):

        return_list = []
        for string in names:
            matches = re.finditer(r'\.\[(\d+)\]', string)
            pop_list = [m.start(0) for m in matches]
            pop_list.sort(reverse=True)
            if len(pop_list) > 0:
                string = list(string)
                for pop_id in pop_list:
                    string.pop(pop_id)
                string = ''.join(string)
            return_list.append(string)
        return return_list

    def register_hook(module, module_idx):
        def hook(module, input, output):
            nonlocal module_idx
            m_key = all_layers[module_idx][0]
            m_key = model_name + "." + m_key

            try:
                eval(m_key)
            except:
                m_key = name_fixer([m_key])[0]

            summary[m_key] = OrderedDict()
            summary[m_key]["type"] = str(type(module)).split('.')[-1][:-2]
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size

            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
                summary[m_key]["weight_shape"] = module.weight.shape
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
                not isinstance(module, nn.Sequential)
                and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    if isinstance(input_size, tuple):
        input_size = [input_size]

    x = [torch.rand(2, *in_size).type(dtype)
         for in_size, dtype in zip(input_size, dtypes)]

    summary = OrderedDict()
    hooks = []

    for module_idx, (layer_name, module) in enumerate(all_layers):
        register_hook(module, module_idx)

    model(*x)

    for h in hooks:
        h.remove()

    return summary




In [3]:

class VGG(nn.Module):
    ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

    def __init__(self) -> None:
        super().__init__()

        layers = []
        counts = defaultdict(int)

        def add(name: str, layer: nn.Module) -> None:
            layers.append((f"{name}{counts[name]}", layer))
            counts[name] += 1

        in_channels = 3
        for x in self.ARCH:
            if x != 'M':
                # conv-bn-relu
                add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
                add("bn", nn.BatchNorm2d(x))
                add("relu", nn.ReLU(True))
                in_channels = x
            else:
                # maxpool
                add("pool", nn.MaxPool2d(2))

        self.backbone = nn.Sequential(OrderedDict(layers))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
        x = self.backbone(x)

        # avgpool: [N, 512, 2, 2] => [N, 512]
        x = x.mean([2, 3])

        # classifier: [N, 512] => [N, 10]
        x = self.classifier(x)
        return x

In [4]:
# model = VGG()
parent_name = ''
layers= []

def summary_string_fixed(model, all_layers, input_size, model_name=None, batch_size=-1, dtypes=None):
    if dtypes is None:
        dtypes = [torch.FloatTensor] * len(input_size)
        
    def name_fixer(names):

        return_list = []
        for string in names:
            matches = re.finditer(r'\.\[(\d+)\]', string)
            pop_list = [m.start(0) for m in matches]
            pop_list.sort(reverse=True)
            if len(pop_list) > 0:
                string = list(string)
                for pop_id in pop_list:
                    string.pop(pop_id)
                string = ''.join(string)
            return_list.append(string)
        return return_list

    def register_hook(module, module_idx):
        def hook(module, input, output):
            nonlocal module_idx
            m_key = all_layers[module_idx][0]
            m_key = model_name + "." + m_key

            try:
                eval(m_key)
            except:
                m_key = name_fixer([m_key])[0]

         
            summary[m_key] = str(type(module)).split('.')[-1][:-2]
            

        if (
                not isinstance(module, nn.Sequential)
                and not isinstance(module, nn.ModuleList)
                and not isinstance(module, torchvision.ops.misc.SqueezeExcitation)
                and not isinstance(module, torchvision.models.mobilenetv3.InvertedResidual)
        ):
            hooks.append(module.register_forward_hook(hook))

    if isinstance(input_size, tuple):
        input_size = [input_size]

    x = [torch.rand(2, *in_size).type(dtype)
         for in_size, dtype in zip(input_size, dtypes)]

    summary = OrderedDict()
    hooks = []

    for module_idx, (layer_name, module) in enumerate(all_layers):
        register_hook(module, module_idx)

    model(*x)

    for h in hooks:
        h.remove()

    return summary


def get_all_layers(model, parent_name=''):
        layers = []

        def reformat_layer_name(str_data):
            try:
                split_data = str_data.split('.')
                for ind in range(len(split_data)):
                    data = split_data[ind]
                    if (data.isdigit()):
                        split_data[ind] = "[" + data + "]"
                final_string = '.'.join(split_data)

                iters_a = re.finditer(r'[a-zA-Z]\.\[', final_string)
                indices = [m.start(0) + 1 for m in iters_a]
                iters = re.finditer(r'\]\.\[', final_string)
                indices.extend([m.start(0) + 1 for m in iters])

                final_string = list(final_string)
                final_string = [final_string[i] for i in range(len(final_string)) if i not in indices]

                str_data = ''.join(final_string)

            except:
                pass

            return str_data

        for name, module in model.named_children():
            full_name = f"{parent_name}.{name}" if parent_name else name
            test_name = "model." + full_name
            try:
                eval(test_name)
                layers.append((full_name, module))
            except:
                layers.append((reformat_layer_name(full_name), module))
            if isinstance(module, nn.Module):
                layers.extend(get_all_layers(module, parent_name=full_name))
        return layers


In [5]:
vgg = VGG()
mobilenet_v2 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
mobilenet_v3 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v3_small', pretrained=True)

Using cache found in /Users/sathya/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/sathya/.cache/torch/hub/pytorch_vision_v0.10.0


In [6]:
def get_importance(layer, sparsity):

    def get_input_channel_importance(weight):
        importances = []
        for o_c in range(weight.shape[0]):
            channel_weight = weight.detach()[o_c,:]
            importance = torch.norm(channel_weight)
            importances.append(importance.view(1))
        return torch.cat(importances)

    sorted_indices = torch.argsort(get_input_channel_importance(layer.weight), descending=True)
    n_keep = int(round(len(sorted_indices) * (1.0 - sparsity)))
    indices_to_remove = sorted_indices[n_keep:]
    
    return indices_to_remove


import copy
model = copy.deepcopy(mobilenet_v3)
all_layers = get_all_layers(model)
model_summary = summary_string_fixed(model, all_layers, (3, 64, 64), model_name='model')
conv_layer = []
for key, value in model_summary.items():
    if(value == 'Conv2d'):
        conv_layer.append(key)
grouped = []

for i in range(len(conv_layer)):
    model = copy.deepcopy(mobilenet_v3)
    
    layer = eval(conv_layer[i])
    indices_to_remove = get_importance(layer, sparsity=0.10)


    example_inputs = torch.randn(1,3,32,32)


    DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

    pruning_group = DG.get_pruning_group( layer, tp.prune_conv_out_channels, idxs=indices_to_remove )

    if DG.check_pruning_group(pruning_group):
        pruning_group.prune()
    try:
        model(example_inputs)
    except:
        grouped.append(i)
        continue

In [7]:
grouped

[0,
 1,
 3,
 5,
 6,
 8,
 9,
 11,
 12,
 14,
 16,
 17,
 19,
 21,
 22,
 24,
 26,
 27,
 29,
 31,
 32,
 34,
 36,
 37,
 39,
 41,
 42,
 44,
 46,
 47,
 49]

In [8]:
working = []
for i in range(len(conv_layer)):
    if i not in grouped:
        working.append(i)
# working.append(grouped)
# print(working)

In [30]:
mobilenet_v3

MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), 

In [53]:
model = copy.deepcopy(mobilenet_v3)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)
summary(model,(3,32,32))
for i in range(len(working)):
    print("I:",i)
    
    layer = eval(conv_layer[working[i]])
    indices_to_remove = get_importance(layer, sparsity=0.80)
    print(conv_layer[working[i]],layer, indices_to_remove)

    example_inputs = torch.randn(1,3,32,32)

    pruning_group = DG.get_pruning_group(layer, tp.prune_conv_out_channels, idxs=indices_to_remove )

    if DG.check_pruning_group(pruning_group):
        pruning_group.prune()
    print("++++++++")
#     macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
#     print(macs, nparams)
#     break
pruned_model = model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
         Hardswish-3           [-1, 16, 16, 16]               0
            Conv2d-4             [-1, 16, 8, 8]             144
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12             [-1, 16, 8, 8]               0
           Conv2d-13             [-1, 16, 8, 8]             256
      BatchNorm2d-14             [-1, 1

In [43]:
torch.save(model.state_dict(),'orig.pth')
torch.save(mode,'orig.pth')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
         Hardswish-3           [-1, 16, 16, 16]               0
            Conv2d-4             [-1, 16, 8, 8]             144
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12             [-1, 16, 8, 8]               0
           Conv2d-13             [-1, 16, 8, 8]             256
      BatchNorm2d-14             [-1, 1

In [44]:
summary(pruned_model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             432
       BatchNorm2d-2           [-1, 16, 16, 16]              32
         Hardswish-3           [-1, 16, 16, 16]               0
            Conv2d-4             [-1, 16, 8, 8]             144
       BatchNorm2d-5             [-1, 16, 8, 8]              32
              ReLU-6             [-1, 16, 8, 8]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12             [-1, 16, 8, 8]               0
           Conv2d-13             [-1, 16, 8, 8]             256
      BatchNorm2d-14             [-1, 1