From 032b595941b5d3e302bb9acb7160e930b9ef4001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C5=82a=C5=BCej=20O?= Date: Mon, 7 May 2018 20:32:00 +0200 Subject: [PATCH] Add more atari games. First we add freeway, but more can be added easily. --- tensor2tensor/data_generators/gym.py | 71 +++++++++++++++++--- tensor2tensor/rl/envs/simulated_batch_env.py | 41 ++++------- tensor2tensor/rl/envs/utils.py | 12 +--- tensor2tensor/rl/model_rl_experiment.py | 5 +- 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index 9ea820233..1d85d0ac3 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -42,7 +42,7 @@ flags = tf.flags FLAGS = flags.FLAGS -flags.DEFINE_string("agent_policy_path", "", "File with model for pong") +flags.DEFINE_string("agent_policy_path", "", "File with model for agent") class GymDiscreteProblem(video_utils.VideoProblem): @@ -99,6 +99,14 @@ def env(self): def num_actions(self): return self.env.action_space.n + @property + def frame_height(self): + return self.env.observation_space.shape[0] + + @property + def frame_width(self): + return self.env.observation_space.shape[1] + @property def num_rewards(self): raise NotImplementedError() @@ -150,14 +158,6 @@ class GymPongRandom5k(GymDiscreteProblem): def env_name(self): return "PongDeterministic-v4" - @property - def frame_height(self): - return 210 - - @property - def frame_width(self): - return 160 - @property def min_reward(self): return -1 @@ -179,9 +179,38 @@ class GymPongRandom50k(GymPongRandom5k): def num_steps(self): return 50000 +@registry.register_problem +class GymFreewayRandom5k(GymDiscreteProblem): + """Freeway game, random actions.""" + + @property + def env_name(self): + return "FreewayDeterministic-v4" + + @property + def min_reward(self): + return 0 + + @property + def num_rewards(self): + return 2 + + @property + def num_steps(self): + return 5000 + + +@registry.register_problem +class GymFreewayRandom50k(GymFreewayRandom5k): + """Freeway game, random actions.""" + + @property + def num_steps(self): + return 50000 + @registry.register_problem -class GymDiscreteProblemWithAgent(GymPongRandom5k): +class GymDiscreteProblemWithAgent(GymDiscreteProblem): """Gym environment with discrete actions and rewards and an agent.""" def __init__(self, *args, **kwargs): @@ -190,7 +219,7 @@ def __init__(self, *args, **kwargs): self.debug_dump_frames_path = "debug_frames_env" # defaults - self.environment_spec = lambda: gym.make("PongDeterministic-v4") + self.environment_spec = lambda: gym.make(self.env_name) self.in_graph_wrappers = [] self.collect_hparams = rl.atari_base() self.settable_num_steps = 20000 @@ -286,3 +315,23 @@ def restore_networks(self, sess): ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir) ckpt = ckpts.model_checkpoint_path env_model_loader.restore(sess, ckpt) + + +@registry.register_problem +class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k): + pass + + +@registry.register_problem +class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k): + pass + + +@registry.register_problem +class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k): + pass + + +@registry.register_problem +class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k): + pass diff --git a/tensor2tensor/rl/envs/simulated_batch_env.py b/tensor2tensor/rl/envs/simulated_batch_env.py index 9a229b424..aecf2be20 100644 --- a/tensor2tensor/rl/envs/simulated_batch_env.py +++ b/tensor2tensor/rl/envs/simulated_batch_env.py @@ -32,15 +32,13 @@ from tensor2tensor.utils import trainer_lib import tensorflow as tf +import numpy as np flags = tf.flags FLAGS = flags.FLAGS -flags.DEFINE_string("frames_path", "", "Path to the first frames.") - - class SimulatedBatchEnv(InGraphBatchEnv): """Batch of environments inside the TensorFlow graph. @@ -49,42 +47,31 @@ class SimulatedBatchEnv(InGraphBatchEnv): flags are held in according variables. """ - def __init__(self, length, observ_shape, observ_dtype, action_shape, - action_dtype): + def __init__(self, environment_lambda, length): """Batch of environments inside the TensorFlow graph.""" self.length = length + initalization_env = environment_lambda() hparams = trainer_lib.create_hparams( FLAGS.hparams_set, problem_name=FLAGS.problem, data_dir="UNUSED") hparams.force_full_predict = True self._model = registry.model(FLAGS.model)( hparams, tf.estimator.ModeKeys.PREDICT) - self.action_shape = action_shape - self.action_dtype = action_dtype - - with open(os.path.join(FLAGS.frames_path, "frame1.png"), "rb") as f: - png_frame_1_raw = f.read() + self.action_space = initalization_env.action_space + self.action_shape = list(initalization_env.action_space.shape) + self.action_dtype = tf.int32 - with open(os.path.join(FLAGS.frames_path, "frame2.png"), "rb") as f: - png_frame_2_raw = f.read() + obs_1 = initalization_env.reset() + obs_2 = initalization_env.step(0)[0] - self.frame_1 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_1_raw), - tf.float32), 0) - self.frame_2 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_2_raw), - tf.float32), 0) + self.frame_1 = tf.expand_dims(tf.cast(obs_1, tf.float32), 0) + self.frame_2 = tf.expand_dims(tf.cast(obs_2, tf.float32), 0) - shape = (self.length,) + observ_shape - self._observ = tf.Variable(tf.zeros(shape, observ_dtype), trainable=False) - self._prev_observ = tf.Variable(tf.zeros(shape, observ_dtype), + shape = (self.length,) + initalization_env.observation_space.shape + # TODO(blazej0) - make more generic - make higher number of previous observations possible. + self._observ = tf.Variable(tf.zeros(shape, tf.float32), trainable=False) + self._prev_observ = tf.Variable(tf.zeros(shape, tf.float32), trainable=False) - self._starting_observ = tf.Variable(tf.zeros(shape, observ_dtype), - trainable=False) - - observ_dtype = tf.int64 - - @property - def action_space(self): - return gym.make("PongNoFrameskip-v4").action_space def __len__(self): """Number of combined environments.""" diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py index 26e12eab7..08e2d07fc 100644 --- a/tensor2tensor/rl/envs/utils.py +++ b/tensor2tensor/rl/envs/utils.py @@ -287,7 +287,7 @@ def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False): hparams, "in_graph_wrappers") else [] if hparams.simulated_environment: - cur_batch_env = define_simulated_batch_env(num_agents) + cur_batch_env = define_simulated_batch_env(environment_lambda, num_agents) else: cur_batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb) for w in wrappers: @@ -306,12 +306,6 @@ def define_batch_env(constructor, num_agents, xvfb=False): return env -def define_simulated_batch_env(num_agents): - # TODO(blazej0): the parameters should be infered. - observ_shape = (210, 160, 3) - observ_dtype = tf.float32 - action_shape = [] - action_dtype = tf.int32 - cur_batch_env = simulated_batch_env.SimulatedBatchEnv( - num_agents, observ_shape, observ_dtype, action_shape, action_dtype) +def define_simulated_batch_env(environment_lambda, num_agents): + cur_batch_env = simulated_batch_env.SimulatedBatchEnv(environment_lambda, num_agents) return cur_batch_env diff --git a/tensor2tensor/rl/model_rl_experiment.py b/tensor2tensor/rl/model_rl_experiment.py index 4fd89022e..bfc5075d3 100644 --- a/tensor2tensor/rl/model_rl_experiment.py +++ b/tensor2tensor/rl/model_rl_experiment.py @@ -51,7 +51,7 @@ def train(hparams, output_dir): time_delta = time.time() - start_time print(line+"Step {}.1. - generate data from policy. " "Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta)))) - FLAGS.problem = "gym_discrete_problem_with_agent" + FLAGS.problem = "gym_discrete_problem_with_agent_on_{}".format(hparams.game) FLAGS.agent_policy_path = last_model gym_problem = registry.problem(FLAGS.problem) gym_problem.settable_num_steps = hparams.true_env_generator_num_steps @@ -76,7 +76,7 @@ def train(hparams, output_dir): print(line+"Step {}.3. - evalue env model. " "Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta)))) gym_simulated_problem = registry.problem( - "gym_simulated_discrete_problem_with_agent") + "gym_simulated_discrete_problem_with_agent_on_{}".format(hparams.game)) sim_steps = hparams.simulated_env_generator_num_steps gym_simulated_problem.settable_num_steps = sim_steps gym_simulated_problem.generate_data(iter_data_dir, tmp_dir) @@ -115,6 +115,7 @@ def main(_): simulated_env_generator_num_steps=300, ppo_epochs_num=200, ppo_epoch_length=300, + game="pong", ) train(hparams, FLAGS.output_dir)