In [8]:
import pickle
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import random

def convpool(X,W,b):
    conv_out=tf.nn.conv2d(X,W,strides=[1,1,1,1],padding="SAME")
    conv_out=tf.nn.bias_add(conv_out,b)
    conv_out=tf.nn.elu(conv_out)
    pool_out=tf.nn.max_pool(conv_out,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    return pool_out

def init_filter(shape,poolsz):
    w=np.random.randn(*shape)/np.sqrt(np.prod(shape[:-1])+shape[-1]*np.prod(shape[:-2]/np.prod(poolsz)))
    return w.astype(np.float32)

from functools import partial
he_init = tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG") # he init method
scale=0.0
my_dense=partial(tf.layers.dense,activation=tf.nn.elu,
                 kernel_regularizer=tf.contrib.layers.l1_regularizer(scale),
                 kernel_initializer=he_init)



INFO:tensorflow:Scale of 0 disables regularizer.


In [26]:
cae_checkpoints_savepath="model_checkpoints_large/CAE_04302018_2layers_xsmall.ckpt"
checkpoints_savepath="model_checkpoints_large/DQN_cae_05092018_v1.ckpt"

pool_sz=(2,2)

n_hidden1=100
n_hidden2=200
n_hidden3=40
n_outputs=2

tf.reset_default_graph()

# cnn_pool layer 1
W1_shape=(4,4,3,10)
W1_init=init_filter(W1_shape,pool_sz)
b1_init=np.zeros(W1_shape[-1],dtype=np.float32)

# cnn_pool layer 2
W2_shape=(4,4,10,3)
W2_init=init_filter(W2_shape,pool_sz)
b2_init=np.zeros(W2_shape[-1],dtype=np.float32)
X=tf.placeholder(tf.float32,shape=(None,2,128,128,3),name="X")

with tf.name_scope("cnn"):
    with tf.device("/cpu:0"):
        W1=tf.Variable(W1_init.astype(np.float32),trainable=False,name='W1')
        b1=tf.Variable(b1_init.astype(np.float32),trainable=False,name='b1')
        W2=tf.Variable(W2_init.astype(np.float32),trainable=False,name='W2')
        b2=tf.Variable(b2_init.astype(np.float32),trainable=False,name='b2')
    
    # first frame
    with tf.device("/cpu:0"):
        X1=tf.reshape(tf.slice(X,[0,0,0,0,0],[-1,1,-1,-1,-1]),[-1,128,128,3])
        Z11=convpool(X1,W1,b1)
        Z12=convpool(Z11,W2,b2)
    
    # second frame
    with tf.device("/cpu:0"):
        X2=tf.reshape(tf.slice(X,[0,1,0,0,0],[-1,1,-1,-1,-1]),[-1,128,128,3])
        Z21=convpool(X2,W1,b1)
        Z22=convpool(Z21,W2,b2)

with tf.name_scope("cnn_output"):
    with tf.device("/cpu:0"):
        # take the difference of two frames
        Z_diff=Z22-Z12

with tf.name_scope("dense_layers"):
    with tf.device("/cpu:0"):
        # fully connected layer
        hidden1=my_dense(Z_diff,n_hidden1)
        hidden2=my_dense(hidden1,n_hidden2)
        hidden3=my_dense(hidden2,n_hidden3)
        q_values=my_dense(hidden3,n_outputs)
        
with tf.name_scope("target_q"):
    with tf.device("/cpu:0"):
        q_target=tf.placeholder(tf.float32,shape=[None,n_outputs])
        
        
with tf.name_scope("training_op"):
    with tf.device("/cpu:0"):
        learning_rate=tf.placeholder(tf.float32,shape=[])
        mse_loss=tf.reduce_mean(tf.squared_difference(q_values,q_target))
        optimizer=tf.train.AdamOptimizer(learning_rate)
        training_op=optimizer.minimize(mse_loss)
        init=tf.global_variables_initializer()
        

with tf.name_scope("saver"):
    var_list={'cnn/Variable':W1,'cnn/Variable_1':b1,'cnn/Variable_2':W2,'cnn/Variable_3':b2}
#     var_list={'cnn/Variable:0':W1}
    saver_cae_restore = tf.train.Saver(var_list=var_list)
    saver_whole = tf.train.Saver()
    

training the network

If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable objects (in the target graph).

In [27]:
init=tf.global_variables_initializer()
train_mse_save=[]
test_mse_save=[]
with tf.Session() as sess:
    init.run()
    try:
        saver_cae_restore.restore(sess, cae_checkpoints_savepath)
    except:
        print("restoring error, will start over!")
    

INFO:tensorflow:Restoring parameters from model_checkpoints_large/CAE_04302018_2layers_xsmall.ckpt
