In [41]:
"""
Remove the unnecessary filters from a CNN pruned using main.py, then compare the converted net with the previous.
While main.py allows to determine which filters can be removed, those remain in the network's architecture. The
current script creates a new, lightweight architecture from the result of main.py.

NOTE: The arguments passed to this script are parsed in main.py (i.e. a dataset choice must be made).
"""

import time
from collections import OrderedDict

import torch
import torch.nn as nn

# import VITALabAI.project.bar.model.resnet as resnet
# from VITALabAI.project.bar.model.layers import SparseConvConfig
# # from VITALabAI.project.bar.main import ClassificationTraining


class IdentityModule(nn.Module):
    def forward(self, x):
        return x


def convert_resnet(net, insert_identity_modules=False):
    """Convert a ResNetCifar module (in place)

    Returns
    -------
        net: the mutated net
    """
    
    net.conv1, net.bn1 = convert_conv_bn(net.conv1, net.bn1, torch.ones(3).byte(), (cfg_dict['bn1']>0))
    in_gates = torch.ones(net.conv1.out_channels).byte()

    clean_res = True
    net.layer1, in_gates = convert_layer(net.layer1, in_gates, insert_identity_modules, clean_res, layer_name = 'layer1')
    net.layer2, in_gates = convert_layer(net.layer2, in_gates, insert_identity_modules, clean_res, layer_name = 'layer2')
    net.layer3, in_gates = convert_layer(net.layer3, in_gates, insert_identity_modules, clean_res, layer_name = 'layer3')
    net.layer4, in_gates = convert_layer(net.layer4, in_gates, insert_identity_modules, clean_res, layer_name = 'layer4')


    if clean_res:
        net.fc = convert_fc_head(net.fc, in_gates)
    else:
        net.fc = resnet.InwardPrunedLinear(convert_fc_head(net.fc, in_gates), mask2i(in_gates))

    return net


def convert_layer(layer_module, in_gates, insert_identity_modules, clean_res, layer_name =None):
    """Convert a ResnetCifar layer (in place)

    Parameters
    ----------
        layer_module: a nn.Sequential
        in_gates: mask

    Returns
    -------
        layer_module: mutated layer_module
        in_gates: ajusted mask
    """

    previous_layer_gates = in_gates

    new_blocks = []
    for block_num, block in enumerate(layer_module):
        new_block, in_gates = convert_block(block, in_gates, block_name = layer_name+'.'+str(block_num))
        if new_block is None:
            if insert_identity_modules:
                new_blocks.append(IdentityModule())
        else:
            new_blocks.append(new_block)

    # Remove unused residual features
    if clean_res:
        print()
        cur_layer_gates = in_gates
        for block in new_blocks:
            if isinstance(block, IdentityModule):
                continue
            clean_block(block, previous_layer_gates, cur_layer_gates)  # in-place

    layer_module = nn.Sequential(*new_blocks)
    return layer_module, in_gates


def clean_block(mixed_block, previous_layer_alivef, cur_layer_alivef):
    """Remove unused res features (operates in-place)"""

    def clean_indices(idx, alive_mask=cur_layer_alivef):
        mask = i2mask(idx, alive_mask)
        mask = mask[mask2i(alive_mask)]
        return mask2i(mask)

    if mixed_block.f_res is None:
        mixed_block.in_idx = clean_indices(mixed_block.in_idx)
    else:
        mixed_block.in_idx = clean_indices(mixed_block.in_idx, alive_mask=previous_layer_alivef)
        mixed_block.res_size = cur_layer_alivef.sum().item()
        print('DOWNS ----- Res size: ', mixed_block.res_size)
        mixed_block.res_idx = clean_indices(mixed_block.res_idx)
        print('Res:  ', len(mixed_block.res_idx))
    mixed_block.delta_idx = clean_indices(mixed_block.delta_idx)

    print('In:   ', len(mixed_block.in_idx))
    print('Delta:', len(mixed_block.delta_idx))


def convert_block(block_module, in_gates, block_name = None):
    """Convert a Basic Resnet block (in place)

    Parameters
    ----------
        block_module: a BasicBlock
        in_gates: received mask

    Returns
    -------
        block_module: mutated block
        in_gates: out_gates of this block (in_gates for next block)
    """

