In [1]:
"""
Remove the unnecessary filters from a CNN pruned using main.py, then compare the converted net with the previous.
While main.py allows to determine which filters can be removed, those remain in the network's architecture. The
current script creates a new, lightweight architecture from the result of main.py.

NOTE: The arguments passed to this script are parsed in main.py (i.e. a dataset choice must be made).
"""

import time
from collections import OrderedDict

import torch
import torch.nn as nn

# import VITALabAI.project.bar.model.resnet as resnet
# from VITALabAI.project.bar.model.layers import SparseConvConfig
# # from VITALabAI.project.bar.main import ClassificationTraining


class IdentityModule(nn.Module):
    def forward(self, x):
        return x


def convert_resnet(net, insert_identity_modules=False):
    """Convert a ResNetCifar module (in place)

    Returns
    -------
        net: the mutated net
    """
    
    net.conv1, net.bn1 = convert_conv_bn(net.conv1, net.bn1, torch.ones(3).byte(), (cfg_dict['bn1']>0))
    in_gates = torch.ones(net.conv1.out_channels).byte()

    clean_res = True
    net.layer1, in_gates = convert_layer(net.layer1, in_gates, insert_identity_modules, clean_res, layer_name = 'layer1')
    net.layer2, in_gates = convert_layer(net.layer2, in_gates, insert_identity_modules, clean_res, layer_name = 'layer2')
    net.layer3, in_gates = convert_layer(net.layer3, in_gates, insert_identity_modules, clean_res, layer_name = 'layer3')
    net.layer4, in_gates = convert_layer(net.layer4, in_gates, insert_identity_modules, clean_res, layer_name = 'layer4')


    if clean_res:
        net.fc = convert_fc_head(net.fc, in_gates)
    else:
        net.fc = resnet.InwardPrunedLinear(convert_fc_head(net.fc, in_gates), mask2i(in_gates))

    return net


def convert_layer(layer_module, in_gates, insert_identity_modules, clean_res, layer_name =None):
    """Convert a ResnetCifar layer (in place)

    Parameters
    ----------
        layer_module: a nn.Sequential
        in_gates: mask

    Returns
    -------
        layer_module: mutated layer_module
        in_gates: ajusted mask
    """

    previous_layer_gates = in_gates

    new_blocks = []
    for block_num, block in enumerate(layer_module):
        new_block, in_gates = convert_block(block, in_gates, block_name = layer_name+'.'+str(block_num))
        if new_block is None:
            if insert_identity_modules:
                new_blocks.append(IdentityModule())
        else:
            new_blocks.append(new_block)

    # Remove unused residual features
    if clean_res:
        print()
        cur_layer_gates = in_gates
        for block in new_blocks:
            if isinstance(block, IdentityModule):
                continue
            clean_block(block, previous_layer_gates, cur_layer_gates)  # in-place

    layer_module = nn.Sequential(*new_blocks)
    return layer_module, in_gates


def clean_block(mixed_block, previous_layer_alivef, cur_layer_alivef):
    """Remove unused res features (operates in-place)"""

    def clean_indices(idx, alive_mask=cur_layer_alivef):
        mask = i2mask(idx, alive_mask)
        mask = mask[mask2i(alive_mask)]
        return mask2i(mask)

    if mixed_block.f_res is None:
        mixed_block.in_idx = clean_indices(mixed_block.in_idx)
    else:
        mixed_block.in_idx = clean_indices(mixed_block.in_idx, alive_mask=previous_layer_alivef)
        mixed_block.res_size = cur_layer_alivef.sum().item()
        print('DOWNS ----- Res size: ', mixed_block.res_size)
        mixed_block.res_idx = clean_indices(mixed_block.res_idx)
        print('Res:  ', len(mixed_block.res_idx))
    mixed_block.delta_idx = clean_indices(mixed_block.delta_idx)

    print('In:   ', len(mixed_block.in_idx))
    print('Delta:', len(mixed_block.delta_idx))


def convert_block(block_module, in_gates, block_name = None):
    """Convert a Basic Resnet block (in place)

    Parameters
    ----------
        block_module: a BasicBlock
        in_gates: received mask

    Returns
    -------
        block_module: mutated block
        in_gates: out_gates of this block (in_gates for next block)
    """

#     assert not hasattr(block_module, 'conv3')  # must be basic block

    b1_gates = (cfg_dict[f'{block_name}.bn1']>0)   # get_gates(block_module.bn1)
    b2_gates = (cfg_dict[f'{block_name}.bn2']>0)
    b3_gates = (cfg_dict[f'{block_name}.bn3']>0)
   

    delta_branch_is_pruned = b1_gates.sum().item() == 0 or b2_gates.sum().item() == 0 or b3_gates.sum().item() == 0
    
    # Delta branch
    if not delta_branch_is_pruned:
        block_module.conv1, block_module.bn1 = convert_conv_bn(block_module.conv1, block_module.bn1, in_gates, b1_gates)
        block_module.conv2, block_module.bn2 = convert_conv_bn(block_module.conv2, block_module.bn2, b1_gates, b2_gates)
        block_module.conv3, block_module.bn3 = convert_conv_bn(block_module.conv3, block_module.bn3, b2_gates, b3_gates)


    if block_module.downsample is not None:
        ds_gates = (cfg_dict[f'{block_name}.downsample.1']>0) # get_gates(block_module.downsample[1])
        ds_conv, ds_bn = convert_conv_bn(block_module.downsample[0], block_module.downsample[1], in_gates, ds_gates)
        ds_module = nn.Sequential(ds_conv, ds_bn)

        if delta_branch_is_pruned:
            mixed_block = MixedBlock(f_delta=None, delta_idx=None,
                                            f_res=ds_module,
                                            in_idx=mask2i(in_gates),
                                            res_idx=mask2i(ds_gates),
                                            res_size=len(b3_gates))
        else:
            block_module.downsample = ds_module
            mixed_block = MixedBlock.from_bottleneck(block_module,
                                                       delta_idx=mask2i(b3_gates),
                                                       in_idx=mask2i(in_gates),
                                                       res_idx=mask2i(ds_gates),
                                                       res_size=len(b3_gates))
        in_gates = elementwise_or(ds_gates, b3_gates)
    else:
        if delta_branch_is_pruned:
            mixed_block = None
        else:
            mixed_block = MixedBlock.from_bottleneck(block_module,
                                                       delta_idx=mask2i(b3_gates),
                                                       in_idx=mask2i(in_gates))
        in_gates = elementwise_or(in_gates, b3_gates)

    return mixed_block, in_gates


