In [1]:
import timm
import torch
import torch.nn as nn
from collections import OrderedDict
import re
import numpy as np

In [2]:
timm.list_models('*mobilenet*')

['mobilenetv2_035',
 'mobilenetv2_050',
 'mobilenetv2_075',
 'mobilenetv2_100',
 'mobilenetv2_110d',
 'mobilenetv2_120d',
 'mobilenetv2_140',
 'mobilenetv3_large_075',
 'mobilenetv3_large_100',
 'mobilenetv3_rw',
 'mobilenetv3_small_050',
 'mobilenetv3_small_075',
 'mobilenetv3_small_100',
 'tf_mobilenetv3_large_075',
 'tf_mobilenetv3_large_100',
 'tf_mobilenetv3_large_minimal_100',
 'tf_mobilenetv3_small_075',
 'tf_mobilenetv3_small_100',
 'tf_mobilenetv3_small_minimal_100']

In [3]:
model = timm.create_model('mobilenetv3_small_100', num_classes=10)

In [4]:
fusing_layers = [
            'Conv2d',
            'BatchNorm2d',
            'ReLU',
            'Linear',
            'BatchNorm1d',
            
        ]


def detect_sequences(lst):

            i = 0
            while i < len(lst):
               
                if i + 2 < len(lst) and [l for l in lst[i : i + 3]] == [
                    fusing_layers[0],
                    fusing_layers[1],
                    fusing_layers[2],
                ]:
                    test_layer = layer_shapes[i : i + 2]
                    if(np.all(test_layer == test_layer[0])):
                        mapped_layers['Conv2d_BatchNorm2d_ReLU'].append(
                            np.take(name_list, [i for i in range(i, i + 3)]).tolist()
                        )
                        i += 3
                
                elif i + 1 < len(lst) and [l for l in lst[i : i + 2]] == [
                    fusing_layers[0],
                    fusing_layers[1],
                ]:
                    test_layer = layer_shapes[i : i + 2]
                    if(np.all(test_layer == test_layer[0])):
                        
                        mapped_layers['Conv2d_BatchNorm2d'].append(
                            np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                        )
                        i += 2
                # if i + 1 < len(lst) and [ type(l) for l in lst[i:i+2]] == [fusing_layers[0], fusing_layers[2]]:
                #     detected_sequences.append(np.take(name_list,[i for i in range(i,i+2)]).tolist())
                #     i += 2
                # elif i + 1 < len(lst) and [ type(l) for l in lst[i:i+2]] == [fusing_layers[1], fusing_layers[2]]:
                #     detected_sequences.append(np.take(name_list,[i for i in range(i,i+2)]).tolist())
                #     i += 2
                elif i + 1 < len(lst) and [l for l in lst[i : i + 2]] == [
                    fusing_layers[3],
                    fusing_layers[2],
                ]:
                    mapped_layers['Linear_ReLU'].append(
                        np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                    )
                    i += 2
                elif i + 1 < len(lst) and [l for l in lst[i : i + 2]] == [
                    fusing_layers[3],
                    fusing_layers[4],
                ]:
                    mapped_layers['Linear_BatchNorm1d'].append(
                        np.take(name_list, [i for i in range(i, i + 2)]).tolist()
                    )
                    i += 2
                else:
                    i += 1


def get_all_layers(model, parent_name=''):
    layers = []
    for name, module in model.named_children():
        full_name = f"{parent_name}.{name}" if parent_name else name
        layers.append((reformat_layer_name(full_name), module))
        if isinstance(module, nn.Module):
            layers.extend(get_all_layers(module, parent_name=full_name))
    return layers


def reformat_layer_name(str_data):
    try:
    
        split_data = str_data.split('.')
        for ind in range(len(split_data)):
            data = split_data[ind]
            if(data.isdigit()):
                split_data[ind] = "["+data+"]"
        final_string = '.'.join(split_data)

        iters_a = re.finditer(r'[a-zA-Z]\.\[', final_string)
        indices = [m.start(0)+1 for m in iters_a]
        iters = re.finditer(r'\]\.\[', final_string)
        indices.extend([m.start(0) for m in iters])

        final_string = list(final_string)
        for ind in indices:
            final_string.pop(ind)
        str_data = ''.join(final_string)
        
    except:
        pass
    
    return str_data


def summary_string_fixed(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
    if dtypes is None:
        dtypes = [torch.FloatTensor] * len(input_size)

    summary_str = ''
    

    def register_hook(module, layer_name, module_idx):
        def hook(module, input, output):
            nonlocal module_idx  # Add this line to access the outer module_idx variable
            class_name = str(module.__class__).split(".")[-1].split("'")[0]

            m_key = reformat_layer_name(all_layers[module_idx][0])
          
                
            summary[m_key] = OrderedDict()
            summary[m_key]["type"] = str(type(module)).split('.')[-1][:-2]
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size

            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
                summary[m_key]["weight_shape"] = module.weight.shape
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype)
         for in_size, dtype in zip(input_size, dtypes)]

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook for each layer
    for module_idx, (layer_name, module) in enumerate(all_layers):
        register_hook(module, layer_name, module_idx)

    # make a forward pass
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    return summary


In [11]:
model = copy.deepcopy(self.model)
all_layers = get_all_layers(model)
model_summary = summary_string_fixed(model, (3, 64, 64))#, device="cuda")

name_type_shape = []
for key in model_summary.keys():
    data = model_summary[key]
    if ("weight_shape" in data.keys()):
        name_type_shape.append([key, data['type'], data['weight_shape'][0]])      
name_type_shape = np.asarray(name_type_shape)

name_list = name_type_shape[:,0]
layer_types = name_type_shape[:,1]
layer_types[layer_types=='BatchNormAct2d'] = 'BatchNorm2d'
layer_shapes = name_type_shape[:,2]
mapped_layers = {'model_layer':[],'Conv2d_BatchNorm2d_ReLU':[],'Conv2d_BatchNorm2d':[],'Linear_ReLU':[],'Linear_BatchNorm1d':[]}
        
detect_sequences(layer_types)

for keys,value in mapped_layers.items():
    mapped_layers[keys] = np.asarray(mapped_layers[keys])
    
mapped_layers['model_layer'] = name_type_shape[:,:2]
# self.mapped_layers = mapped_layers



In [6]:
#Folding
#CWP

In [None]:




# # iterate through conv layers
# for i_conv in range(len(all_convs) - 1):
#     # each channel sorting index, we need to apply it to:
#     # - the output dimension of the previous conv
#     # - the previous BN layer
#     # - the input dimension of the next conv (we compute importance here)
#     prev_conv = getattr(model,all_convs[i_conv])
#     prev_bn = getattr(model,all_bns[i_conv])
#     next_conv = getattr(model,all_convs[i_conv + 1])
#     # note that we always compute the importance according to input channels
#     importance = self.get_input_channel_importance(next_conv.weight)
#     # sorting from large to small
#     sort_idx = torch.argsort(importance, descending=True)

#     # apply to previous conv and its following bn
#     prev_conv.weight.copy_(
#         torch.index_select(prev_conv.weight.detach(), 0, sort_idx)
#     )
#     for tensor_name in ["weight", "bias", "running_mean", "running_var"]:
#         tensor_to_apply = getattr(prev_bn, tensor_name)
#         tensor_to_apply.copy_(
#             torch.index_select(tensor_to_apply.detach(), 0, sort_idx)
#         )

#     # apply to the next conv input (hint: one line of code)
#     ##################### YOUR CODE STARTS HERE #####################
#     next_conv.weight.copy_(
#         torch.index_select(next_conv.weight.detach(), 1, sort_idx)
#     )

# return model