#     assert not hasattr(block_module, 'conv3')  # must be basic block

    b1_gates = (cfg_dict[f'{block_name}.bn1']>0)   # get_gates(block_module.bn1)
    b2_gates = (cfg_dict[f'{block_name}.bn2']>0)
    b3_gates = (cfg_dict[f'{block_name}.bn3']>0)
   

    delta_branch_is_pruned = b1_gates.sum().item() == 0 or b2_gates.sum().item() == 0 or b3_gates.sum().item() == 0
    
    # Delta branch
    if not delta_branch_is_pruned:
        block_module.conv1, block_module.bn1 = convert_conv_bn(block_module.conv1, block_module.bn1, in_gates, b1_gates)
        block_module.conv2, block_module.bn2 = convert_conv_bn(block_module.conv2, block_module.bn2, b1_gates, b2_gates)
        block_module.conv3, block_module.bn3 = convert_conv_bn(block_module.conv3, block_module.bn3, b2_gates, b3_gates)


    if block_module.downsample is not None:
        ds_gates = (cfg_dict[f'{block_name}.downsample.1']>0) # get_gates(block_module.downsample[1])
        ds_conv, ds_bn = convert_conv_bn(block_module.downsample[0], block_module.downsample[1], in_gates, ds_gates)
        ds_module = nn.Sequential(ds_conv, ds_bn)

        if delta_branch_is_pruned:
            mixed_block = MixedBlock(f_delta=None, delta_idx=None,
                                            f_res=ds_module,
                                            in_idx=mask2i(in_gates),
                                            res_idx=mask2i(ds_gates),
                                            res_size=len(b3_gates))
        else:
            block_module.downsample = ds_module
            mixed_block = MixedBlock.from_bottleneck(block_module,
                                                       delta_idx=mask2i(b3_gates),
                                                       in_idx=mask2i(in_gates),
                                                       res_idx=mask2i(ds_gates),
                                                       res_size=len(b3_gates))
        in_gates = elementwise_or(ds_gates, b3_gates)
    else:
        if delta_branch_is_pruned:
            mixed_block = None
        else:
            mixed_block = MixedBlock.from_bottleneck(block_module,
                                                       delta_idx=mask2i(b3_gates),
                                                       in_idx=mask2i(in_gates))
        in_gates = elementwise_or(in_gates, b3_gates)

    return mixed_block, in_gates


def convert_conv_bn(conv_module, bn_module, in_gates, out_gates):
    in_indices = mask2i(in_gates)  # indices of kept features
    out_indices = mask2i(out_gates)

    # Keep the good ones
    new_conv_w = conv_module.weight.data[out_indices][:, in_indices]

    new_conv = make_conv(new_conv_w, from_module=conv_module)
    new_bn = convert_bn(bn_module, out_indices)

    new_conv.out_idx = out_indices
    
    return new_conv, new_bn


def convert_fc_head(fc_module, in_gates):
    """Convert a the final FC module of the net

    Parameters
    ----------
        fc_module: a nn.Linear with weight tensor of size (out_f, in_f)
        in_gates: binary vector or list of size in_f

    Returns
    -------
        fc_module: mutated module
    """

    in_indices = mask2i(in_gates)
    new_weight_tensor = fc_module.weight.data[:, in_indices]
    return make_fc(new_weight_tensor, from_module=fc_module)


def convert_bn(bn_module, out_indices):
#     z = bn_module.get_gates(stochastic=False)
    new_weight = bn_module.weight.data[out_indices] # * z[out_indices]
    new_bias = bn_module.bias.data[out_indices] # * z[out_indices]

    new_bn_module = nn.BatchNorm2d(len(new_weight))
    new_bn_module.weight.data.copy_(new_weight)
    new_bn_module.bias.data.copy_(new_bias)
    new_bn_module.running_mean.copy_(bn_module.running_mean[out_indices])
    new_bn_module.running_var.copy_(bn_module.running_var[out_indices])

    new_bn_module.out_idx = out_indices

    return new_bn_module