def convert_conv_bn(conv_module, bn_module, in_gates, out_gates):
    in_indices = mask2i(in_gates)  # indices of kept features
    out_indices = mask2i(out_gates)

    # Keep the good ones
    new_conv_w = conv_module.weight.data[out_indices][:, in_indices]

    new_conv = make_conv(new_conv_w, from_module=conv_module)
    new_bn = convert_bn(bn_module, out_indices)

    new_conv.out_idx = out_indices
    
    return new_conv, new_bn


def convert_fc_head(fc_module, in_gates):
    """Convert a the final FC module of the net

    Parameters
    ----------
        fc_module: a nn.Linear with weight tensor of size (out_f, in_f)
        in_gates: binary vector or list of size in_f

    Returns
    -------
        fc_module: mutated module
    """

    in_indices = mask2i(in_gates)
    new_weight_tensor = fc_module.weight.data[:, in_indices]
    return make_fc(new_weight_tensor, from_module=fc_module)


def convert_bn(bn_module, out_indices):
#     z = bn_module.get_gates(stochastic=False)
    new_weight = bn_module.weight.data[out_indices] # * z[out_indices]
    new_bias = bn_module.bias.data[out_indices] # * z[out_indices]

    new_bn_module = nn.BatchNorm2d(len(new_weight))
    new_bn_module.weight.data.copy_(new_weight)
    new_bn_module.bias.data.copy_(new_bias)
    new_bn_module.running_mean.copy_(bn_module.running_mean[out_indices])
    new_bn_module.running_var.copy_(bn_module.running_var[out_indices])

    new_bn_module.out_idx = out_indices

    return new_bn_module


def make_bn(bn_module, kept_indices):
    new_bn_module = nn.BatchNorm2d(len(kept_indices))
    new_bn_module.weight.data.copy_(bn_module.weight.data[kept_indices])
    new_bn_module.bias.data.copy_(bn_module.bias.data[kept_indices])
    new_bn_module.running_mean.copy_(bn_module.running_mean[kept_indices])
    new_bn_module.running_var.copy_(bn_module.running_var[kept_indices])

    if hasattr(bn_module, 'out_idx'):
        new_bn_module.out_idx = bn_module.out_idx[kept_indices]
    else:
        new_bn_module.out_idx = kept_indices

    return new_bn_module


def make_conv(weight_tensor, from_module):
    # NOTE: No bias

    # New weight size
    in_channels = weight_tensor.size(1)
    out_channels = weight_tensor.size(0)

    # Other params
    kernel_size = from_module.kernel_size
    stride = from_module.stride
    padding = from_module.padding

    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
    conv.weight.data.copy_(weight_tensor)
    return conv


def make_fc(weight_tensor, from_module):
    in_features = weight_tensor.size(1)
    out_features = weight_tensor.size(0)
    fc = nn.Linear(in_features, out_features)
    fc.weight.data.copy_(weight_tensor)
    fc.bias.data.copy_(from_module.bias.data)
    return fc

def elementwise_or(a, b):
    return (a + b) > 0


def mask2i(mask):
#     assert mask.dtype == torch.uint8
    return mask.nonzero().view(-1)  # Note: do not use .squeeze() because single item becomes a scalar instead of 1-vec


def i2mask(i, from_tensor):
    x = torch.zeros_like(from_tensor)
    x[i] = 1
    return x


In [2]:
class MixedBlock(nn.Module):
    def __init__(self, f_delta, delta_idx, in_idx, f_res=None, res_idx=None, res_size=None):
        super(MixedBlock, self).__init__()
        self.f_delta = f_delta
        self.delta_idx = delta_idx
        self.in_idx = in_idx
        self.f_res = f_res
        self.res_idx = res_idx
        self.res_size = res_size
        self.activ = nn.ReLU(inplace=True)

        self.res_scatter_idx = None
        self.delta_scatter_idx = None

        if f_delta is None:
            self.forward = self.forward_without_delta
        else:
            self.forward = self.forward_with_delta

    def scatter_features(self, idx, src, final_size):
        if self.res_scatter_idx is None or self.res_scatter_idx.size(0) != src.size(0):
            scatter_idx = idx.new_empty(*src.size())
            scatter_idx.copy_(idx[None, :, None, None])
            self.res_scatter_idx = scatter_idx

        x = torch.zeros(src.size(0), final_size, src.size(2), src.size(3)).to(src)
        x.scatter_(dim=1, index=self.res_scatter_idx, src=src)
        return x

    def scatter_add_features(self, dst, idx, src):
        if self.delta_scatter_idx is None or self.delta_scatter_idx.size(0) != src.size(0):
            scatter_idx = idx.new_empty(*src.size())
            scatter_idx.copy_(idx[None, :, None, None])
            self.delta_scatter_idx = scatter_idx

        dst.scatter_add_(dim=1, index=self.delta_scatter_idx, src=src)

    def forward_with_delta(self, x):
        # x: (B, C, H, W)

        if self.f_res is None:
            x_alive = x.index_select(dim=1, index=self.in_idx)
            delta = self.f_delta.forward(x_alive)  # 3x3 conv
        else:
            delta = self.f_delta.forward(x)  # 3x3 conv

            res = self.f_res.forward(x)  # 1x1 conv
            x = self.scatter_features(self.res_idx, res, self.res_size)

        self.scatter_add_features(x, self.delta_idx, delta)
        
        return self.activ(x)

    def forward_without_delta(self, x):
        res = self.f_res.forward(x)  # 1x1 conv
        x = self.scatter_features(self.res_idx, res, self.res_size)

        return self.activ(x)

    @staticmethod
    def from_basic(block, delta_idx, in_idx, res_idx=None, res_size=None):
        f_delta = nn.Sequential(
            block.conv1,
            block.bn1,
            block.activ,  # nn.ReLU(inplace=True)
            block.conv2,
            block.bn2
        )
        return MixedBlock(f_delta, delta_idx, in_idx, block.downsample, res_idx, res_size)
    
    @staticmethod
    def from_bottleneck(block, delta_idx, in_idx, res_idx=None, res_size=None):
        f_delta = nn.Sequential(
            block.conv1,
            block.bn1,
            block.relu,  # nn.ReLU(inplace=True)
            block.conv2,
            block.bn2,
            block.relu,
            block.conv3,
            block.bn3
        )
        return MixedBlock(f_delta, delta_idx, in_idx, block.downsample, res_idx, res_size)

