# 0. Import

In [14]:
import keras
from keras.models import load_model
import numpy as np 
import os
import skimage.io as io
import skimage.transform as trans
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras

# 1. Model and Weight Analysis

## 1.1 Define Original U-Net 

In [90]:
def unet(pretrained_weights = None,input_size = (512,512,3)):
    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(3, 1, activation = 'sigmoid')(conv9)

    model = Model(input = inputs, output = conv10)
    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    return model

In [4]:
model = unet()
model.summary()
model.load_weights('unet_wseg_waymo.hdf5')
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 512, 512, 64) 36928       conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 256, 256, 64) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (

In [None]:
weights=model.get_weights()
print(len(weights))
print(weights)

## 1.2 Define Small U-Net (Half num of channels)

In [53]:
def unet_pruning(input_size = (512,512,3)):
    inputs = Input(input_size)
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    up9 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv10 = Conv2D(3, 1, activation = 'sigmoid')(conv9)

    model = Model(input = inputs, output = conv10)
    return model

In [54]:
unet_pruning_model = unet_pruning()
unet_pruning_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 512, 512, 32) 896         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 512, 512, 32) 9248        conv2d_47[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D)  (None, 256, 256, 32) 0           conv2d_48[0][0]                  
__________________________________________________________________________________________________
conv2d_49 



## 1.3 Display Model Info: Name, Weight, Shape

In [26]:
layers = model.layers
p_layers = unet_pruning_model.layers

print("original:", len(layers), "pruning:", len(p_layers))

for i in range(len(model.layers)):
    layer = model.layers[i]
    layer_p = unet_pruning_model.layers[i]
    
    print(layer.name, "-", layer_p.name)

original: 38 pruning: 38
input_1 - input_2
conv2d_1 - conv2d_24
conv2d_2 - conv2d_25
max_pooling2d_1 - max_pooling2d_5
conv2d_3 - conv2d_26
conv2d_4 - conv2d_27
max_pooling2d_2 - max_pooling2d_6
conv2d_5 - conv2d_28
conv2d_6 - conv2d_29
max_pooling2d_3 - max_pooling2d_7
conv2d_7 - conv2d_30
conv2d_8 - conv2d_31
dropout_1 - dropout_3
max_pooling2d_4 - max_pooling2d_8
conv2d_9 - conv2d_32
conv2d_10 - conv2d_33
dropout_2 - dropout_4
up_sampling2d_1 - up_sampling2d_5
conv2d_11 - conv2d_34
concatenate_1 - concatenate_5
conv2d_12 - conv2d_35
conv2d_13 - conv2d_36
up_sampling2d_2 - up_sampling2d_6
conv2d_14 - conv2d_37
concatenate_2 - concatenate_6
conv2d_15 - conv2d_38
conv2d_16 - conv2d_39
up_sampling2d_3 - up_sampling2d_7
conv2d_17 - conv2d_40
concatenate_3 - concatenate_7
conv2d_18 - conv2d_41
conv2d_19 - conv2d_42
up_sampling2d_4 - up_sampling2d_8
conv2d_20 - conv2d_43
concatenate_4 - concatenate_8
conv2d_21 - conv2d_44
conv2d_22 - conv2d_45
conv2d_23 - conv2d_46


In [32]:
conv_layer_p = unet_pruning_model.get_layer("conv2d_24")
conv_layer_p_weights = conv_layer_p.get_weights()
print(conv_layer_p_weights)

