# Purpose
This notebooks adapts the weights scrapped online to my implementation architecture.

# Import

In [17]:
import torch
from glic.networks.completion_network import CompletionNetwork
from collections import OrderedDict
import pandas as pd
import torch.nn as nn

# Compare the architectures

In [2]:
# my implementation CN architecure
my_cn = CompletionNetwork()
my_cn_state_dict = my_cn.state_dict()
my_cn_shapes = OrderedDict([(k, v.shape) for k, v in my_cn_state_dict.items()])
my_cn_shapes_serie = pd.Series(my_cn_shapes.values(),index = my_cn_shapes.keys())

In [3]:
# # download the weight
# trained_cn_weights_address = "http://iizuka.cs.tsukuba.ac.jp/data/completionnet_places2.t7"
# import wget
# wget.download(trained_cn_weights_address,"../logs/scrapped_weights/trained_cn_weights.t7")

Weights were converted using https://github.com/clcarwin/convert_torch_to_pytorch  
*The github project had to be forked to account for SpacialDilatedConvolution*

In [4]:
# load the converted weights
scrapped_weights = torch.load("../logs/scrapped_weights/trained_cn_weights.pth")
scrapped_weights_shapes = OrderedDict([(k, v.shape) for k, v in scrapped_weights.items()])
scrapped_weights_shapes_serie = pd.Series(scrapped_weights_shapes.values(),index = scrapped_weights_shapes.keys())

In [5]:
my_cn_shapes_serie

cn_net.0.0.weight                  (64, 3, 5, 5)
cn_net.0.1.weight                          (64,)
cn_net.0.1.bias                            (64,)
cn_net.0.1.running_mean                    (64,)
cn_net.0.1.running_var                     (64,)
                                       ...      
cn_net.15.1.bias                           (32,)
cn_net.15.1.running_mean                   (32,)
cn_net.15.1.running_var                    (32,)
cn_net.15.1.num_batches_tracked               ()
cn_net.16.0.weight                 (3, 32, 3, 3)
Length: 97, dtype: object

In [6]:
scrapped_weights_shapes_serie

0.weight                  (64, 4, 5, 5)
0.bias                            (64,)
1.weight                          (64,)
1.bias                            (64,)
1.running_mean                    (64,)
                              ...      
46.running_mean                   (32,)
46.running_var                    (32,)
46.num_batches_tracked               ()
48.weight                 (3, 32, 3, 3)
48.bias                            (3,)
Length: 114, dtype: object

# Layer by layer

## First layer

In [13]:
my_cn_shapes_serie.iloc[:6]

cn_net.0.0.weight                 (64, 3, 5, 5)
cn_net.0.1.weight                         (64,)
cn_net.0.1.bias                           (64,)
cn_net.0.1.running_mean                   (64,)
cn_net.0.1.running_var                    (64,)
cn_net.0.1.num_batches_tracked               ()
dtype: object

In [39]:
scrapped_weights_shapes_serie.iloc[:7]

0.weight                 (64, 4, 5, 5)
0.bias                           (64,)
1.weight                         (64,)
1.bias                           (64,)
1.running_mean                   (64,)
1.running_var                    (64,)
1.num_batches_tracked               ()
dtype: object

