Skip to content

Commit

Permalink
Added pong demo
Browse files Browse the repository at this point in the history
Misc:
   - added ConvLayer, SeqLayer, LambdaLayer
   - added ability to save and restore controllers (ignores experience, as it can be huge)
  • Loading branch information
siemanko committed Jul 2, 2016
1 parent 44b9527 commit 9f5c8a0
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 9 deletions.
152 changes: 152 additions & 0 deletions notebooks/pong.py
@@ -0,0 +1,152 @@
import gym
import numpy as np
import tensorflow as tf
import time
import os

from itertools import count

from tf_rl.models import Layer, LambdaLayer, ConvLayer, SeqLayer
from tf_rl.controller.discrete_deepq import DiscreteDeepQ

# CRAZY VARIABLES
REAL_TIME = False
RENDER = True

MODEL_SAVE_DIR = "./pong_model/"
MODEL_SAVE_EVERY_S = 60

# SESNIBLE VARIABLES
FPS = 60
MAX_FRAMES = 1000
IMAGE_SHAPE = (210, 160, 3)
OBS_SHAPE = (210, 160, 6)
NUM_ACTIONS = 6


def make_model():
"""Create a tensorflow convnet that takes image as input
and outputs a predicted discounted score for every action"""

with tf.variable_scope('convnet'):
convnet = SeqLayer([
ConvLayer(3, 3, 6, 32, stride=(1,1), scope='conv1'), # out.shape = (B, 210, 160, 3)
LambdaLayer(tf.nn.sigmoid),
ConvLayer(2, 2, 32, 64, stride=(2,2), scope='conv2'), # out.shape = (B, 105, 80, 64)
LambdaLayer(tf.nn.sigmoid),
ConvLayer(3, 3, 64, 64, stride=(1,1), scope='conv3'), # out.shape = (B, 105, 80, 64)
LambdaLayer(tf.nn.sigmoid),
ConvLayer(2, 2, 64, 128, stride=(2,2), scope='conv4'), # out.shape = (B, 53, 40, 128)
LambdaLayer(tf.nn.sigmoid),
ConvLayer(3, 3, 128, 128, stride=(1,1), scope='conv5'), # out.shape = (B, 53, 40, 128)
LambdaLayer(tf.nn.sigmoid),
ConvLayer(2, 2, 128, 256, stride=(2,2), scope='conv6'), # out.shape = (B, 27, 20, 256)
LambdaLayer(tf.nn.sigmoid),
LambdaLayer(lambda x: tf.reshape(x, [-1, 27 * 20 * 256])), # out.shape = (B, 27 * 20 * 256)
Layer(27 * 20 * 256, 6, scope='proj_actions') # out.shape = (B, 6)
], scope='convnet')
return convnet


def make_controller():
"""createa deepq controller"""
session = tf.Session()

model = make_model()

optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

return DiscreteDeepQ(OBS_SHAPE,
NUM_ACTIONS,
model,
optimizer,
session,
random_action_probability=0.1,
minibatch_size=8,
discount_rate=0.99,
exploration_period=500000,
max_experience=10000,
target_network_update_rate=0.01,
store_every_nth=4,
train_every_nth=4)

# TODO(szymon): apparently both DeepMind and Karpathy
# people normalize their frames to sizes 80x80 and grayscale.
# should we do this?
def normalize_frame(o):
"""Change from uint in range (0, 255) to float in range (0,1)"""
return o.astype(np.float32) / 255.0

def main():
env = gym.make('Pong-v0')
controller = make_controller()

# Load existing model.
if os.path.exists(MODEL_SAVE_DIR):
print "loading model... ",
controller.restore(MODEL_SAVE_DIR)
print 'done.'
last_model_save = time.time()

# For every game
for game_no in count():
# Reset simulator
frame_tm1 = normalize_frame(env.reset())
frame_t, _, _, _ = env.step(env.action_space.sample())
frame_t = normalize_frame(frame_t)

rewards = []
for _ in range(MAX_FRAMES):
start_time = time.time()

# observation consists of two last frames
# this is important so that we can detect speed.
observation_t = np.concatenate([frame_tm1, frame_t], 2)

# pick an action according to Q-function learned so far.
action = controller.action(observation_t)

if RENDER: env.render()

# advance simulator
frame_tp1, reward, done, info = env.step(action)
frame_tp1 = normalize_frame(frame_tp1)
if done: break

observation_tp1 = np.concatenate([frame_t, frame_tp1], 2)

