# Introduction
The goal here is to export the weights of the Keras model in a PyTorch-like state dictionary format.

### Imports
Import libraries and write settings here.

## General packages

In [1]:
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model, Model, Sequential
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers import Input, Dropout, Activation, Conv2D, MaxPooling2D, UpSampling2D, Lambda, BatchNormalization
from keras.callbacks import Callback, EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from keras import callbacks, initializers, layers, models, optimizers
from keras import backend as K

from pathlib import Path
import numpy as np

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


## Custom modules 

In [2]:
import sys
%load_ext autoreload

%autoreload 2

repo_path = Path("..")

sys.path.append("..")
import visualization_utils

# Keras model


In [3]:
# MODEL
model_name = "c-ResUnet.h5"
# model_name = "c-ResUnet_noWM.h5"
model_path = "{}/model_results/{}".format(repo_path, model_name)

In [4]:
def mean_iou(y_true, y_pred):
    prec = []
    for t in np.arange(0.2, 0.8, 0.05):
        y_pred_ = tf.to_int32(y_pred > t)
        score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2)
        K.get_session().run(tf.local_variables_initializer())
        with tf.control_dependencies([up_opt]):
            score = tf.identity(score)
        prec.append(score)
    return K.mean(K.stack(prec), axis=0)

# dice loss


def dice_coef(y_true, y_pred):
    """Generate the 'Dice' coefficient for the provided prediction.
    Args:
        y_true: The expected/desired output mask.
        y_pred: The actual/predicted mask.
    Returns:
        The Dice coefficient between the expected and actual outputs. Values
        closer to 1 are considered 'better'.
    """
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    """Model loss function using the 'Dice' coefficient.
    Args:
        y_true: The expected/desired output mask.
        y_pred: The actual/predicted mask.
    Returns:
        The corresponding loss, related to the dice coefficient between the expected
        and actual outputs. Values closer to 0 are considered 'better'.
    """
    return -dice_coef(y_true, y_pred)


def create_weighted_binary_crossentropy(zero_weight, one_weight):

    def weighted_binary_crossentropy(y_true, y_pred):

        b_ce = K.binary_crossentropy(y_true, y_pred)

        # Apply the weights
        weight_vector = y_true * one_weight + (1. - y_true) * zero_weight
        weighted_b_ce = weight_vector * b_ce

        # Return the mean error
        return K.mean(weighted_b_ce)

    return weighted_binary_crossentropy

## Load model

In [5]:
WeightedLoss = create_weighted_binary_crossentropy(1, 1.5)

model = load_model(model_path, custom_objects={'mean_iou': mean_iou, 'dice_coef': dice_coef,
                                               'weighted_binary_crossentropy': WeightedLoss}, compile=False)










