# Save and Restore Model

In [1]:
import os
import numpy as np
import tensorflow as tf
from util import suppress_tf_warning
suppress_tf_warning()
print ("TF:[%s]"%(tf.__version__))

TF:[1.15.0]


### Create Model

In [2]:
def create_tf_model(xdim=10,ydim=2,hdims=[256,256],actv=tf.nn.relu,out_actv=None):
    """
    Create TF model
    """
    def mlp(x,hdims=[256,256],actv=tf.nn.relu,out_actv=None):
        ki = tf.truncated_normal_initializer(stddev=0.1)
        for hdim in hdims[:-1]:
            x = tf.layers.dense(x,units=hdim,activation=actv,kernel_initializer=ki)
        return tf.layers.dense(x,units=hdims[-1],
                               activation=out_actv,kernel_initializer=ki)
    def placeholder(dim=None):
        return tf.placeholder(dtype=tf.float32,shape=(None,dim) if dim else (None,))
    def placeholders(*args):
        """
        Usage: a_ph,b_ph,c_ph = placeholders(adim,bdim,None)
        """
        return [placeholder(dim) for dim in args]
    def get_vars(scope):
        return [x for x in tf.compat.v1.global_variables() if scope in x.name]
    # Have own session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    # Placeholder
    x_ph = placeholder(xdim)
    # Model
    with tf.variable_scope('main'):
        y = mlp(x_ph,hdims=hdims+[ydim],actv=actv,out_actv=out_actv)
    # Params
    main_vars = get_vars('main')
    model = {'x_ph':x_ph,'y':y,'main_vars':main_vars}
    return model, sess

def save_tf_model(npz_path,M,VERBOSE=True):
    """
    Save TF model weights 
    """
    # TF model
    tf_vars = M.model['main_vars'] 
    data2save,var_names,var_vals = dict(),[],[]
    for v_idx,tf_var in enumerate(tf_vars):
        var_name,var_val = tf_var.name,M.sess.run(tf_var)
        var_names.append(var_name)
        var_vals.append(var_val)
        data2save[var_name] = var_val
        if VERBOSE:
            print ("[%02d]  var_name:[%s]  var_shape:%s"%
                (v_idx,var_name,var_val.shape,)) 
    # Create folder if not exist
    dir_name = os.path.dirname(npz_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
        print ("[%s] created."%(dir_name))
    # Save npz
    np.savez(npz_path,**data2save)
    print ("[%s] saved."%(npz_path))
    
def restore_tf_model(npz_path,M,VERBOSE=True):
    """
    Restore TF model weights
    """
    # Load npz
    l = np.load(npz_path)
    print ("[%s] loaded."%(npz_path))
    
    # Get values of TF model  
    tf_vars = M.model['main_vars'] 
    var_vals = []
    for tf_var in tf_vars:
        var_vals.append(l[tf_var.name])   
        
    # Assign weights of ARS model
    M.set_weights(var_vals)

### Model Class

In [3]:
class ModelClass(object):
    """
    TF Model
    """
    def __init__(self,xdim=10,ydim=2,hdims=[256,256],actv=tf.nn.relu,out_actv=None,
                 seed=0):
        self.seed = seed
        self.xdim,self.ydim = xdim,ydim
        self.model,self.sess = create_tf_model(
            xdim=xdim,ydim=ydim,hdims=hdims,actv=actv,out_actv=out_actv)
        # Initialize model 
        tf.set_random_seed(self.seed)
        np.random.seed(self.seed)
        self.sess.run(tf.global_variables_initializer())
        # Flag to initialize assign operations for 'set_weights()'
        self.FIRST_SET_FLAG = True
    def get_weights(self):
        weight_vals = self.sess.run(self.model['main_vars'])
        return weight_vals
    def set_weights(self,weight_vals):
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
            self.sess.run(self.assign_ops[w_idx],
                          {self.assign_placeholders[w_idx]:weight_vals[w_idx]})

### Create Model

In [4]:
tf.reset_default_graph()
M = ModelClass()

### Get the output of the model of a random input

In [5]:
np.random.seed(1)
x_rand = np.random.rand(M.xdim)
y_out_1 = M.sess.run(M.model['y'],feed_dict={M.model['x_ph']:x_rand.reshape((1,-1))})
print (y_out_1)

[[-0.00717899 -0.04945676]]


### Save model

In [6]:
npz_path = '../data/net/toy_model/net.npz'
save_tf_model(npz_path,M,VERBOSE=True)

[00]  var_name:[main/dense/kernel:0]  var_shape:(10, 256)
[01]  var_name:[main/dense/bias:0]  var_shape:(256,)
[02]  var_name:[main/dense_1/kernel:0]  var_shape:(256, 256)
[03]  var_name:[main/dense_1/bias:0]  var_shape:(256,)
[04]  var_name:[main/dense_2/kernel:0]  var_shape:(256, 2)
[05]  var_name:[main/dense_2/bias:0]  var_shape:(2,)
[../data/net/toy_model/net.npz] saved.


### Re-initialize weights and get the otutput

In [7]:
M.sess.run(tf.global_variables_initializer())
y_out_2 = M.sess.run(M.model['y'],feed_dict={M.model['x_ph']:x_rand.reshape((1,-1))})
print (y_out_2)

[[ 0.18526612 -0.08723448]]


### Restore the model and get the output

In [8]:
restore_tf_model(npz_path,M,VERBOSE=True)
y_out_3 = M.sess.run(M.model['y'],feed_dict={M.model['x_ph']:x_rand.reshape((1,-1))})
print (y_out_3)

[../data/net/toy_model/net.npz] loaded.
[[-0.00717899 -0.04945676]]


### y_1 and y_3 should be the same

In [9]:
print ('y_out_1:',y_out_1)
print ('y_out_2:',y_out_2)
print ('y_out_3:',y_out_3)

y_out_1: [[-0.00717899 -0.04945676]]
y_out_2: [[ 0.18526612 -0.08723448]]
y_out_3: [[-0.00717899 -0.04945676]]