def make_bn(bn_module, kept_indices):
    new_bn_module = nn.BatchNorm2d(len(kept_indices))
    new_bn_module.weight.data.copy_(bn_module.weight.data[kept_indices])
    new_bn_module.bias.data.copy_(bn_module.bias.data[kept_indices])
    new_bn_module.running_mean.copy_(bn_module.running_mean[kept_indices])
    new_bn_module.running_var.copy_(bn_module.running_var[kept_indices])

    if hasattr(bn_module, 'out_idx'):
        new_bn_module.out_idx = bn_module.out_idx[kept_indices]
    else:
        new_bn_module.out_idx = kept_indices

    return new_bn_module


def make_conv(weight_tensor, from_module):
    # NOTE: No bias

    # New weight size
    in_channels = weight_tensor.size(1)
    out_channels = weight_tensor.size(0)

    # Other params
    kernel_size = from_module.kernel_size
    stride = from_module.stride
    padding = from_module.padding

    conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
    conv.weight.data.copy_(weight_tensor)
    return conv


def make_fc(weight_tensor, from_module):
    in_features = weight_tensor.size(1)
    out_features = weight_tensor.size(0)
    fc = nn.Linear(in_features, out_features)
    fc.weight.data.copy_(weight_tensor)
    fc.bias.data.copy_(from_module.bias.data)
    return fc

def elementwise_or(a, b):
    return (a + b) > 0


def mask2i(mask):
#     assert mask.dtype == torch.uint8
    return mask.nonzero().view(-1)  # Note: do not use .squeeze() because single item becomes a scalar instead of 1-vec


def i2mask(i, from_tensor):
    x = torch.zeros_like(from_tensor)
    x[i] = 1
    return x


In [42]:
class MixedBlock(nn.Module):
    def __init__(self, f_delta, delta_idx, in_idx, f_res=None, res_idx=None, res_size=None):
        super(MixedBlock, self).__init__()
        self.f_delta = f_delta
        self.delta_idx = delta_idx
        self.in_idx = in_idx
        self.f_res = f_res
        self.res_idx = res_idx
        self.res_size = res_size
        self.activ = nn.ReLU(inplace=True)

        self.res_scatter_idx = None
        self.delta_scatter_idx = None

        if f_delta is None:
            self.forward = self.forward_without_delta
        else:
            self.forward = self.forward_with_delta

    def scatter_features(self, idx, src, final_size):
        if self.res_scatter_idx is None or self.res_scatter_idx.size(0) != src.size(0):
            scatter_idx = idx.new_empty(*src.size())
            scatter_idx.copy_(idx[None, :, None, None])
            self.res_scatter_idx = scatter_idx

        x = torch.zeros(src.size(0), final_size, src.size(2), src.size(3)).to(src)
        x.scatter_(dim=1, index=self.res_scatter_idx, src=src)
        return x

    def scatter_add_features(self, dst, idx, src):
        if self.delta_scatter_idx is None or self.delta_scatter_idx.size(0) != src.size(0):
            scatter_idx = idx.new_empty(*src.size())
            scatter_idx.copy_(idx[None, :, None, None])
            self.delta_scatter_idx = scatter_idx

        dst.scatter_add_(dim=1, index=self.delta_scatter_idx, src=src)

    def forward_with_delta(self, x):
        # x: (B, C, H, W)

        if self.f_res is None:
            x_alive = x.index_select(dim=1, index=self.in_idx)
            delta = self.f_delta.forward(x_alive)  # 3x3 conv
        else:
            delta = self.f_delta.forward(x)  # 3x3 conv

            res = self.f_res.forward(x)  # 1x1 conv
            x = self.scatter_features(self.res_idx, res, self.res_size)

        self.scatter_add_features(x, self.delta_idx, delta)
        
        return self.activ(x)

    def forward_without_delta(self, x):
        res = self.f_res.forward(x)  # 1x1 conv
        x = self.scatter_features(self.res_idx, res, self.res_size)

        return self.activ(x)

    @staticmethod
    def from_basic(block, delta_idx, in_idx, res_idx=None, res_size=None):
        f_delta = nn.Sequential(
            block.conv1,
            block.bn1,
            block.activ,  # nn.ReLU(inplace=True)
            block.conv2,
            block.bn2
        )
        return MixedBlock(f_delta, delta_idx, in_idx, block.downsample, res_idx, res_size)
    
    @staticmethod
    def from_bottleneck(block, delta_idx, in_idx, res_idx=None, res_size=None):
        f_delta = nn.Sequential(
            block.conv1,
            block.bn1,
            block.relu,  # nn.ReLU(inplace=True)
            block.conv2,
            block.bn2,
            block.relu,
            block.conv3,
            block.bn3
        )
        return MixedBlock(f_delta, delta_idx, in_idx, block.downsample, res_idx, res_size)

