# VGG16 Net Surgery
VGG16 Transfer Learning After 3-to-4-Channel Input Conversion

## Background
The weakly supervised segmentation techniques presented in the "Simply Does It" paper use a backbone convnet (either DeepLab or VGG16 network) **pre-trained on ImageNet**. This pre-trained network takes RGB images as an input (W x H x 3). Remember that the weakly supervised version is trained using **4-channel inputs: RGB + a binary mask with a filled bounding box of the object instance**. Therefore, we need to **perform net surgery and create a 4-channel input version** of the VGG16 net, initialized with the 3-channel parameter values **except** for the additional convolutional filters (we use Gaussian initialization for them).

Here's how a typical VGG16 convnet looks like:

![](img/vgg16.png)

In the picture above, we're modifying the first block on the left.

In [2]:
"""
net_surgery.ipynb

VGG16 Transfer Learning After 3-to-4-Channel Input Conversion

Written by Phil Ferriere

Licensed under the MIT License (see LICENSE for details)

Based on:
  - https://github.com/minhnhat93/tf_object_detection_multi_channels/blob/master/edit_checkpoint.py
    Written by SNhat M. Nguyen
    Unknown code license
"""
from tensorflow.python import pywrap_tensorflow
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


## Configuration

In [3]:
num_input_channels = 4 # AStream uses 4-channel inputs
init_method = 'gaussian' # ['gaussian'|'spread_average'|'zeros']
input_path = 'models/vgg_16_3chan/vgg_16_3chan.ckpt' # copy of checkpoint in http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
output_path = 'models/vgg_16_4chan/vgg_16_4chan.ckpt'

## Surgery

Here are the VGG16 stage 1 parameters we'll want to modify:
```
(dlwin36tfwss) Phil@SERVERP E:\repos\tf-wss\tfwss\tools
$  python -m inspect_checkpoint --file_name=../models/vgg_16_3chan.ckpt | grep -i conv1_1
vgg_16/conv1/conv1_1/weights (DT_FLOAT) [3,3,3,64]
vgg_16/conv1/conv1_1/biases (DT_FLOAT) [64]
```
First, let's find the correct tensor:

In [4]:
print('Loading checkpoint...')
reader = pywrap_tensorflow.NewCheckpointReader(input_path)
print('...done loading checkpoint.')

var_to_shape_map = reader.get_variable_to_shape_map()
var_to_edit_name = 'vgg_16/conv1/conv1_1/weights'

for key in sorted(var_to_shape_map):
    if key != var_to_edit_name:
        var = tf.Variable(reader.get_tensor(key), name=key, dtype=tf.float32)
    else:
        var_to_edit = reader.get_tensor(var_to_edit_name)
        print('Tensor {} of shape {} located.'.format(var_to_edit_name, var_to_edit.shape))

Loading checkpoint...
...done loading checkpoint.
Tensor vgg_16/conv1/conv1_1/weights of shape (3, 3, 3, 64) located.


Now, let's edit the tensor and initialize it according to the chosen init method:

In [5]:
sess = tf.Session()
if init_method != 'gaussian':
    print('Error: Unimplemented initialization method')
new_channels_shape = list(var_to_edit.shape)
new_channels_shape[2] = num_input_channels - 3
gaussian_var = tf.random_normal(shape=new_channels_shape, stddev=0.001).eval(session=sess)
new_var = np.concatenate([var_to_edit, gaussian_var], axis=2)
new_var = tf.Variable(new_var, name=var_to_edit_name, dtype=tf.float32)

Finally, let's update the network parameters and the save the updated model to disk:

In [6]:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, output_path)

'models/vgg_16_4chan/vgg_16_4chan.ckpt'

## Verification
Verify the result of this surgery by looking at the output of the following commands:
```
$ python -m inspect_checkpoint --file_name=../models/vgg_16_3chan.ckpt --tensor_name=vgg_16/conv1/conv1_1/weights > vgg_16_3chan-conv1_1-weights.txt
$ python -m inspect_checkpoint --file_name=../models/vgg_16_4chan.ckpt --tensor_name=vgg_16/conv1/conv1_1/weights > vgg_16_4chan-conv1_1-weights.txt
```
You should see the following values in the first filter:
```
# 3-channel VGG16
3,3,3,0
[[[ 0.4800154   0.55037946  0.42947057]
  [ 0.4085474   0.44007453  0.373467  ]
  [-0.06514555 -0.08138704 -0.06136011]]

 [[ 0.31047726  0.34573907  0.27476987]
  [ 0.05020237  0.04063221  0.03868078]
  [-0.40338343 -0.45350131 -0.36722335]]

 [[-0.05087169 -0.05863491 -0.05746817]
  [-0.28522751 -0.33066967 -0.26224968]
  [-0.41851634 -0.4850302  -0.35009676]]]
  
# 4-channel VGG16
3,3,4,0
[[[  4.80015397e-01   5.50379455e-01   4.29470569e-01   1.13388560e-04]
  [  4.08547401e-01   4.40074533e-01   3.73466998e-01   7.61439209e-04]
  [ -6.51455522e-02  -8.13870355e-02  -6.13601133e-02   4.74345696e-04]]

 [[  3.10477257e-01   3.45739067e-01   2.74769872e-01   4.11637186e-04]
  [  5.02023660e-02   4.06322069e-02   3.86807770e-02   1.38304755e-03]
  [ -4.03383434e-01  -4.53501314e-01  -3.67223352e-01   1.28411280e-03]]

 [[ -5.08716851e-02  -5.86349145e-02  -5.74681684e-02  -6.34787197e-04]
  [ -2.85227507e-01  -3.30669671e-01  -2.62249678e-01  -1.77454809e-03]
  [ -4.18516338e-01  -4.85030204e-01  -3.50096762e-01   2.10441509e-03]]]
  
```