Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("agent_policy_path", "", "File with model for pong")
flags.DEFINE_string("agent_policy_path", "", "File with model for agent")


class GymDiscreteProblem(video_utils.VideoProblem):
Expand Down Expand Up @@ -99,6 +99,14 @@ def env(self):
def num_actions(self):
return self.env.action_space.n

@property
def frame_height(self):
return self.env.observation_space.shape[0]

@property
def frame_width(self):
return self.env.observation_space.shape[1]

@property
def num_rewards(self):
raise NotImplementedError()
Expand Down Expand Up @@ -150,14 +158,6 @@ class GymPongRandom5k(GymDiscreteProblem):
def env_name(self):
return "PongDeterministic-v4"

@property
def frame_height(self):
return 210

@property
def frame_width(self):
return 160

@property
def min_reward(self):
return -1
Expand All @@ -179,9 +179,38 @@ class GymPongRandom50k(GymPongRandom5k):
def num_steps(self):
return 50000

@registry.register_problem
class GymFreewayRandom5k(GymDiscreteProblem):
"""Freeway game, random actions."""

@property
def env_name(self):
return "FreewayDeterministic-v4"

@property
def min_reward(self):
return 0

@property
def num_rewards(self):
return 2

@property
def num_steps(self):
return 5000


@registry.register_problem
class GymFreewayRandom50k(GymFreewayRandom5k):
"""Freeway game, random actions."""

@property
def num_steps(self):
return 50000


@registry.register_problem
class GymDiscreteProblemWithAgent(GymPongRandom5k):
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
"""Gym environment with discrete actions and rewards and an agent."""

def __init__(self, *args, **kwargs):
Expand All @@ -190,7 +219,7 @@ def __init__(self, *args, **kwargs):
self.debug_dump_frames_path = "debug_frames_env"

# defaults
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
self.environment_spec = lambda: gym.make(self.env_name)
self.in_graph_wrappers = []
self.collect_hparams = rl.atari_base()
self.settable_num_steps = 20000
Expand Down Expand Up @@ -286,3 +315,23 @@ def restore_networks(self, sess):
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
ckpt = ckpts.model_checkpoint_path
env_model_loader.restore(sess, ckpt)


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
pass
41 changes: 14 additions & 27 deletions tensor2tensor/rl/envs/simulated_batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@
from tensor2tensor.utils import trainer_lib

import tensorflow as tf
import numpy as np


flags = tf.flags
FLAGS = flags.FLAGS


flags.DEFINE_string("frames_path", "", "Path to the first frames.")


class SimulatedBatchEnv(InGraphBatchEnv):
"""Batch of environments inside the TensorFlow graph.

Expand All @@ -49,42 +47,31 @@ class SimulatedBatchEnv(InGraphBatchEnv):
flags are held in according variables.
"""

def __init__(self, length, observ_shape, observ_dtype, action_shape,
action_dtype):
def __init__(self, environment_lambda, length):
"""Batch of environments inside the TensorFlow graph."""
self.length = length
initalization_env = environment_lambda()
hparams = trainer_lib.create_hparams(
FLAGS.hparams_set, problem_name=FLAGS.problem, data_dir="UNUSED")
hparams.force_full_predict = True
self._model = registry.model(FLAGS.model)(
hparams, tf.estimator.ModeKeys.PREDICT)

self.action_shape = action_shape
self.action_dtype = action_dtype

with open(os.path.join(FLAGS.frames_path, "frame1.png"), "rb") as f:
png_frame_1_raw = f.read()
self.action_space = initalization_env.action_space
self.action_shape = list(initalization_env.action_space.shape)
self.action_dtype = tf.int32

with open(os.path.join(FLAGS.frames_path, "frame2.png"), "rb") as f:
png_frame_2_raw = f.read()
obs_1 = initalization_env.reset()
obs_2 = initalization_env.step(0)[0]

self.frame_1 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_1_raw),
tf.float32), 0)
self.frame_2 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_2_raw),
tf.float32), 0)
self.frame_1 = tf.expand_dims(tf.cast(obs_1, tf.float32), 0)
self.frame_2 = tf.expand_dims(tf.cast(obs_2, tf.float32), 0)

shape = (self.length,) + observ_shape
self._observ = tf.Variable(tf.zeros(shape, observ_dtype), trainable=False)
self._prev_observ = tf.Variable(tf.zeros(shape, observ_dtype),
shape = (self.length,) + initalization_env.observation_space.shape
# TODO(blazej0) - make more generic - make higher number of previous observations possible.
self._observ = tf.Variable(tf.zeros(shape, tf.float32), trainable=False)
self._prev_observ = tf.Variable(tf.zeros(shape, tf.float32),
trainable=False)
self._starting_observ = tf.Variable(tf.zeros(shape, observ_dtype),
trainable=False)

observ_dtype = tf.int64

@property
def action_space(self):
return gym.make("PongNoFrameskip-v4").action_space

def __len__(self):
"""Number of combined environments."""
Expand Down
12 changes: 3 additions & 9 deletions tensor2tensor/rl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False):
hparams, "in_graph_wrappers") else []

if hparams.simulated_environment:
cur_batch_env = define_simulated_batch_env(num_agents)
cur_batch_env = define_simulated_batch_env(environment_lambda, num_agents)
else:
cur_batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb)
for w in wrappers:
Expand All @@ -306,12 +306,6 @@ def define_batch_env(constructor, num_agents, xvfb=False):
return env


def define_simulated_batch_env(num_agents):
# TODO(blazej0): the parameters should be infered.
observ_shape = (210, 160, 3)
observ_dtype = tf.float32
action_shape = []
action_dtype = tf.int32
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(
num_agents, observ_shape, observ_dtype, action_shape, action_dtype)
def define_simulated_batch_env(environment_lambda, num_agents):
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(environment_lambda, num_agents)
return cur_batch_env
5 changes: 3 additions & 2 deletions tensor2tensor/rl/model_rl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(hparams, output_dir):
time_delta = time.time() - start_time
print(line+"Step {}.1. - generate data from policy. "
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
FLAGS.problem = "gym_discrete_problem_with_agent"
FLAGS.problem = "gym_discrete_problem_with_agent_on_{}".format(hparams.game)
FLAGS.agent_policy_path = last_model
gym_problem = registry.problem(FLAGS.problem)
gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
Expand All @@ -76,7 +76,7 @@ def train(hparams, output_dir):
print(line+"Step {}.3. - evalue env model. "
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
gym_simulated_problem = registry.problem(
"gym_simulated_discrete_problem_with_agent")
"gym_simulated_discrete_problem_with_agent_on_{}".format(hparams.game))
sim_steps = hparams.simulated_env_generator_num_steps
gym_simulated_problem.settable_num_steps = sim_steps
gym_simulated_problem.generate_data(iter_data_dir, tmp_dir)
Expand Down Expand Up @@ -115,6 +115,7 @@ def main(_):
simulated_env_generator_num_steps=300,
ppo_epochs_num=200,
ppo_epoch_length=300,
game="pong",
)
train(hparams, FLAGS.output_dir)

Expand Down