# Save and Restore Weights

In [1]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from util import placeholders,get_mnist,suppress_tf_warning,mlp,gpu_sess,get_vars
%matplotlib inline  
%config InlineBackend.figure_format='retina'
suppress_tf_warning()
print ("TF version:[%s]."%(tf.__version__))

TF version:[1.15.0].


### Dataset

In [2]:
x_train,y_train,x_test,y_test = get_mnist()
n_train,n_test,x_dim,y_dim = x_train.shape[0],x_test.shape[0],\
    x_train.shape[1],y_train.shape[1]
print ("n_train:[%d], n_test:[%d], x_dim:[%d], y_dim:[%d]"%
       (n_train,n_test,x_dim,y_dim))

n_train:[60000], n_test:[10000], x_dim:[784], y_dim:[10]


### ConvNet Class

In [3]:


class ConvNetClsClass(object):
    """
    CNN for classification
    """
    def __init__(self,name='CNN',x_dim=784,y_dim=10,img_dim=[28,28,1],
                 filter_sizes=[32,32],kernel_sizes=[3,3],h_dims=[128],
                 USE_BN=True,USE_DROPOUT=True):
        self.name = name
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.img_dim = img_dim
        
        self.filter_sizes = filter_sizes
        self.kernel_sizes = kernel_sizes
        self.h_dims = h_dims
        
        self.USE_BN = USE_BN
        self.USE_DROPOUT = USE_DROPOUT
        self.FIRST_SET_FLAG = True
        
        self.build_model()
        self.build_graph()
        print("[%s] instantiated."%(self.name))
        
    def build_model(self):
        """
        Build model
        """
        self.ph_x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dim],name='x')
        self.ph_y = tf.placeholder(dtype=tf.float32,shape=[None,self.y_dim],name='y')
        self.ph_is_train = tf.placeholder(tf.bool,name='is_train') 
        
        net = tf.reshape(self.ph_x,shape=[-1]+self.img_dim) # reshape
        
        with tf.variable_scope('main'):
            # Conv layers
            for (filter_size,kernel_size) in zip(self.filter_sizes,self.kernel_sizes):
                net = tf.layers.conv2d(inputs=net,
                                       filters=filter_size,kernel_size=kernel_size,
                                       padding='same',activation=None)
                net = tf.layers.max_pooling2d(inputs=net,pool_size=2,strides=2)
                if self.USE_BN:
                    net = tf.layers.batch_normalization(net,training=self.ph_is_train)
                net = tf.nn.relu(net)

            # Dense layers
            net = tf.layers.flatten(net)
            net = mlp(net,h_dims=self.h_dims+[self.y_dim],actv=tf.nn.relu,out_actv=None,
                      USE_DROPOUT=self.USE_DROPOUT,ph_is_training=self.ph_is_train)
        self.y_hat = net
        self.main_vars = get_vars('main')
        
    def build_graph(self):
        """
        Build graph
        """
        self.costs = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=self.ph_y,logits=self.y_hat)
        self.cost = tf.reduce_mean(self.costs) 
        self.update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) # BN
        self.optm = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(self.cost)
        self.optm = tf.group([self.optm,self.update_ops])
        self.corr = tf.equal(tf.argmax(self.y_hat,1),tf.argmax(self.ph_y,1)) # [N]
        self.accr = tf.reduce_mean(tf.cast(self.corr, "float")) # [1]
        
    def update(self,sess,x_batch,y_batch):
        """
        Update model 
        """
        feeds = {self.ph_x:x_batch,self.ph_y:y_batch,self.ph_is_train:True}
        cost_val,_ = sess.run([self.cost,self.optm],feed_dict=feeds)
        return cost_val
    
    def get_accr(self,sess,x,y,batch_size=256):
        """
        Get accuracy
        """
        n = x.shape[0] # number of data
        accr_val_sum = 0.0
        for it in range(np.ceil(n/batch_size).astype(np.int)):
            x_batch = x[it*batch_size:(it+1)*batch_size,:]
            y_batch = y[it*batch_size:(it+1)*batch_size,:]
            feeds = {self.ph_x:x_batch,self.ph_y:y_batch,self.ph_is_train:False}
            accr_val = sess.run(self.accr,feed_dict=feeds)
            accr_val_sum += accr_val*x_batch.shape[0]
        accr_val_avg = accr_val_sum/n # average out accuracy 
        return accr_val_avg
    
    def save(self,npz_path,sess,VERBOSE=False):
        """
        Save model
        """
        # Accumulate weights
        tf_vars = self.main_vars
        data2save,var_names,var_vals = dict(),[],[]
        for v_idx,tf_var in enumerate(tf_vars):
            var_name,var_val = tf_var.name,sess.run(tf_var)
            var_names.append(var_name)
            var_vals.append(var_val)
            data2save[var_name] = var_val
            if VERBOSE:
                print ("[%02d]  var_name:[%s]  var_shape:%s"%
                    (v_idx,var_name,var_val.shape,))  
        # Create folder if not exist
        dir_name = os.path.dirname(npz_path)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
            print ("[%s] created."%(dir_name))
        # Save npz
        np.savez(npz_path,**data2save)
        print ("[%s] saved."%(npz_path))
    
    def restore(self,npz_path,sess):
        """
        Restore model
        """
        # Load npz
        l = np.load(npz_path)
        print ("[%s] loaded."%(npz_path))
        # Get values of SAC model  
        tf_vars = self.main_vars
        var_vals = []
        for tf_var in tf_vars:
            var_vals.append(l[tf_var.name])   
        # Set weights
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.main_vars): 
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.main_vars): 
            sess.run(self.assign_ops[w_idx],
                     {self.assign_placeholders[w_idx]:var_vals[w_idx]})
        