In [43]:
# model

In [44]:
# list(model.parameters())

In [45]:
%cd /workspace/pytracking

/workspace/pytracking


In [46]:
model = resnet50_single_mask()

In [80]:
%cd /workspace/pytracking
from ltr.models.backbone.resnet import resnet50
from ltr.models.backbone.resnet_single_mask import resnet50_single_mask 

/workspace/pytracking


In [81]:
from ltr.admin.loading import torch_load_legacy
ckpt = torch_load_legacy('/workspace/tracking_datasets/saved_ckpts/ltr/dimp/sparse/dimp50_single_mask/DiMPnet_ep0038.pth.tar')['net']

In [82]:
from collections import OrderedDict
new_state = OrderedDict()

for key, value in ckpt.items():
    key = key[18:] # remove `module.`
    new_state[key] = value
model.load_state_dict(new_state, strict = False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['initializer.filter_conv.weight', 'initializer.filter_conv.bias', 'optimizer.log_step_length', 'optimizer.filter_reg', 'optimizer.label_map_predictor.weight', 'optimizer.target_mask_predictor.0.weight', 'optimizer.spatial_weight_predictor.weight', '_extractor.0.weight', '_1r.0.weight', '_1r.0.bias', '_1r.1.weight', '_1r.1.bias', '_1r.1.running_mean', '_1r.1.running_var', '_1r.1.num_batches_tracked', '_1t.0.weight', '_1t.0.bias', '_1t.1.weight', '_1t.1.bias', '_1t.1.running_mean', '_1t.1.running_var', '_1t.1.num_batches_tracked', '_2t.0.weight', '_2t.0.bias', '_2t.1.weight', '_2t.1.bias', '_2t.1.running_mean', '_2t.1.running_var', '_2t.1.num_batches_tracked', 'r.0.weight', 'r.0.bias', 'r.1.weight', 'r.1.bias', 'r.1.running_mean', 'r.1.running_var', 'r.1.num_batches_tracked', '3r.0.weight', '3r.0.bias', '3r.1.weight', '3r.1.bias', '3r.1.running_mean', '3r.1.running_var', '3r.1.num_batches_tracked', '4r.0.weight', '4r.0.bias', '4r.1.weig

In [84]:
import numpy as np
np.sum((model.layer1[0].mask.weight>0.5).cpu().numpy()*1)

107

In [65]:
new_state['layer1.0.mask.weight']

tensor([[[[ 4.9439e-01]],

         [[ 5.2811e-01]],

         [[ 5.0189e-01]],

         [[ 5.0036e-01]],

         [[ 4.2392e-01]],

         [[ 4.8955e-01]],

         [[ 4.4509e-01]],

         [[ 5.1082e-01]],

         [[ 8.0642e-05]],

         [[-6.3515e-04]],

         [[ 4.1840e-01]],

         [[ 4.9996e-01]],

         [[ 5.0042e-01]],

         [[ 5.1869e-01]],

         [[ 5.0854e-01]],

         [[ 4.9412e-01]],

         [[ 3.1977e-01]],

         [[ 5.1078e-01]],

         [[ 5.0340e-01]],

         [[-3.0990e-04]],

         [[ 3.7836e-01]],

         [[ 5.0641e-01]],

         [[-4.7648e-06]],

         [[ 1.8034e-05]],

         [[ 5.0227e-01]],

         [[ 9.6952e-05]],

         [[ 5.1169e-01]],

         [[ 5.1010e-01]],

         [[ 5.2279e-01]],

         [[ 5.0660e-01]],

         [[ 5.1267e-01]],

         [[ 5.1161e-01]],

         [[ 5.0470e-01]],

         [[ 4.8641e-01]],

         [[ 2.0751e-04]],

         [[-4.3775e-04]],

         [[-1.1354e-03]],

 

In [22]:
total1 = 0
total2 = 0
total3 = 0
for idx, m in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        if idx.split('.')[0] == 'layer1':
            total1 += m.weight.data.shape[0]
        if idx.split('.')[0] == 'layer2':
            total2 += m.weight.data.shape[0]
        if idx.split('.')[0] == 'layer3':
            total3 += m.weight.data.shape[0]

bn1 = torch.zeros(total1)
bn2 = torch.zeros(total2)
bn3 = torch.zeros(total3)
index1 = 0
index2 = 0
index3 = 0
for (idx, m) in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        if idx.split('.')[0] == 'layer1':
            size = m.weight.data.shape[0]
            bn1[index1:(index1+size)] = m.weight.data.abs().clone()
            index1 += size
        if idx.split('.')[0] == 'layer2':
            size = m.weight.data.shape[0]
            bn2[index2:(index2+size)] = m.weight.data.abs().clone()
            index2 += size
        if idx.split('.')[0] == 'layer3':
            size = m.weight.data.shape[0]
            bn3[index3:(index3+size)] = m.weight.data.abs().clone()
            index3 += size
            
y1, i = torch.sort(bn1)
thre_index1 = int(total1 * 0.5)
thre1 = y1[thre_index1]

y2, i = torch.sort(bn2)
thre_index2 = int(total2 * 0.5)
thre2 = y2[thre_index2]

y3, i = torch.sort(bn3)
thre_index3 = int(total3 * 0.5)
thre3 = y3[thre_index3]

In [23]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (mask): scaler()
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_s

In [None]:
# modules = list(model.named_modules())
# for k, (idx, m) in enumerate(modules):
#     print(idx)

In [None]:
newmodel = resnet50()


pruned = 0
cfg = []
cfg_mask = []
modules = list(model.named_modules())

old_modules = list(model.modules())
new_modules = list(newmodel.modules())
cfg_dict = {}

for k, (idx, m) in enumerate(modules):
    
    
    
    if isinstance(m, nn.BatchNorm2d) :
        if idx.split('.')[0] == 'layer1':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre1).float()
#             m.wmul_eight.data = weight_copy*mask
            
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
#             if not isinstance(modules[k-2],nn.Sequential):
#             print(idx.split('.')[1])
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer2':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre2).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
#             if not isinstance(modules[k-2],nn.Sequential):
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer3':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre3).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
            
        else:
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(0.0).float()
#             pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            try:
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
            except:
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            cfg_dict[idx] = mask
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/(total1+total2+total3)

In [None]:
cfg

In [None]:
# cfg_dict

In [None]:
import pickle
with open('/workspace/tracking_datasets/cfg_dict_resnet_child/cfg_dict_resnet50_layerwise_budget_50.json', 'wb') as fp:
    pickle.dump(cfg_dict, fp)

In [None]:
import pickle
with open('/workspace/tracking_datasets/cfg_dict_resnet_child/cfg_dict_resnet50_layerwise_budget_50.json', 'rb') as fp:
    data = pickle.load(fp)

In [None]:
data

In [None]:
out = convert_resnet(model)

In [None]:
out.state_dict()['layer1.0.f_delta.1.weight']

In [None]:
model

In [None]:
out.state_dict()['layer1.0.f_delta.0.weight']

In [None]:
out.state_dict().keys()

In [None]:
torch.save(out.state_dict(),'/workspace/tracking_datasets/pruned_ckpts/dimp50_bar/correct_layerwise_pruned_50p.pth.tar')

In [None]:
# out.load_state_dict(torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_bar/layerwise_pruned_50p.pth.tar'))

In [None]:
# ckpt = torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_bar/layerwise_pruned_50p.pth')
# ckpt['layer1.0.f_delta.1.weight']

In [None]:
ckpt = torch.load('/workspace/tracking_datasets/pruned_ckpts/dimp50_bar/correct_layerwise_pruned_50p.pth.tar')
ckpt['layer1.0.f_delta.1.weight']

In [None]:
input1 = torch.rand((2,3,224,224))
output = out(input1)
print(output.shape)

In [None]:
import os
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
first_bn = False

for layer_id in range(len(old_modules)):
    m = old_modules[layer_id]
    m1 = new_modules[layer_id]
    
     if isinstance(m, nn.BatchNorm2d) :
        if idx.split('.')[0] == 'layer1':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre1).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
#             if not isinstance(modules[k-2],nn.Sequential):
#             print(idx.split('.')[1])
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer2':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre2).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
#             if not isinstance(modules[k-2],nn.Sequential):
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
        elif idx.split('.')[0] == 'layer3':
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre3).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            if not idx.split('.')[2]=='downsample':
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
            else:
                print('yes')
                print(torch.sum(mask),'**')
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
            
        else:
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(0.0).float()
#             pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            try:
                if not idx.split('.')[2]=='downsample':
                    cfg.append(int(torch.sum(mask)))
                    print(torch.sum(mask))
                else:
                    print('yes')
                    print(torch.sum(mask),'**')
            except:
                cfg.append(int(torch.sum(mask)))
                print(torch.sum(mask))
    #             print(modules[k-2])
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

    


In [None]:
%cd /workspace/pytracking
from ltr.admin.loading import torch_load_legacy
ckpt1 = torch_load_legacy('/workspace/pytracking/pytracking/networks/dimp50.pth')['net']

In [None]:
ckpt1['bb_regressor.conv3_1r.0.weight'].shape

In [None]:
import pickle
with open('/workspace/tracking_datasets/cfg_dict_resnet_child/cfg_dict_resnet50_layerwise_budget_50.json', 'rb') as fp:
    cfg_dict = pickle.load(fp)

In [None]:
cfg_dict['layer2.3.bn3']

In [None]:
import numpy as np

In [None]:
out = np.logical_or(cfg_dict['layer2.0.downsample.1'],cfg_dict['layer2.0.bn3'])
out = np.logical_or(out,cfg_dict['layer2.1.bn3'])
out = np.logical_or(out,cfg_dict['layer2.2.bn3'])
out = np.logical_or(out,cfg_dict['layer2.3.bn3'])
out

In [None]:
out1 = np.logical_or(cfg_dict['layer3.0.downsample.1'],cfg_dict['layer3.0.bn3'])
out1 = np.logical_or(out1,cfg_dict['layer3.1.bn3'])
out1 = np.logical_or(out1,cfg_dict['layer3.2.bn3'])
out1 = np.logical_or(out1,cfg_dict['layer3.3.bn3'])
out1 = np.logical_or(out1,cfg_dict['layer3.4.bn3'])
out1 = np.logical_or(out1,cfg_dict['layer3.5.bn3'])
out1

In [None]:
out1.sum()

In [None]:
out.sum()

In [None]:
idx1 = np.where(out>0)[0]
idx2 = np.where(out1>0)[0]

In [None]:
ckpt1['bb_regressor.conv3_1r.0.weight'] = ckpt1['bb_regressor.conv3_1r.0.weight'][:,idx1,:,:]
ckpt1['bb_regressor.conv3_1t.0.weight'] = ckpt1['bb_regressor.conv3_1t.0.weight'][:,idx1,:,:]

ckpt1['bb_regressor.conv4_1r.0.weight'] = ckpt1['bb_regressor.conv4_1r.0.weight'][:,idx2,:,:]
ckpt1['bb_regressor.conv4_1t.0.weight'] = ckpt1['bb_regressor.conv4_1t.0.weight'][:,idx2,:,:]

ckpt1['classifier.feature_extractor.0.weight'] = ckpt1['classifier.feature_extractor.0.weight'][:,idx2,:,:]

In [None]:
ckpt1['classifier.feature_extractor.0.weight'].shape

In [None]:
ckpt1['bb_regressor.conv4_1t.0.weight'].shape

In [None]:
torch.save(ckpt1,'/workspace/pytracking/pytracking/networks/dimp50_bar.pth')

In [None]:
# ckpt1 = ckpt1['bb_regressor.conv3_1r.0.weight'][:,idx1,:,:]