# store transitions
controller.store(observation_t, action, reward, observation_tp1)
# run a single iteration of SGD
controller.training_step()


frame_tm1, frame_t = frame_t, frame_tp1
rewards.append(reward)

# if real time visualization is requested throttle down FPS.
if REAL_TIME:
time_passed = time.time() - start_time
time_left = 1.0 / FPS - time_passed

if time_left > 0:
time.sleep(time_left)

# save model if time since last save is greater than
# MODEL_SAVE_EVERY_S
if time.time() - last_model_save >= MODEL_SAVE_EVERY_S:
if not os.path.exists(MODEL_SAVE_DIR):
os.makedirs(MODEL_SAVE_DIR)
controller.save(MODEL_SAVE_DIR, debug=True)
last_model_save = time.time()

# Count scores. This relies on specific score values being
# assigned by openai gym and might break in the future.
points_lost = rewards.count(-1.0)
points_won = rewards.count(1.0)
print "Game no {} is over. Points lost: {}, points won: {}".format(
game_no, points_lost, points_won)

if __name__ == '__main__':
main()

71 changes: 62 additions & 9 deletions tf_rl/controller/discrete_deepq.py
@@ -1,11 +1,14 @@
import numpy as np
import random
import tensorflow as tf
import os
import pickle
import time

from collections import deque

