Based on the paper: https://arxiv.org/pdf/1502.05477.pdf

Some help: https://github.com/wojzaremba/trpo/blob/master/main.py

Main help: https://github.com/tensorflow/models/blob/master/pcl_rl/trust_region.py

Sketch of proof for KL expression via Fisher Information matrix (another proof simply uses Taylor expansion): https://stats.stackexchange.com/questions/51185/connection-between-fisher-metric-and-the-relative-entropy

Short reference: https://roosephu.github.io/2016/11/19/TRPO/

In [None]:
import gym
import gym_ple
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import imageio

In [None]:
%matplotlib inline

In [None]:
env = gym.make('FlappyBird-v0')

In [None]:
ob = env.reset()
print ob.shape
plt.imshow(ob)
plt.show()

## Calibrating the environment

Original Flappy Bird environment may be too difficult to attack with simple TRPO (main reason is that TRPO assumes a markovian setting and during the actual game you need to track the history of clicks)

To make it easier I create an environemnt on top of the original one - I still retain the original mechanics but did some reward hacking on top of it to, hopefully, speed up learning 

In [None]:
class Custom_Flappy_Env:
    def __init__(self, proper_gym_env):
        self.gym_env = proper_gym_env
        
    def reset(self):
        return self.gym_env.reset()
    
    def step(self, action):
        observation, reward, done, info = self.gym_env.step(action)
        return observation, 1.0, done, info        

## Making use of generic agent / environment scripts

In [None]:
import sys
sys.path.append("..")
from rl_agent import RL_Agent
from rl_learner import TRPO_Learner

In [None]:
class Flappy_Agent(RL_Agent):
    # Overwriting supposedly abstract RL_Agent class
    # All what is left is to actually provide the specific model to choose action
    # It is still implied that
    # 1) __init__ method defines all its variables in model_name scope
    # 2) the class has self.session, self.prob_layer and self.log_prob_layer methods
    # The remaining functionality needed in PG and TRPO learners is still defined in abstract base
    def __init__(self, model_name):
        RL_Agent.__init__(self, model_name)
        with tf.variable_scope(model_name):
            self.session = tf.Session()

            self.input_layer = tf.placeholder(shape=[None, 512, 288, 3], dtype=tf.float32)
            self.conv_1 = tf.layers.conv2d(self.input_layer, filters=8, kernel_size=5, strides=2, activation=tf.nn.relu)
            self.pool_1 = tf.layers.max_pooling2d(self.conv_1, pool_size=3, strides=2)

            self.conv_2 = tf.layers.conv2d(self.pool_1, filters=16, kernel_size=5, strides=2, activation=tf.nn.relu)
            self.pool_2 = tf.layers.max_pooling2d(self.conv_2, pool_size=3, strides=2)

            self.conv_3 = tf.layers.conv2d(self.pool_2, filters=32, kernel_size=5, strides=2, activation=tf.nn.relu)
            self.pool_3 = tf.layers.max_pooling2d(self.conv_3, pool_size=3, strides=2)

            self.flat = tf.contrib.layers.flatten(self.pool_3)
            self.dense_1 = tf.layers.dense(self.flat, units=25, activation=tf.nn.relu)
            self.dense_2 = tf.layers.dense(self.dense_1, units=2)
                        
            self.prob_layer = tf.maximum(tf.minimum(tf.nn.softmax(self.dense_2), 0.9999), 0.0001)
            self.log_prob_layer = tf.log(self.prob_layer)
                        
            self.session.run(tf.global_variables_initializer())

In [None]:
tf.reset_default_graph()
trpo = TRPO_Learner(rl_agent=Flappy_Agent("2018_01_20_flappy_model"), 
                    game_env=Custom_Flappy_Env(env),
                    discount=0.99, 
                    batch_size=25, 
                    frame_cap=100,
                    trpo_delta=0.02,
                    line_search_option="max")

for i in range(10):
    trpo.step()

## Visualizing played flappybird game

In [None]:
from IPython.display import HTML
states, _, _ = trpo.play_single_game()
gif_location = "simulations/" + trpo.agent.model_name + "_after_" + str(trpo.played_games) + "_games.gif"
imageio.mimsave(gif_location, states)
HTML('<img src="' + gif_location + '">')