[array([[[[ 2.19514556e-02,  2.42551953e-01,  3.99107710e-02,
           2.20315848e-02,  2.14625567e-01,  2.47779921e-01,
          -1.93289958e-03,  3.68081570e-01, -3.63029569e-01,
          -6.71134293e-02,  4.01774228e-01, -2.14680463e-01,
          -1.25934556e-01,  3.03576570e-02,  1.72079578e-01,
          -4.83537763e-01,  1.39848322e-01, -4.48408835e-02,
          -1.67904899e-01, -3.73050600e-01,  7.01229051e-02,
           3.67138803e-01,  1.93945110e-01, -3.48909423e-02,
          -1.33234207e-02,  2.98654437e-02,  3.57851535e-01,
           1.31390497e-01,  3.66153270e-01,  4.92509594e-03,
          -1.32643864e-01,  3.72226745e-01],
         [-1.00713976e-01,  2.87748307e-01, -6.20141067e-03,
          -4.43248540e-01,  1.50314972e-01, -3.70585829e-01,
          -2.97284812e-01,  3.76137882e-01, -1.87056199e-01,
          -2.13085994e-01, -4.14029211e-01, -3.33959222e-01,
          -4.71831262e-01,  5.77195212e-02, -2.25581333e-01,
           6.42331690e-02, -2.00320870e

In [36]:
print(conv_layer_p_weights[0].shape)
print(conv_layer_p_weights[1].shape)

(3, 3, 3, 32)
(32,)


In [37]:
conv_layer = model.get_layer("conv2d_1")
conv_layer_weights = conv_layer.get_weights()
print(conv_layer_weights)

[array([[[[ 7.48673603e-02, -3.59741688e-01,  1.05753183e-01, ...,
          -1.81313306e-01, -5.14598489e-01, -6.02530837e-02],
         [-6.65770471e-03,  3.13159287e-01, -2.71185905e-01, ...,
           5.93785197e-02, -4.15093929e-01,  4.46947932e-01],
         [-5.88004589e-02,  4.52397615e-01, -3.62931043e-01, ...,
           2.60685217e-02, -1.88034624e-01,  1.02178119e-01]],

        [[ 1.86629251e-01,  3.17432284e-01, -2.13954568e-01, ...,
           1.19607277e-01, -1.68098152e-01, -1.31710187e-01],
         [ 5.89104332e-02, -2.37554640e-01, -2.63397664e-01, ...,
          -2.77946554e-02, -5.70635833e-02, -1.11129723e-01],
         [-9.77829844e-02, -1.64739996e-01, -1.29499316e-01, ...,
           5.92834204e-02,  1.44627422e-01, -2.25849196e-01]],

        [[ 7.69469664e-02, -2.72346050e-01, -4.96726297e-02, ...,
          -4.02847141e-01,  2.04787269e-01, -5.33641353e-02],
         [ 1.31859863e-02, -8.56065005e-02,  1.65181216e-02, ...,
          -2.09566593e-01, -3.423

In [38]:
print(conv_layer_weights[0].shape)
print(conv_layer_weights[1].shape)

(3, 3, 3, 64)
(64,)


In [41]:
obtained_weight = conv_layer_weights[0][:,:,:,0:32]
obtained_bias = conv_layer_weights[1][0:32]

In [44]:
print(obtained_weight.shape)
print(obtained_bias.shape)

(3, 3, 3, 32)
(32,)


In [46]:
conv_layer_p.set_weights([obtained_weight, obtained_bias])

In [47]:
conv_layer_p_weights_new = conv_layer_p.get_weights()
print(conv_layer_p_weights_new)

[array([[[[ 7.48673603e-02, -3.59741688e-01,  1.05753183e-01,
           1.88295171e-01,  1.36951849e-01,  2.66627222e-01,
          -9.99576598e-03,  3.87260497e-01,  3.74591142e-01,
          -4.82921414e-02,  2.60315567e-01, -1.25506446e-01,
          -8.50475952e-02,  4.83646654e-02, -2.39891306e-01,
          -2.34437943e-01,  2.61913627e-01,  3.65245819e-01,
          -9.31872353e-02, -3.62636335e-02, -7.10283220e-02,
           5.05518913e-01, -6.07066453e-02, -1.92208499e-01,
          -5.32920897e-01,  7.92290550e-03, -3.63518298e-01,
           2.49590859e-01, -6.86031673e-03, -4.69866604e-01,
          -1.42213836e-01, -3.24607879e-01],
         [-6.65770471e-03,  3.13159287e-01, -2.71185905e-01,
          -1.80579945e-01,  3.90678972e-01,  4.55240816e-01,
          -4.05971482e-02, -1.57608390e-01, -5.64294934e-01,
           2.57314354e-01, -1.18829533e-01,  1.89240083e-01,
           1.30838454e-01,  1.83516979e-01,  2.17196777e-01,
          -1.94007441e-01, -8.68860334e

# 2. Load Weights from Original Model to Pruned Model

In [74]:
layers = model.layers
p_layers = unet_pruning_model.layers

print("original:", len(layers), "pruning:", len(p_layers))

for i in range(len(model.layers)):
    layer = model.layers[i]
    layer_p = unet_pruning_model.layers[i]
    
    orig_layer_name = layer.name
    prun_layer_name = layer_p.name
    
    if 'conv2d' in layer.name:
        
        print(i)
        print(orig_layer_name, "-", prun_layer_name)
        
        orig_conv_layer = model.get_layer(orig_layer_name)
        orig_conv_layer_weights = orig_conv_layer.get_weights()
        print(orig_conv_layer_weights[0].shape, orig_conv_layer_weights[1].shape)

        prun_conv_layer = unet_pruning_model.get_layer(prun_layer_name)
        prun_conv_layer_weights = prun_conv_layer.get_weights()
        print(prun_conv_layer_weights[0].shape, prun_conv_layer_weights[1].shape)

        target_num_filters = prun_conv_layer_weights[0].shape[3]
        each_kernel_channels = prun_conv_layer_weights[0].shape[2]

        obtained_weight = orig_conv_layer_weights[0][:,:,0:each_kernel_channels,0:target_num_filters]
        obtained_bias = orig_conv_layer_weights[1][0:target_num_filters]

        print(obtained_weight.shape, obtained_bias.shape)

        prun_conv_layer.set_weights([obtained_weight, obtained_bias])

original: 38 pruning: 38
1
conv2d_1 - conv2d_47
(3, 3, 3, 64) (64,)
(3, 3, 3, 32) (32,)
(3, 3, 3, 32) (32,)
2
conv2d_2 - conv2d_48
(3, 3, 64, 64) (64,)
(3, 3, 32, 32) (32,)
(3, 3, 32, 32) (32,)
4
conv2d_3 - conv2d_49
(3, 3, 64, 128) (128,)
(3, 3, 32, 64) (64,)
(3, 3, 32, 64) (64,)
5
conv2d_4 - conv2d_50
(3, 3, 128, 128) (128,)
(3, 3, 64, 64) (64,)
(3, 3, 64, 64) (64,)
7
conv2d_5 - conv2d_51
(3, 3, 128, 256) (256,)
(3, 3, 64, 128) (128,)
(3, 3, 64, 128) (128,)
8
conv2d_6 - conv2d_52
(3, 3, 256, 256) (256,)
(3, 3, 128, 128) (128,)
(3, 3, 128, 128) (128,)
10
conv2d_7 - conv2d_53
(3, 3, 256, 512) (512,)
(3, 3, 128, 256) (256,)
(3, 3, 128, 256) (256,)
11
conv2d_8 - conv2d_54
(3, 3, 512, 512) (512,)
(3, 3, 256, 256) (256,)
(3, 3, 256, 256) (256,)
14
conv2d_9 - conv2d_55
(3, 3, 512, 1024) (1024,)
(3, 3, 256, 512) (512,)
(3, 3, 256, 512) (512,)
15
conv2d_10 - conv2d_56
(3, 3, 1024, 1024) (1024,)
(3, 3, 512, 512) (512,)
(3, 3, 512, 512) (512,)
18
conv2d_11 - conv2d_57
(2, 2, 1024, 512) (512,)
(

In [80]:
unet_pruning_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 512, 512, 32) 896         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 512, 512, 32) 9248        conv2d_47[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D)  (None, 256, 256, 32) 0           conv2d_48[0][0]                  
__________________________________________________________________________________________________
conv2d_49 

In [81]:
conv_layer_1 = unet_pruning_model.get_layer("conv2d_69")
conv_layer_1 = conv_layer_1.get_weights()
print(conv_layer_1)

[array([[[[-1.63575873e-01, -1.35104153e-02, -2.86828578e-01],
         [ 7.19326138e-02,  1.65638924e-01,  9.77255180e-02],
         [-2.72788137e-01, -8.16098694e-03,  3.96588236e-01],
         [-2.50786364e-01, -2.81809215e-02,  6.26358613e-02],
         [-2.91038305e-01, -3.92640740e-01, -6.72045648e-02],
         [-6.05054259e-01,  1.00323379e-01, -4.17220712e-01],
         [-5.36465108e-01, -1.38890088e-01,  1.36491477e-01],
         [-2.95083672e-01, -6.42698780e-02,  1.36345476e-01],
         [-6.96028024e-02, -6.53160438e-02, -3.08791459e-01],
         [-2.26070985e-01,  1.62487209e-01, -2.42966190e-01],
         [-3.85908365e-01, -1.75240830e-01, -2.40916125e-02],
         [-5.76454222e-01, -3.02572072e-01, -9.02878046e-02],
         [-3.95108350e-02, -3.98582481e-02, -4.14033532e-02],
         [-7.65015423e-01, -3.68647903e-01,  1.12017915e-02],
         [-1.71528518e-01, -3.00563610e-04, -3.93085815e-02],
         [ 1.40113086e-01,  2.01103285e-01, -1.71003193e-01],
       

In [82]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 512, 512, 64) 36928       conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 256, 256, 64) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (

In [83]:
conv_layer_2 = model.get_layer("conv2d_23")
conv_layer_2 = conv_layer_2.get_weights()
print(conv_layer_2)

[array([[[[-1.63575873e-01, -1.35104153e-02, -2.86828578e-01],
         [ 7.19326138e-02,  1.65638924e-01,  9.77255180e-02],
         [-2.72788137e-01, -8.16098694e-03,  3.96588236e-01],
         [-2.50786364e-01, -2.81809215e-02,  6.26358613e-02],
         [-2.91038305e-01, -3.92640740e-01, -6.72045648e-02],
         [-6.05054259e-01,  1.00323379e-01, -4.17220712e-01],
         [-5.36465108e-01, -1.38890088e-01,  1.36491477e-01],
         [-2.95083672e-01, -6.42698780e-02,  1.36345476e-01],
         [-6.96028024e-02, -6.53160438e-02, -3.08791459e-01],
         [-2.26070985e-01,  1.62487209e-01, -2.42966190e-01],
         [-3.85908365e-01, -1.75240830e-01, -2.40916125e-02],
         [-5.76454222e-01, -3.02572072e-01, -9.02878046e-02],
         [-3.95108350e-02, -3.98582481e-02, -4.14033532e-02],
         [-7.65015423e-01, -3.68647903e-01,  1.12017915e-02],
         [-1.71528518e-01, -3.00563610e-04, -3.93085815e-02],
         [ 1.40113086e-01,  2.01103285e-01, -1.71003193e-01],
       