In [31]:
# get the alpha bias
conv2d = nn.Conv2d(4,64,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
conv2d.parameters = scrapped_weights['0.weight']
rgb_channels = torch.zeros((3,256,256))
alpha_channel = torch.ones((1,256,256))
alpha_image = torch.concatenate((rgb_channels,alpha_channel))
alpha_bias = torch.mean(conv2d(alpha_image),dim=(1,2))

In [38]:
# set first layer
my_cn_state_dict['cn_net.0.0.weight'] = scrapped_weights['0.weight'][:,:3,:,:]
my_cn_state_dict['cn_net.0.1.weight'] = scrapped_weights['1.weight']
my_cn_state_dict['cn_net.0.1.bias'] = scrapped_weights['0.bias']+scrapped_weights['1.bias']+alpha_bias

In [41]:
# batch norm layer
shift = 1
for idx in range(3,6):
    my_cn_state_dict[my_cn_shapes_serie.index[idx]] = scrapped_weights[scrapped_weights_shapes_serie.index[idx+shift]]

## Other layers

In [51]:
def convert_layer(my_cn_layers_names,my_cn_state_dict,scrapped_weights_layers_names,scrapped_weights):
    """This function transfer the weights from scrapped weights to my implementation of CN.
    
    Args:
        my_cn_layers_names (list): list of layers names to be transfered in my implementation of CN.
        my_cn_state_dict (OrderedDict): my implementation of CN.
        scrapped_weights_layers_names (list): list of layers names to be transfered in scrapped weights.
        scrapped_weights (OrderedDict): scrapped weights.
    """

    generic_map = {0:[0],
                   1:[2],
                   2:[1,3],
                   3:[4],
                   4:[5],
                   5:[6],}
    assert len(my_cn_layers_names) == 6
    assert len(scrapped_weights_layers_names) == 7

    for idx in range(6):
        layer = my_cn_layers_names[idx]
        my_cn_state_dict[layer] = torch.zeros(my_cn_state_dict[layer].shape)
        for other_idx in generic_map[idx]:
            other_layer = scrapped_weights_layers_names[other_idx]
            my_cn_state_dict[layer] += scrapped_weights[other_layer]
            print(f"{other_layer} -> {layer}")

In [71]:
# other layers
for idx in range(1,len(my_cn_shapes_serie)//6):
    convert_layer(my_cn_shapes_serie.index[idx*6:idx*6+6],my_cn_state_dict,
                  scrapped_weights_shapes_serie.index[idx*7:idx*7+7],scrapped_weights)
    print("\n")

21.weight -> cn_net.7.0.weight
22.weight -> cn_net.7.1.weight
21.bias -> cn_net.7.1.bias
22.bias -> cn_net.7.1.bias
22.running_mean -> cn_net.7.1.running_mean
22.running_var -> cn_net.7.1.running_var
22.num_batches_tracked -> cn_net.7.1.num_batches_tracked


24.weight -> cn_net.8.0.weight
25.weight -> cn_net.8.1.weight
24.bias -> cn_net.8.1.bias
25.bias -> cn_net.8.1.bias
25.running_mean -> cn_net.8.1.running_mean
25.running_var -> cn_net.8.1.running_var
25.num_batches_tracked -> cn_net.8.1.num_batches_tracked


27.weight -> cn_net.9.0.weight
28.weight -> cn_net.9.1.weight
27.bias -> cn_net.9.1.bias
28.bias -> cn_net.9.1.bias
28.running_mean -> cn_net.9.1.running_mean
28.running_var -> cn_net.9.1.running_var
28.num_batches_tracked -> cn_net.9.1.num_batches_tracked


30.weight -> cn_net.10.0.weight
31.weight -> cn_net.10.1.weight
30.bias -> cn_net.10.1.bias
31.bias -> cn_net.10.1.bias
31.running_mean -> cn_net.10.1.running_mean
31.running_var -> cn_net.10.1.running_var
31.num_batches_tr

# Last layer

In [72]:
my_cn_shapes_serie.iloc[-6:]

cn_net.15.1.weight                         (32,)
cn_net.15.1.bias                           (32,)
cn_net.15.1.running_mean                   (32,)
cn_net.15.1.running_var                    (32,)
cn_net.15.1.num_batches_tracked               ()
cn_net.16.0.weight                 (3, 32, 3, 3)
dtype: object

In [73]:
scrapped_weights_shapes_serie.iloc[-7:]

46.weight                         (32,)
46.bias                           (32,)
46.running_mean                   (32,)
46.running_var                    (32,)
46.num_batches_tracked               ()
48.weight                 (3, 32, 3, 3)
48.bias                            (3,)
dtype: object