Based on the paper: https://arxiv.org/pdf/1502.05477.pdf

Some help: https://github.com/wojzaremba/trpo/blob/master/main.py

Main help: https://github.com/tensorflow/models/blob/master/pcl_rl/trust_region.py

Sketch of proof for KL expression via Fisher Information matrix (another proof simply uses Taylor expansion): https://stats.stackexchange.com/questions/51185/connection-between-fisher-metric-and-the-relative-entropy

Short reference: https://roosephu.github.io/2016/11/19/TRPO/

In [None]:
import gym
import gym_ple
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
env = gym.make('FlappyBird-v0')

In [None]:
ob = env.reset()
ob = env.step(0)[0]
print ob.shape
plt.imshow(ob)
plt.show()

## Calculating products with Fisher information matrix

In [None]:
tf.reset_default_graph()

input_layer = tf.placeholder(shape=[1, 512, 288, 3], dtype=tf.int32)
true_pred = tf.constant(1, shape=[1, 1], dtype=tf.float32)

conv1_layer = tf.layers.conv2d(tf.cast(input_layer, tf.float32), filters=8, kernel_size=[5, 5], 
                               padding="same", use_bias=False, activation=tf.nn.relu, name="conv_filters")
pool1_layer = tf.layers.max_pooling2d(conv1_layer, pool_size=[16, 8], strides=[16, 8])

flatten_layer = tf.contrib.layers.flatten(pool1_layer)
pred_layer = tf.sigmoid(tf.layers.dense(flatten_layer, 1, use_bias=False, name="dense_weights"))

loss = tf.losses.mean_squared_error(pred_layer, true_pred)

tf.trainable_variables()

In [None]:
def get_padded_gradients(loss, var_list):
    grads = tf.gradients(loss, var_list)
    return [g if g is not None else tf.zeros(v.shape)
            for g, v in zip(grads, var_list)]

def get_flattened_gradients(loss, var_list):
    padded_gradients = get_padded_gradients(loss, var_list)
    return tf.concat([tf.reshape(x, [-1]) for x in padded_gradients], 0)

In [None]:
target_v = tf.constant([1] * 9816, dtype=tf.float32)

flat_grad = get_flattened_gradients(loss, tf.trainable_variables())
flat_vars = tf.concat([tf.reshape(x, [-1]) for x in tf.trainable_variables()], 0)

grad_vector_product = tf.reduce_sum(flat_grad * target_v)
fisher_vector_product = get_flattened_gradients(grad_vector_product, tf.trainable_variables())

with tf.Session() as sess:
    feed_dict = {input_layer: np.expand_dims(ob, axis=0)}
    sess.run(tf.global_variables_initializer(), feed_dict=feed_dict)
    print sess.run(pred_layer)
    fisher_test, var_test, grad_test = sess.run([fisher_vector_product, flat_vars, flat_grad], feed_dict=feed_dict)
    
    print fisher_test
    print var_test
    print grad_test  

## Defining RL agent class

In [None]:
def get_padded_gradients(loss, var_list):
    grads = tf.gradients(loss, var_list)
    return [g if g is not None else tf.zeros(v.shape)
            for g, v in zip(grads, var_list)]

def get_flattened_gradients(loss, var_list):
    padded_gradients = get_padded_gradients(loss, var_list)
    return tf.concat([tf.reshape(x, [-1]) for x in padded_gradients], 0)

A matrix as in TRPO paper can be calculated via E_state [ - sum[ p(a | theta_old, state) * grad^2_theta log(p(a | theta_old, state)) ] ]

In [None]:
class RL_Agent:
    
    def __init__(self, model_name):
        with tf.variable_scope(model_name):
            self.model_name = model_name
            self.session = tf.Session()
            
            self.input_layer = tf.placeholder(shape=[None, 512, 288, 3], dtype=tf.float32)
            self.conv1_layer = tf.layers.conv2d(self.input_layer, 
                                                filters=8, kernel_size=[5, 5], 
                                                padding="same", use_bias=False, 
                                                activation=tf.nn.relu, name="conv_weights"
                                               )
            
            self.pool1_layer = tf.layers.max_pooling2d(self.conv1_layer, pool_size=[16, 8], strides=[16, 8])
            self.flatten_layer = tf.contrib.layers.flatten(self.pool1_layer)
            self.dense_layer = tf.layers.dense(self.flatten_layer, 2, use_bias=False, name="dense_weights")
            
            self.prob_layer = tf.nn.softmax(self.dense_layer)
            self.log_prob_layer = tf.nn.log_softmax(self.dense_layer)
                        
            self.session.run(tf.global_variables_initializer())

    def model_variables(self):
        return [x for x in tf.trainable_variables() if self.model_name in x.name]
            
    def predict(self, x):
        return self.session.run(self.prob_layer, feed_dict={self.input_layer: x})
    
    def fisher_vector_product(self, x, vector):
        expected_log_prob = tf.reduce_sum(tf.stop_gradient(self.prob_layer) * self.log_prob_layer, 1)
        log_prob_grad = get_flattened_gradients(expected_log_prob, self.model_variables())
        grad_vector_product = tf.reduce_sum(log_prob_grad * vector)
        fisher_vector_product = - get_flattened_gradients(grad_vector_product, self.model_variables())

        return self.session.run(fisher_vector_product, feed_dict={self.input_layer: x})
            

In [None]:
tf.reset_default_graph()

flappy_model = RL_Agent("test_model")
print flappy_model.predict(np.expand_dims(ob, axis=0))
print flappy_model.model_variables()
print flappy_model.fisher_vector_product(np.expand_dims(ob, axis=0), 
                                         tf.constant([1] * (9216 * 2 + 600), dtype=tf.float32))

## Playground

In [None]:
x = tf.Variable([[1, 1], [1, 2]], dtype=tf.float32)
y = tf.Variable([[2, 1]], dtype=tf.float32)
z = tf.matmul(tf.matmul(y, x), tf.transpose(y))

target_v = tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.float32)

def get_padded_gradients(loss, var_list):
    grads = tf.gradients(loss, var_list)
    return [g if g is not None else tf.zeros(v.shape)
            for g, v in zip(grads, var_list)]

def get_flattened_gradients(loss, var_list):
    padded_gradients = get_padded_gradients(loss, var_list)
    return tf.concat([tf.reshape(x, [-1]) for x in padded_gradients], 0)

# grads = get_padded_gradients(z, [x, y])
# flat_grad = tf.concat([tf.reshape(grads[i], [-1]) for i in range(len(grads))], 0)

flat_grad = get_flattened_gradients(z, [x, y])
flat_vars = tf.concat([tf.reshape(x, [-1]), tf.reshape(y, [-1])], 0)
print flat_grad, flat_vars

grad_vector_product = tf.reduce_sum(flat_grad * target_v)
fisher_vector_product = get_flattened_gradients(grad_vector_product, [x, y])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    cur_test = sess.run(fisher_vector_product)
    
    print cur_test