In [6]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 1 4           input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 1 4           conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 1 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (

# Export weights

In [7]:
def get_named_weights(model, shapes=False):
    """Return PyTorch's `state_dict`-like dictionary of model weights."""
    from collections import OrderedDict
    names = [weight.name for layer in model.layers for weight in layer.weights]
    weights = model.get_weights()

    state_dict = OrderedDict()
    state_dict_shapes = OrderedDict()
    for name, weight in zip(names, weights):
        state_dict[name] = weight
        if shapes:
            state_dict_shapes[name] = weight.shape
    return state_dict, state_dict_shapes


state_dict, state_dict_shapes = get_named_weights(model, True)

In [8]:
# w = state_dict['conv2d_1/kernel:0']
# for el in w.flatten():
#     print(repr(el))

In [13]:
idx = 0
for k,v in state_dict_shapes.items():
    print(k, v)
    idx+=1
    if idx==4: break

conv2d_1/kernel:0 (1, 1, 3, 1)
conv2d_1/bias:0 (1,)
batch_normalization_1/gamma:0 (1,)
batch_normalization_1/beta:0 (1,)


## Check weights 

In [10]:
np.set_printoptions(precision=20)

# w_name = 'conv2d_1/kernel:0'
w_name = 'batch_normalization_1'  # /beta:0'
# w_name = 'conv2d_2'

if 'conv' in w_name:
    print('weight:\n', state_dict[f"{w_name}/kernel:0"],
          '\nbias:\n', state_dict[f"{w_name}/bias:0"])

elif 'batch' in w_name:
    print('weight:\n', state_dict[f"{w_name}/gamma:0"],
          '\nbias:\n', state_dict[f"{w_name}/beta:0"],
          '\nmean:\n', state_dict[f"{w_name}/moving_mean:0"],
          '\nvariance:\n', state_dict[f"{w_name}/moving_variance:0"])
# state_dict[w_name]

weight:
 [0.17005211] 
bias:
 [-0.03283687] 
mean:
 [0.2896597] 
variance:
 [7.10894e-05]


## Export 

In [11]:
import pickle


def save_state_dict(d, path):
    print('Saving at:', path)
    with open(path, 'wb') as f:
        pickle.dump(d, f)


def load_state_dict(path):
    with open(path, 'rb') as f:
        d = pickle.load(f)
        return d


save_folder = f'{repo_path}/model_results/'

# state_dict
outpath = save_folder + f"{model_name.split('.')[0]}_state_dict.pkl"
save_state_dict(state_dict, outpath)
state_dict1 = load_state_dict(outpath)

# state_dict_shapes
outpath = save_folder + f"{model_name.split('.')[0]}_state_dict_shapes.pkl"
save_state_dict(state_dict_shapes, outpath)
state_dict_shapes1 = load_state_dict(outpath)

idx = 0
for old_w, old_s, new_w, new_s in zip(state_dict.keys(), state_dict_shapes.values(), state_dict1.keys(), state_dict_shapes1.values()):
    print('saved dict:\n', 'layer:\t', old_w, '\tshape:\t', old_s)
    print('loaded dict:\n', 'layer:\t', new_w, '\tshape:\t', new_s)
    print('comparison:\nold:', state_dict[old_w], '\nnew:', state_dict1[new_w])
    print('\n\n')
    idx += 1
    if idx > 4:
        break

Saving at: ../model_results/c-ResUnet_state_dict.pkl
Saving at: ../model_results/c-ResUnet_state_dict_shapes.pkl
saved dict:
 layer:	 conv2d_1/kernel:0 	shape:	 (1, 1, 3, 1)
loaded dict:
 layer:	 conv2d_1/kernel:0 	shape:	 (1, 1, 3, 1)
comparison:
old: [[[[ 0.0028932514]
   [-0.16237044  ]
   [-0.0057095936]]]] 
new: [[[[ 0.0028932514]
   [-0.16237044  ]
   [-0.0057095936]]]]



saved dict:
 layer:	 conv2d_1/bias:0 	shape:	 (1,)
loaded dict:
 layer:	 conv2d_1/bias:0 	shape:	 (1,)
comparison:
old: [0.2927297] 
new: [0.2927297]



saved dict:
 layer:	 batch_normalization_1/gamma:0 	shape:	 (1,)
loaded dict:
 layer:	 batch_normalization_1/gamma:0 	shape:	 (1,)
comparison:
old: [0.17005211] 
new: [0.17005211]



saved dict:
 layer:	 batch_normalization_1/beta:0 	shape:	 (1,)
loaded dict:
 layer:	 batch_normalization_1/beta:0 	shape:	 (1,)
comparison:
old: [-0.03283687] 
new: [-0.03283687]



saved dict:
 layer:	 batch_normalization_1/moving_mean:0 	shape:	 (1,)
loaded dict:
 layer:	 batch_

In [12]:
outpath

'../model_results/c-ResUnet_state_dict_shapes.pkl'

<div class="alert alert-block alert-info">
    
This correctly saves Keras weights as pickle files to be imported from PyTorch later.
    
</div>