print ("Ready.")

Ready.


### Instantiate and Loop

In [4]:
tf.reset_default_graph()
sess = gpu_sess() # open session
C = ConvNetClsClass(name='CNN',x_dim=784,y_dim=10,img_dim=[28,28,1],
                    filter_sizes=[32,32],kernel_sizes=[3,3],h_dims=[128],
                    USE_BN=True,USE_DROPOUT=True)

[CNN] instantiated.


In [5]:
sess.run(tf.global_variables_initializer()) # Initialize variables
max_epoch,batch_size,print_every = 10,128,1
max_iter = np.ceil(n_train/batch_size).astype(np.int) # number of iterations
for epoch in range(max_epoch):
    p_idx = np.random.permutation(n_train)
    for it in range(max_iter):
        b_idx = p_idx[batch_size*(it):batch_size*(it+1)]
        x_batch,y_batch = x_train[b_idx,:],y_train[b_idx,:]
        C.update(sess,x_batch,y_batch)
    if ((epoch%print_every)==0) or (epoch==(max_epoch-1)):
        train_accr_val = C.get_accr(sess,x_train,y_train)
        test_accr_val = C.get_accr(sess,x_test,y_test)
        print ("epoch:[%d/%d] train_accuracy:[%.3f] test_accuracy:[%.3f]"%
               (epoch,max_epoch,train_accr_val,test_accr_val))
print ("Done.")

epoch:[0/10] train_accuracy:[0.971] test_accuracy:[0.972]
epoch:[1/10] train_accuracy:[0.986] test_accuracy:[0.984]
epoch:[2/10] train_accuracy:[0.990] test_accuracy:[0.987]
epoch:[3/10] train_accuracy:[0.992] test_accuracy:[0.989]
epoch:[4/10] train_accuracy:[0.994] test_accuracy:[0.991]
epoch:[5/10] train_accuracy:[0.995] test_accuracy:[0.992]
epoch:[6/10] train_accuracy:[0.995] test_accuracy:[0.990]
epoch:[7/10] train_accuracy:[0.997] test_accuracy:[0.992]
epoch:[8/10] train_accuracy:[0.997] test_accuracy:[0.990]
epoch:[9/10] train_accuracy:[0.997] test_accuracy:[0.991]
Done.


### Save

In [6]:
C.save(npz_path='../net/cnn_mnist/net.npz',sess=sess)

[../net/cnn_mnist/net.npz] saved.


### Re-init weights and evaluate

In [7]:
sess.run(tf.global_variables_initializer()) 
test_accr_val = C.get_accr(sess,x_test,y_test)
print ("test_accr_val:[%.3f]"%(test_accr_val))

test_accr_val:[0.083]


### Restore

In [8]:
C.restore(npz_path='../net/cnn_mnist/net.npz',sess=sess)

[../net/cnn_mnist/net.npz] loaded.


### Evaluate

In [9]:
test_accr_val = C.get_accr(sess,x_test,y_test)
print ("test_accr_val:[%.3f]"%(test_accr_val))

test_accr_val:[0.991]