In [3]:
# model

In [4]:
# list(model.parameters())

In [5]:
# %cd /workspace/pytracking

In [6]:
# model = resnet50_single_mask()

In [7]:
%cd /workspace/pytracking
from ltr.models.backbone.resnet import resnet50
from ltr.models.backbone.resnet_child import resnet50_child
from ltr.models.backbone.resnet_shared_mask_wo_bn import resnet50_mask_wo_bn 

/workspace/pytracking


In [8]:
model = resnet50_mask_wo_bn()

In [9]:
from ltr.admin.loading import torch_load_legacy
ckpt = torch_load_legacy('/workspace/tracking_datasets/saved_ckpts/ltr/dimp/sparse/dimp50_mask_wo_bn/DiMPnet_ep0049.pth.tar')['net']

In [10]:
from collections import OrderedDict
new_state = OrderedDict()

for key, value in ckpt.items():
    key = key[18:] # remove `module.`
    new_state[key] = value
model.load_state_dict(new_state, strict = False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['initializer.filter_conv.weight', 'initializer.filter_conv.bias', 'optimizer.log_step_length', 'optimizer.filter_reg', 'optimizer.label_map_predictor.weight', 'optimizer.target_mask_predictor.0.weight', 'optimizer.spatial_weight_predictor.weight', '_extractor.0.weight', '_1r.0.weight', '_1r.0.bias', '_1r.1.weight', '_1r.1.bias', '_1r.1.running_mean', '_1r.1.running_var', '_1r.1.num_batches_tracked', '_1t.0.weight', '_1t.0.bias', '_1t.1.weight', '_1t.1.bias', '_1t.1.running_mean', '_1t.1.running_var', '_1t.1.num_batches_tracked', '_2t.0.weight', '_2t.0.bias', '_2t.1.weight', '_2t.1.bias', '_2t.1.running_mean', '_2t.1.running_var', '_2t.1.num_batches_tracked', 'r.0.weight', 'r.0.bias', 'r.1.weight', 'r.1.bias', 'r.1.running_mean', 'r.1.running_var', 'r.1.num_batches_tracked', '3r.0.weight', '3r.0.bias', '3r.1.weight', '3r.1.bias', '3r.1.running_mean', '3r.1.running_var', '3r.1.num_batches_tracked', '4r.0.weight', '4r.0.bias', '4r.1.weig

In [11]:
# ckpt.keys()

In [12]:
import math
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.utils.model_zoo as model_zoo
from torchvision.models.resnet import model_urls


import numpy as np

import torch
import torch.nn as nn


class Backbone(nn.Module):
    """Base class for backbone networks. Handles freezing layers etc.
    args:
        frozen_layers  -  Name of layers to freeze. Either list of strings, 'none' or 'all'. Default: 'none'.
    """
    def __init__(self, frozen_layers=()):
        super().__init__()

        if isinstance(frozen_layers, str):
            if frozen_layers.lower() == 'none':
                frozen_layers = ()
            elif frozen_layers.lower() != 'all':
                raise ValueError('Unknown option for frozen layers: \"{}\". Should be \"all\", \"none\" or list of layer names.'.format(frozen_layers))

        self.frozen_layers = frozen_layers
        self._is_frozen_nograd = False


    def train(self, mode=True):
        super().train(mode)
        if mode == True:
            self._set_frozen_to_eval()
        if not self._is_frozen_nograd:
            self._set_frozen_to_nograd()
            self._is_frozen_nograd = True
        return self


    def _set_frozen_to_eval(self):
        if isinstance(self.frozen_layers, str) and self.frozen_layers.lower() == 'all':
            self.eval()
        else:
            for layer in self.frozen_layers:
                getattr(self, layer).eval()


    def _set_frozen_to_nograd(self):
        if isinstance(self.frozen_layers, str) and self.frozen_layers.lower() == 'all':
            for p in self.parameters():
                p.requires_grad_(False)
        else:
            for layer in self.frozen_layers:
                for p in getattr(self, layer).parameters():
                    p.requires_grad_(False)

class scaler(nn.Module):
    def __init__(self,num_features):
        super(scaler, self).__init__()
        self.weight = nn.parameter.Parameter(torch.empty(num_features)).reshape(1,num_features,1,1).cuda()
        self.weight.retain_grad()
    def forward(self,x):
        out = self.weight * x
        return out
    

class channel_selection(nn.Module):
    """
    Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
    The output shape of this layer is determined by the number of 1 in `self.indexes`.
    """
    def __init__(self, num_channels):
        """
        Initialize the `indexes` with all one vector with the length same as the number of channels.
        During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
        """
        super(channel_selection, self).__init__()
        self.indexes = nn.Parameter(torch.ones(num_channels))

    def forward(self, input_tensor):
        """
        Parameter
        ---------
        input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
        """
        selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
        if selected_index.size == 1:
            selected_index = np.resize(selected_index, (1,)) 
        output = input_tensor[:, :, :, :]
        return output





def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, bias=False, dilation=dilation)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_bn=True):
        super(BasicBlock, self).__init__()
        self.use_bn = use_bn
        self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)

        if use_bn:
            self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, dilation=dilation)

        if use_bn:
            self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        if self.use_bn:
            out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)

        if self.use_bn:
            out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes,cfg,stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(cfg[0],cfg[1], kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(cfg[1])
        self.conv2 = nn.Conv2d(cfg[1],cfg[2], kernel_size=3, stride=stride,
                               padding=dilation, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(cfg[2])
        self.conv3 = nn.Conv2d(cfg[2],cfg[3], kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(cfg[3])
        self.relu = nn.ReLU(inplace=True)
#         self.mask = scaler(cfg[3])
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        
#         out = self.mask(out)
#         residual = self.mask(residual)
        
        out += residual
        out = self.relu(out)

        return out


class ResNet(Backbone):
    """ ResNet network module. Allows extracting specific feature blocks."""
    def __init__(self, block, layers, output_layers,cfg = None , num_classes=1000, inplanes=64, dilation_factor=1, frozen_layers=()):
        self.inplanes = inplanes
        super(ResNet, self).__init__(frozen_layers=frozen_layers)
        self.output_layers = output_layers
        self.conv1 = nn.Conv2d(3, inplanes , kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(inplanes)

        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        stride = [1 + (dilation_factor < l) for l in (8, 4, 2)]
        self.layer1 = self._make_layer(block, inplanes, layers[0], dilation=max(dilation_factor//8, 1), cfg = cfg[0:(3*layers[0])+1])
        self.layer2 = self._make_layer(block, inplanes*2, layers[1], stride=stride[0], dilation=max(dilation_factor//4, 1), cfg = cfg[3*layers[0]:3*layers[1]+3*layers[0]+1])
        self.layer3 = self._make_layer(block, inplanes*4, layers[2], stride=stride[1], dilation=max(dilation_factor//2, 1), cfg = cfg[3*layers[1]+3*layers[0]:3*layers[1]+3*layers[0]+3*layers[2]+1])
        self.layer4 = self._make_layer(block, inplanes*8, layers[3], stride=stride[2], dilation=dilation_factor, cfg = cfg[3*layers[1]+3*layers[0]+3*layers[2]:3*layers[1]+3*layers[0]+3*layers[2]+3*layers[3]+1])

        out_feature_strides = {'conv1': 4, 'layer1': 4, 'layer2': 4*stride[0], 'layer3': 4*stride[0]*stride[1],
                               'layer4': 4*stride[0]*stride[1]*stride[2]}

        # TODO better way?
        if isinstance(self.layer1[0], BasicBlock):
            out_feature_channels = {'conv1': inplanes, 'layer1': inplanes, 'layer2': inplanes*2, 'layer3': inplanes*4,
                               'layer4': inplanes*8}
        elif isinstance(self.layer1[0], Bottleneck):
            base_num_channels = 4 * inplanes
            out_feature_channels = {'conv1': inplanes, 'layer1': base_num_channels, 'layer2': base_num_channels * 2,
                                    'layer3': base_num_channels * 4, 'layer4': base_num_channels * 8}
        else:
            raise Exception('block not supported')

        self._out_feature_strides = out_feature_strides
        self._out_feature_channels = out_feature_channels

        # self.avgpool = nn.AvgPool2d(7, stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(cfg[-1], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def out_feature_strides(self, layer=None):
        if layer is None:
            return self._out_feature_strides
        else:
            return self._out_feature_strides[layer]

    def out_feature_channels(self, layer=None):
        if layer is None:
            return self._out_feature_channels
        else:
            return self._out_feature_channels[layer]

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, cfg=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(cfg[0],cfg[3],
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(cfg[3]),
#                 channel_selection(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, cfg[0:4], stride, downsample, dilation=dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
#             print(3*i,(3*(i+1))+1)
#             print(cfg[3*i:(3*(i+1))+1],(3*(i+1))+1 - (3*i))
#             print(cfg[9])
            layers.append(block(self.inplanes, planes, cfg[3*i:(3*(i+1))+1]))

        return nn.Sequential(*layers)

    def _add_output_and_check(self, name, x, outputs, output_layers):
        if name in output_layers:
            outputs[name] = x
        return len(output_layers) == len(outputs)

    def forward(self, x, output_layers=None):
        """ Forward pass with input x. The output_layers specify the feature blocks which must be returned """
        outputs = OrderedDict()

        if output_layers is None:
            output_layers = self.output_layers

        x = self.conv1(x)
        x = self.bn1(x) ####### select daal dena
        x = self.relu(x)

        if self._add_output_and_check('conv1', x, outputs, output_layers):
            return outputs

        x = self.maxpool(x)

        x = self.layer1(x)

        if self._add_output_and_check('layer1', x, outputs, output_layers):
            return outputs

        x = self.layer2(x)

        if self._add_output_and_check('layer2', x, outputs, output_layers):
            return outputs

        x = self.layer3(x)

        if self._add_output_and_check('layer3', x, outputs, output_layers):
            return outputs

        x = self.layer4(x)

        if self._add_output_and_check('layer4', x, outputs, output_layers):
            return outputs

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        if self._add_output_and_check('fc', x, outputs, output_layers):
            return outputs

        if len(output_layers) == 1 and output_layers[0] == 'default':
            return x

        raise ValueError('output_layer is wrong.')


def resnet_baby(output_layers=None, pretrained=False, inplanes=16, **kwargs):
    """Constructs a ResNet-18 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))

    model = ResNet(BasicBlock, [2, 2, 2, 2], output_layers, inplanes=inplanes, **kwargs)

    if pretrained:
        raise NotImplementedError
    return model


def resnet18(output_layers=None, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))

    model = ResNet(BasicBlock, [2, 2, 2, 2], output_layers, **kwargs)

    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet50_child(output_layers=None, pretrained=False,cfg = None,**kwargs):
    """Constructs a ResNet-50 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))
#     ckpt = torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_correct/dimp50_correct.pth.tar')
#     cfg = ckpt['cfg']
    model = ResNet(Bottleneck, [3, 4, 6, 3],output_layers,cfg=cfg,**kwargs)
#     if pretrained:
# #         model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], progress = False), strict = False )
#           model.load_state_dict(ckpt['state_dict'])
#           print('pruned checkpoint loaded')
    return model

def resnet101_child(output_layers=None, pretrained=False,cfg = None,**kwargs):
    """Constructs a ResNet-50 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))
#     ckpt = torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_correct/dimp50_correct.pth.tar')
#     cfg = ckpt['cfg']
    model = ResNet(Bottleneck,[3, 4, 23, 3], output_layers,cfg=cfg,**kwargs)
#     if pretrained:
# #         model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], progress = False), strict = False )
#           model.load_state_dict(ckpt['state_dict'])
#           print('pruned checkpoint loaded')
    return model




In [13]:
# import numpy as np
# np.sum((model.layer1[0].bn3.weight).cpu().numpy()*1)
# (model.layer1[0].mask.weight.detach().cpu().numpy()<0.5).sum()

In [14]:
# for (idx, m) in model.named_modules():
#     if isinstance(m, nn.BatchNorm2d):
#         print(m.weight.data.shape[0])
#     if idx.split('.')[-1] == 'mask':
#         print(m.weight.data.shape)

In [15]:
total1 = 0
total2 = 0
total3 = 0
for idx, m in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        if idx.split('.')[0] == 'layer1':
            if idx.split('.')[-1]!='bn3':
                total1 += m.weight.data.shape[0]
        if idx.split('.')[0] == 'layer2':
            if idx.split('.')[-1]!='bn3':
                total2 += m.weight.data.shape[0]
        if idx.split('.')[0] == 'layer3':
            if idx.split('.')[-1]!='bn3':
                total3 += m.weight.data.shape[0]

bn1 = torch.zeros(total1)
bn2 = torch.zeros(total2)
bn3 = torch.zeros(total3)
index1 = 0
index2 = 0
index3 = 0
for (idx, m) in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        if idx.split('.')[0] == 'layer1':
            if idx.split('.')[-1]!='bn3':
                size = m.weight.data.shape[0]
#                 print(size)
                bn1[index1:(index1+size)] = m.weight.data.abs().clone()
                index1 += size
        if idx.split('.')[0] == 'layer2':
            if idx.split('.')[-1]!='bn3':
                size = m.weight.data.shape[0]
                bn2[index2:(index2+size)] = m.weight.data.abs().clone()
                index2 += size
        if idx.split('.')[0] == 'layer3':
            if idx.split('.')[-1]!='bn3':
                size = m.weight.data.shape[0]
                bn3[index3:(index3+size)] = m.weight.data.abs().clone()
                index3 += size
                
#     if idx.split('.')[-1] == 'mask':
#         if idx.split('.')[0] == 'layer1':
#             size = m.weight.data.shape[1]
#             bn1[index1:(index1+size)] = m.weight.data.abs().clone()[0,:,0,0]
#             index1 += size
#         if idx.split('.')[0] == 'layer2':
#             size = m.weight.data.shape[1]
#             bn2[index2:(index2+size)] = m.weight.data.abs().clone()[0,:,0,0]
#             index2 += size
#         if idx.split('.')[0] == 'layer3':
#             size = m.weight.data.shape[1]
#             bn3[index3:(index3+size)] = m.weight.data.abs().clone()[0,:,0,0]
#             index3 += size

            
y1, i = torch.sort(bn1)
thre_index1 = int(total1 * 0.5)
thre1 = y1[thre_index1]

y2, i = torch.sort(bn2)
thre_index2 = int(total2 * 0.5)
thre2 = y2[thre_index2]

y3, i = torch.sort(bn3)
thre_index3 = int(total3 * 0.5)
thre3 = y3[thre_index3]

In [16]:
newmodel = resnet50()


pruned = 0
cfg = []
cfg_mask = []
modules = list(model.named_modules())

old_modules = list(model.modules())
new_modules = list(newmodel.modules())
cfg_dict = {}

for k, (idx, m) in enumerate(modules):
    
    if isinstance(m, nn.BatchNorm2d) :
        if idx.split('.')[0] == 'layer1':
            if idx.split('.')[-1]!='bn3':
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(thre1).float()
    #             m.wmul_eight.data = weight_copy*mask

                pruned = pruned + mask.shape[0] - torch.sum(mask)
                m.weight.data.mul_(mask)
                m.bias.data.mul_(mask)
    #             if not isinstance(modules[k-2],nn.Sequential):
    #             print(idx.split('.')[1])
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
        #             print(modules[k-2])
                cfg_mask.append(mask.clone())
                cfg_dict[idx] = mask
                print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer2':
            if idx.split('.')[-1]!='bn3':
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(thre2).float()
                pruned = pruned + mask.shape[0] - torch.sum(mask)
                m.weight.data.mul_(mask)
                m.bias.data.mul_(mask)
    #             if not isinstance(modules[k-2],nn.Sequential):
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
        #             print(modules[k-2])
                cfg_mask.append(mask.clone())
                cfg_dict[idx] = mask
                print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer3':
            if idx.split('.')[-1]!='bn3':
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(thre3).float()
                pruned = pruned + mask.shape[0] - torch.sum(mask)
                m.weight.data.mul_(mask)
                m.bias.data.mul_(mask)
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
        #             print(modules[k-2])
                cfg_mask.append(mask.clone())
                cfg_dict[idx] = mask
                print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
            
        else:
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(0.0).float()
#             pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            try:
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
            except:
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
    if idx.split('.')[-1]=='mask':
        if idx.split('.')[0]!='layer4':
            weight_copy = m.weight.data.abs().clone()
            y1, i1 = torch.sort(weight_copy[0,:,0,0])
            thre_index1 = int(len(y1) * 0.5)
            thre4 = y1[thre_index1]
            print(thre4)
#             print
            mask = weight_copy.gt(thre4).float()
            pruned = pruned + mask.shape[1] - torch.sum(mask)
            m.weight.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask[0,:,0,0]
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[1], int(torch.sum(mask))))
        else:
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(-1.0).float()
#             pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            #m.bias.data.mul_(mask)
            cfg_dict[idx] = mask[0,:,0,0]
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/(total1+total2+total3)

tensor(64.)
layer index: 2 	 total channel: 64 	 remaining channel: 64
tensor(27.)
layer index: 8 	 total channel: 64 	 remaining channel: 27
tensor(27.)
layer index: 10 	 total channel: 64 	 remaining channel: 27
tensor(0.4979, device='cuda:0')
layer index: 14 	 total channel: 256 	 remaining channel: 127
tensor(24.)
layer index: 19 	 total channel: 64 	 remaining channel: 24
tensor(29.)
layer index: 21 	 total channel: 64 	 remaining channel: 29
tensor(0.5002, device='cuda:0')
layer index: 25 	 total channel: 256 	 remaining channel: 127
tensor(28.)
layer index: 28 	 total channel: 64 	 remaining channel: 28
tensor(56.)
layer index: 30 	 total channel: 64 	 remaining channel: 56
tensor(0.4976, device='cuda:0')
layer index: 34 	 total channel: 256 	 remaining channel: 127
tensor(102.)
layer index: 38 	 total channel: 128 	 remaining channel: 102
tensor(105.)
layer index: 40 	 total channel: 128 	 remaining channel: 105
tensor(0.5003, device='cuda:0')
layer index: 44 	 total channel: 5

In [55]:
newmodel

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 27, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(27, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(27, 127, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(127, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 127, kernel_size=(1, 1), stride=(1, 

In [18]:
#cfg.pop(1)

In [19]:
# cfg_dict.keys()

In [20]:
# idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
# idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))

In [21]:
# cfg_dict.keys()

In [22]:
list_keys = list(cfg_dict.keys())
list_keys[0]

'bn1'

In [23]:
#list(newmodel.named_modules())

In [24]:
cfg_dict.keys()

dict_keys(['bn1', 'layer1.0.bn1', 'layer1.0.bn2', 'layer1.0.mask', 'layer1.1.bn1', 'layer1.1.bn2', 'layer1.1.mask', 'layer1.2.bn1', 'layer1.2.bn2', 'layer1.2.mask', 'layer2.0.bn1', 'layer2.0.bn2', 'layer2.0.mask', 'layer2.1.bn1', 'layer2.1.bn2', 'layer2.1.mask', 'layer2.2.bn1', 'layer2.2.bn2', 'layer2.2.mask', 'layer2.3.bn1', 'layer2.3.bn2', 'layer2.3.mask', 'layer3.0.bn1', 'layer3.0.bn2', 'layer3.0.mask', 'layer3.1.bn1', 'layer3.1.bn2', 'layer3.1.mask', 'layer3.2.bn1', 'layer3.2.bn2', 'layer3.2.mask', 'layer3.3.bn1', 'layer3.3.bn2', 'layer3.3.mask', 'layer3.4.bn1', 'layer3.4.bn2', 'layer3.4.mask', 'layer3.5.bn1', 'layer3.5.bn2', 'layer3.5.mask', 'layer4.0.bn1', 'layer4.0.bn2', 'layer4.0.bn3', 'layer4.0.mask', 'layer4.1.bn1', 'layer4.1.bn2', 'layer4.1.bn3', 'layer4.1.mask', 'layer4.2.bn1', 'layer4.2.bn2', 'layer4.2.bn3', 'layer4.2.mask'])

In [25]:
cfg_dict_mask_bn = []
cfg_dict_mask = []
cfg_dict_mask.append(cfg_dict['bn1'].numpy())
for i in cfg_dict.keys():
    #print(i)
    if i.split('.')[-1]=='mask' and i.split('.')[1]=='0':
        print(i)
        cfg_dict_mask.append(cfg_dict[i].cpu().numpy())
    if i.split('.')[-1]=='mask':
        cfg_dict_mask_bn.append(cfg_dict[i].cpu().numpy())
#((cfg_dict_mask[1]>0)*1).sum()
print(len(cfg_dict_mask),len(cfg_dict_mask_bn))

layer1.0.mask
layer2.0.mask
layer3.0.mask
layer4.0.mask
5 16


In [26]:
cfg.pop(1)

'M'

In [27]:
list_keys = list(cfg_dict.keys())

In [28]:
import os
newmodel = resnet50_child(cfg=cfg)
old_modules = list(model.named_modules())
new_modules = list(newmodel.named_modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
count = 0
count_conv = -1 
first_bn = False
k1 = 0
k2 = 0
while k2!=len(new_modules):
    idx_name1,m0 = old_modules[k1]
    idx_name2,m1 = new_modules[k2]
    if idx_name1!=idx_name2:
        if idx_name1.split('.')[-1]=='mask':
            k1+=1
        else:
            k2+=1
        print('unmatched','************',idx_name1,idx_name2)
        count+=1
        
    else:
        k1+=1
        k2+=1
        print('matched','************',idx_name1,idx_name2)
        if isinstance(m0, nn.BatchNorm2d):
            if idx_name1.split('.')[-1]!='bn3':
                start_mask = cfg_dict[idx_name1]
                idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
                w1 = m0.weight.data[idx0.tolist()].clone()
                m1.weight.data = w1.clone()

                w2 = m0.bias.data[idx0.tolist()].clone()
                m1.bias.data = w2.clone()

                #m1.running_mean = m0.running_mean.clone()
                #m1.running_var = m0.running_var.clone()
        elif isinstance(m0, nn.Conv2d):
            if idx_name1.split('.')[0]!='layer4':
                if len(idx_name1.split('.'))<=2:

                    start_mask = cfg_dict[list_keys[count_conv]]

                    end_mask = cfg_dict[list_keys[count_conv+1]]

                    if count_conv==-1:
                        start_mask = torch.ones(3)

                    idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
                    idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
                    print(idx0.shape,idx1.shape)
                    w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
                    w1 = w1[idx1.tolist(), :, :, :].clone()
                    m1.weight.data = w1.clone()
                    count_conv+=1

                else:
                    if idx_name1.split('.')[2]!='downsample':
                        start_mask = cfg_dict[list_keys[count_conv]]

                        end_mask = cfg_dict[list_keys[count_conv+1]]

                        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
                        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
                        print(idx0.shape,idx1.shape)
                        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
                        w1 = w1[idx1.tolist(), :, :, :].clone()
                        m1.weight.data = w1.clone()
                        count_conv+=1
            else:
                continue

matched ************  
matched ************ conv1 conv1
(3,) (64,)
matched ************ bn1 bn1
matched ************ relu relu
matched ************ maxpool maxpool
matched ************ layer1 layer1
matched ************ layer1.0 layer1.0
matched ************ layer1.0.conv1 layer1.0.conv1
(64,) (27,)
matched ************ layer1.0.bn1 layer1.0.bn1
matched ************ layer1.0.conv2 layer1.0.conv2
(27,) (27,)
matched ************ layer1.0.bn2 layer1.0.bn2
matched ************ layer1.0.conv3 layer1.0.conv3
(27,) (127,)
matched ************ layer1.0.bn3 layer1.0.bn3
matched ************ layer1.0.relu layer1.0.relu
unmatched ************ layer1.0.mask layer1.0.downsample
matched ************ layer1.0.downsample layer1.0.downsample
matched ************ layer1.0.downsample.0 layer1.0.downsample.0
unmatched ************ layer1.1 layer1.0.downsample.1
matched ************ layer1.1 layer1.1
matched ************ layer1.1.conv1 layer1.1.conv1
(127,) (24,)
matched ************ layer1.1.bn1 layer1.1

In [29]:
cfg_dict_mask

[array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),
 array([0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0.,
        1., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0.,
        1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1.,
        1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1.,
        1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 1.,
        1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.,
        1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
        0.,

In [30]:
cfg_dict_mask_bn

[array([0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0.,
        1., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0.,
        1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1.,
        1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1.,
        1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 1.,
        1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.,
        1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
        0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0.,
        0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 1.,
        0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
        0., 

In [36]:
k1 = 0
k2 = 0
count = 0 #conv
count_db = 1  #downsample
count_bn = 0   #batchnorm
while k2!=len(new_modules):
    idx_name1,m0 = old_modules[k1]
    idx_name2,m1 = new_modules[k2]
    if idx_name1!=idx_name2:
        if idx_name1.split('.')[-1]=='mask':
            k1+=1
        else:
            if idx_name2.split('.')[-1]=='1' and idx_name2.split('.')[-2]=='downsample':
                print('yes')
                idx_name1,m0 = old_modules[k1-3]
                start_mask = cfg_dict_mask[count_db]
                idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
#                 print(m0.weight.shape)
#                 print(idx0)
                
                w1 = m0.weight.data[:,idx0.tolist(),:,:].clone()
                m1.weight.data = w1[0,:,0,0].clone()
                
                count_db+=1
            k2+=1
                
        print('unmatched','************',idx_name1,idx_name2)
        #count+=1
            
    else:
        k1+=1
        k2+=1
        print('matched','************',idx_name1,idx_name2)
        if isinstance(m0,nn.BatchNorm2d):
            if idx_name1.split('.')[-1]=='bn3':
                start_mask = cfg_dict_mask_bn[count_bn]
                idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
                idx_name1,m0 = old_modules[k1+1]
                print('bn3',m0.weight.shape)
                w1 = m0.weight.data[:,idx0.tolist(),:,:].clone()
                m1.weight.data = w1[0,:,0,0].clone()
                
                count_bn+=1
                
        if isinstance(m0,nn.Conv2d):
            if len(idx_name1.split('.'))>2:
                if idx_name1.split('.')[2]=='downsample':
                    start_mask = cfg_dict_mask[count]
                    end_mask = cfg_dict_mask[count+1]
                    idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
                    idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
                    print(m1.weight.data.shape,idx0.shape,idx1.shape)
                    print(idx0.shape,idx1.shape)
                    w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
                    w1 = w1[idx1.tolist(), :, :, :].clone()
                    m1.weight.data = w1.clone()
                    count+=1

matched ************  
matched ************ conv1 conv1
matched ************ bn1 bn1
matched ************ relu relu
matched ************ maxpool maxpool
matched ************ layer1 layer1
matched ************ layer1.0 layer1.0
matched ************ layer1.0.conv1 layer1.0.conv1
matched ************ layer1.0.bn1 layer1.0.bn1
matched ************ layer1.0.conv2 layer1.0.conv2
matched ************ layer1.0.bn2 layer1.0.bn2
matched ************ layer1.0.conv3 layer1.0.conv3
matched ************ layer1.0.bn3 layer1.0.bn3
bn3 torch.Size([1, 256, 1, 1])
matched ************ layer1.0.relu layer1.0.relu
unmatched ************ layer1.0.mask layer1.0.downsample
matched ************ layer1.0.downsample layer1.0.downsample
matched ************ layer1.0.downsample.0 layer1.0.downsample.0
torch.Size([127, 64, 1, 1]) (64,) (127,)
(64,) (127,)
yes
unmatched ************ layer1.0.mask layer1.0.downsample.1
matched ************ layer1.1 layer1.1
matched ************ layer1.1.conv1 layer1.1.conv1
matched *

In [None]:

k1 = 0
k2 = 0
count = 0 #conv
count_db = 1  #downsample
count_bn = 0   #batchnorm
while k2!=len(new_modules):
    idx_name1,m0 = old_modules[k1]
    idx_name2,m1 = new_modules[k2]
    if idx_name1!=idx_name2:
        if idx_name1.split('.')[-1]=='mask':
            k1+=1
        else:
#             if idx_name2.split('.')[-1]=='1':
#                 print('yes')
#                 idx_name_db,m0_db = old_modules[k1-3]
#                 print(idx_name_db)
#                 start_mask_db = cfg_dict_mask[count_db]
#                 idx0_db = np.squeeze(np.argwhere(np.asarray(start_mask_db.cpu().numpy())))
#                 w1_db = m0_db.weight.data[idx0_db.tolist()].clone()
#                 m1.weight.data = w1_db.clone()
                
#                 count_db+=1
            k2+=1
                
        print('unmatched','************',idx_name1,idx_name2)
        #count+=1
            
    else:
        k1+=1
        k2+=1
        print('matched','************',idx_name1,idx_name2)

In [52]:
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()},f'/workspace/tracking_datasets/pruned_ckpts/dimp50_shared_mask_wo_bn/pruned_50p.pth')

In [54]:
x = resnet50_child(cfg=cfg)
x.load_state_dict(torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_shared_mask_wo_bn/pruned_50p.pth')['state_dict'])

<All keys matched successfully>

In [39]:
inp = torch.rand((2,3,224,224))
x(inp).shape

torch.Size([2, 1000])

In [60]:
ckpt.keys()

odict_keys(['feature_extractor.conv1.weight', 'feature_extractor.bn1.weight', 'feature_extractor.bn1.bias', 'feature_extractor.bn1.running_mean', 'feature_extractor.bn1.running_var', 'feature_extractor.bn1.num_batches_tracked', 'feature_extractor.layer1.0.conv1.weight', 'feature_extractor.layer1.0.bn1.weight', 'feature_extractor.layer1.0.bn1.bias', 'feature_extractor.layer1.0.bn1.running_mean', 'feature_extractor.layer1.0.bn1.running_var', 'feature_extractor.layer1.0.bn1.num_batches_tracked', 'feature_extractor.layer1.0.conv2.weight', 'feature_extractor.layer1.0.bn2.weight', 'feature_extractor.layer1.0.bn2.bias', 'feature_extractor.layer1.0.bn2.running_mean', 'feature_extractor.layer1.0.bn2.running_var', 'feature_extractor.layer1.0.bn2.num_batches_tracked', 'feature_extractor.layer1.0.conv3.weight', 'feature_extractor.layer1.0.bn3.weight', 'feature_extractor.layer1.0.bn3.bias', 'feature_extractor.layer1.0.bn3.running_mean', 'feature_extractor.layer1.0.bn3.running_var', 'feature_extract

In [83]:
ckpt['feature_extractor.layer1.0.conv3.weight']

tensor([[[[-4.0119e-09]],

         [[ 2.0645e-02]],

         [[ 5.3493e-02]],

         ...,

         [[ 1.0813e-08]],

         [[-1.4123e-02]],

         [[-4.2867e-02]]],


        [[[ 5.7415e-09]],

         [[ 1.0545e-02]],

         [[ 7.7597e-03]],

         ...,

         [[ 2.3281e-09]],

         [[ 1.5247e-02]],

         [[ 4.9836e-03]]],


        [[[-2.2345e-09]],

         [[ 7.8737e-03]],

         [[-1.9548e-03]],

         ...,

         [[-3.2814e-09]],

         [[ 3.7478e-03]],

         [[-1.3946e-02]]],


        ...,


        [[[ 6.1314e-09]],

         [[-1.8629e-03]],

         [[-1.2791e-02]],

         ...,

         [[-1.2589e-08]],

         [[-1.6471e-03]],

         [[ 5.1506e-03]]],


        [[[-7.4988e-09]],

         [[-1.2011e-02]],

         [[ 1.0461e-01]],

         ...,

         [[ 3.6372e-09]],

         [[-3.8683e-02]],

         [[ 6.3428e-03]]],


        [[[ 1.4321e-09]],

         [[-2.6719e-03]],

         [[-2.0555e-02]],

         

In [84]:
x.layer1[0].conv3.weight

Parameter containing:
tensor([[[[ 7.7597e-03]],

         [[-3.7214e-03]],

         [[ 1.1199e-02]],

         ...,

         [[ 4.3117e-03]],

         [[ 8.0621e-03]],

         [[ 4.9836e-03]]],


        [[[-1.9548e-03]],

         [[ 9.0840e-03]],

         [[-4.5280e-03]],

         ...,

         [[ 3.0070e-03]],

         [[ 1.1128e-02]],

         [[-1.3946e-02]]],


        [[[ 9.0324e-02]],

         [[ 7.1718e-02]],

         [[ 2.8351e-02]],

         ...,

         [[-1.1181e-01]],

         [[-1.2754e-02]],

         [[-1.7638e-02]]],


        ...,


        [[[-1.2791e-02]],

         [[-8.6347e-03]],

         [[ 1.1124e-02]],

         ...,

         [[-1.0229e-04]],

         [[-3.3172e-03]],

         [[ 5.1506e-03]]],


        [[[ 1.0461e-01]],

         [[ 9.2669e-02]],

         [[-2.7954e-02]],

         ...,

         [[-6.8950e-03]],

         [[-1.3025e-02]],

         [[ 6.3428e-03]]],


        [[[-2.0555e-02]],

         [[-5.6770e-02]],

         [[ 5.

In [None]:
k1 = 0
k2 = 0
count = 0 
while k2!=len(new_modules):
    idx1,m0 = old_modules[k1]
    idx2,m1 = new_modules[k2]
    if idx1!=idx2:
        if idx1.split('.')[-1]=='mask':
            k1+=1
        else:
            k2+=1
        print('unmatched','************',idx1,idx2)
        count+=1
        
    else:
        k1+=1
        k2+=1
        print('matched','************',idx1,idx2)
        
print(count)

In [None]:
import os
newmodel = resnet50(cfg=cfg)
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
first_bn = False

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
    if isinstance(m0, nn.BatchNorm2d):
        if first_bn==False:
            first_bn=True
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
            continue
        
        
 
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))

        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()