class DiscreteDeepQ(object):
def __init__(self, observation_size,
def __init__(self, observation_shape,
num_actions,
observation_to_actions,
optimizer,
Expand All @@ -26,7 +29,7 @@ def __init__(self, observation_size,
Parameters
-------
observation_size : int
observation_shape : int
length of the vector passed as observation
num_actions : int
number of actions that the model can execute
Expand All @@ -35,7 +38,7 @@ def __init__(self, observation_size,
that can take in observation vector or a batch
and returns scores (of unbounded values) for each
action for each observation.
input shape: [batch_size, observation_size]
input shape: [batch_size] + observation_shape
output shape: [batch_size, num_actions]
optimizer: tf.solver.*
optimizer for prediction error
Expand Down Expand Up @@ -76,7 +79,7 @@ def __init__(self, observation_size,
writer to log metrics
"""
# memorize arguments
self.observation_size = observation_size
self.observation_shape = observation_shape
self.num_actions = num_actions

self.q_network = observation_to_actions
Expand Down Expand Up @@ -105,6 +108,11 @@ def __init__(self, observation_size,

self.create_variables()

self.s.run(tf.initialize_all_variables())
self.s.run(self.target_network_update)

self.saver = tf.train.Saver()

def linear_annealing(self, n, total, p_initial, p_final):
"""Linear annealing between p_initial and p_final
over total steps - computes value at step n"""
Expand All @@ -113,19 +121,23 @@ def linear_annealing(self, n, total, p_initial, p_final):
else:
return p_initial - (n * (p_initial - p_final)) / (total)


def observation_batch_shape(self, batch_size):
return tuple([batch_size] + list(self.observation_shape))

def create_variables(self):
self.target_q_network = self.q_network.copy(scope="target_network")

# FOR REGULAR ACTION SCORE COMPUTATION
with tf.name_scope("taking_action"):
self.observation = tf.placeholder(tf.float32, (None, self.observation_size), name="observation")
self.observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="observation")
self.action_scores = tf.identity(self.q_network(self.observation), name="action_scores")
tf.histogram_summary("action_scores", self.action_scores)
self.predicted_actions = tf.argmax(self.action_scores, dimension=1, name="predicted_actions")

with tf.name_scope("estimating_future_rewards"):
# FOR PREDICTING TARGET FUTURE REWARDS
self.next_observation = tf.placeholder(tf.float32, (None, self.observation_size), name="next_observation")
self.next_observation = tf.placeholder(tf.float32, self.observation_batch_shape(None), name="next_observation")
self.next_observation_mask = tf.placeholder(tf.float32, (None,), name="next_observation_mask")
self.next_action_scores = tf.stop_gradient(self.target_q_network(self.next_observation))
tf.histogram_summary("target_action_scores", self.next_action_scores)
Expand Down Expand Up @@ -165,10 +177,11 @@ def create_variables(self):
self.summarize = tf.merge_all_summaries()
self.no_op1 = tf.no_op()


def action(self, observation):
"""Given observation returns the action that should be chosen using
DeepQ learning strategy. Does not backprop."""
assert len(observation.shape) == 1, \
assert observation.shape == self.observation_shape, \
"Action is performed based on single observation."

self.actions_executed_so_far += 1
Expand Down Expand Up @@ -208,8 +221,8 @@ def training_step(self):
samples = [self.experience[i] for i in samples]

# bach states
states = np.empty((len(samples), self.observation_size))
newstates = np.empty((len(samples), self.observation_size))
states = np.empty(self.observation_batch_shape(len(samples)))
newstates = np.empty(self.observation_batch_shape(len(samples)))
action_mask = np.zeros((len(samples), self.num_actions))

newstates_mask = np.empty((len(samples),))
Expand Down Expand Up @@ -251,3 +264,43 @@ def training_step(self):
self.iteration += 1

self.number_of_times_train_called += 1

def save(self, save_dir, debug=False):
STATE_FILE = os.path.join(save_dir, 'deepq_state')
MODEL_FILE = os.path.join(save_dir, 'model')

# deepq state
state = {
'actions_executed_so_far': self.actions_executed_so_far,
'iteration': self.iteration,
'number_of_times_store_called': self.number_of_times_store_called,
'number_of_times_train_called': self.number_of_times_train_called,
}

if debug:
print 'Saving model... ',

saving_started = time.time()

self.saver.save(self.s, MODEL_FILE)
with open(STATE_FILE, "wb") as f:
pickle.dump(state, f)

print 'done in {} s'.format(time.time() - saving_started)

def restore(self, save_dir, debug=False):
# deepq state
STATE_FILE = os.path.join(save_dir, 'deepq_state')
MODEL_FILE = os.path.join(save_dir, 'model')

with open(STATE_FILE, "rb") as f:
state = pickle.load(f)
self.saver.restore(self.s, MODEL_FILE)

self.actions_executed_so_far = state['actions_executed_so_far']
self.iteration = state['iteration']
self.number_of_times_store_called = state['number_of_times_store_called']
self.number_of_times_train_called = state['number_of_times_train_called']



75 changes: 75 additions & 0 deletions tf_rl/models.py
Expand Up @@ -87,3 +87,78 @@ def copy(self, scope=None):
given_layers = [self.input_layer.copy()] + [layer.copy() for layer in self.layers]
return MLP(self.input_sizes, self.hiddens, nonlinearities, scope=scope,
given_layers=given_layers)


class ConvLayer(object):
def __init__(self, filter_H, filter_W,
in_C, out_C,
stride=(1,1),
scope="Convolution"):
self.filter_H, self.filter_W = filter_H, filter_W
self.in_C, self.out_C = in_C, out_C
self.stride = stride
self.scope = scope

with tf.variable_scope(self.scope):
input_size = filter_H * filter_W * in_C
W_initializer = tf.random_uniform_initializer(
-1.0 / math.sqrt(input_size),
1.0 / math.sqrt(input_size))
self.W = tf.get_variable('W',
(filter_H, filter_W, in_C, out_C),
initializer=W_initializer)
self.b = tf.get_variable('b',
(out_C),
initializer=tf.constant_initializer(0))

def __call__(self, X):
with tf.variable_scope(self.scope):
return tf.nn.conv2d(X, self.W,
strides=[1] + list(self.stride) + [1],
padding='SAME') + self.b

def variables(self):
return [self.W, self.b]

def copy(self, scope=None):
scope = scope or self.scope + "_copy"

with tf.variable_scope(scope) as sc:
for v in self.variables():
tf.get_variable(base_name(v), v.get_shape(),
initializer=lambda x,dtype=tf.float32: v.initialized_value())
sc.reuse_variables()
return ConvLayer(self.filter_H, self.filter_W, self.in_C, self.out_C, self.stride, scope=sc)

class SeqLayer(object):
def __init__(self, layers, scope='seq_layer'):
self.scope = scope
self.layers = layers

def __call__(self, x):
for l in self.layers:
x = l(x)
return x

def variables(self):
return sum([l.variables() for l in self.layers], [])

def copy(self, scope=None):
scope = scope or self.scope + "_copy"
with tf.variable_scope(self.scope):
copied_layers = [layer.copy() for layer in self.layers]
return SeqLayer(copied_layers, scope=scope)


class LambdaLayer(object):
def __init__(self, f):
self.f = f

def __call__(self, x):
return self.f(x)

def variables(self):
return []

def copy(self):
return LambdaLayer(self.f)

0 comments on commit 9f5c8a0

Please sign in to comment.