Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Model-Based RL: Player #1330

Merged
merged 24 commits into from
Jan 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
79fa882
SimulatedEnv with gym-like interface.
konradczechowski Dec 12, 2018
7b77890
Initial Player
konradczechowski Dec 19, 2018
43aaa1d
Player: Add reward header, keybord reset, CLI, move to player.py
konradczechowski Dec 19, 2018
1b39563
Player: add WAIT mode, few CLI options.
konradczechowski Dec 20, 2018
ecaffcd
Introduce Policy Inferencer
konradczechowski Dec 21, 2018
f65876e
Recording videos for ppo and player, some refactoring.
konradczechowski Dec 27, 2018
0ac71c4
Player refactor. Add real env recording with PPO agent.
konradczechowski Dec 28, 2018
1c9510d
Extend CLI documentation. Remove some imports.
konradczechowski Dec 28, 2018
fbb5cf5
Pylint
konradczechowski Dec 28, 2018
36efbd3
Extend documentation.
konradczechowski Dec 28, 2018
2bb98ed
Move gym.utils.play to global imports.
konradczechowski Dec 28, 2018
8cf6dd5
Correct dopamine import.
konradczechowski Dec 28, 2018
3da2c6a
Remove SimulatedEnv (unnecesarry wrapper for FlatBatchEnv<SimulatedBa…
konradczechowski Jan 3, 2019
b7b7621
Replace join_and_check with os.path.join.
konradczechowski Jan 3, 2019
76242aa
Move generation of initial_frame_chooser function to rl_utils.
konradczechowski Jan 3, 2019
7905d3f
Move make_simulated_env_fn from trainer_model_based.py to rl.py
konradczechowski Jan 3, 2019
fd8ddb9
Remove trainer_model_based imports, clean up player and record_ppo FL…
konradczechowski Jan 4, 2019
ef902eb
Move setup_env and load_t2t_gym_env to T2TGymEnv.
konradczechowski Jan 4, 2019
1c9cdcf
Correct relative imports.
konradczechowski Jan 4, 2019
302ef71
Custom policy world_model and data paths for player.
konradczechowski Jan 7, 2019
ead812e
Enable BatchGymEnv to load directly from checkpoint file.
konradczechowski Jan 7, 2019
eef62c6
Small fix record_ppo.
konradczechowski Jan 7, 2019
5614fd8
Remove unused record_ppo.py.
konradczechowski Jan 7, 2019
dd1a4b5
Merge branch 'master' into player_kc
lukaszkaiser Jan 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 54 additions & 1 deletion tensor2tensor/data_generators/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

import collections
import itertools
import os
import random
import re

from gym.spaces import Box
import numpy as np

Expand Down Expand Up @@ -185,7 +188,10 @@ def current_epoch_rollouts(self, split=None, minimal_rollout_frames=0):
if not rollouts_by_split:
if split is not None:
raise ValueError(
"generate_data() should first be called in the current epoch"
"Data is not splitted into train/dev/test. If data created by "
"environment interaction (NOT loaded from disk) you should call "
"generate_data() first. Note that generate_data() will write to "
"disk and can corrupt your experiment data."
)
else:
rollouts = self._current_epoch_rollouts
Expand Down Expand Up @@ -636,6 +642,53 @@ def base_env_name(self):
def num_channels(self):
return self.observation_space.shape[2]

@staticmethod
def infer_last_epoch_num(data_dir):
"""Infer highest epoch number from file names in data_dir."""
names = os.listdir(data_dir)
epochs_str = [re.findall(pattern=r".*\.(-?\d+)$", string=name)
for name in names]
epochs_str = sum(epochs_str, [])
return max([int(epoch_str) for epoch_str in epochs_str])

@staticmethod
def setup_env_from_hparams(hparams, batch_size, max_num_noops):
game_mode = "NoFrameskip-v4"
camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
camel_game_name += game_mode
env_name = camel_game_name

env = T2TGymEnv(base_env_name=env_name,
batch_size=batch_size,
grayscale=hparams.grayscale,
resize_width_factor=hparams.resize_width_factor,
resize_height_factor=hparams.resize_height_factor,
rl_env_max_episode_steps=hparams.rl_env_max_episode_steps,
max_num_noops=max_num_noops, maxskip_envs=True)
return env

@staticmethod
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None):
"""Load T2TBatchGymEnv with data from one epoch.

Args:
which_epoch_data: data from which epoch to load.
"""
t2t_env = T2TGymEnv.setup_env_from_hparams(
hparams, batch_size=hparams.real_batch_size,
max_num_noops=hparams.max_num_noops
)
# Load data.
if which_epoch_data is not None:
if which_epoch_data == "last":
which_epoch_data = T2TGymEnv.infer_last_epoch_num(data_dir)
assert isinstance(which_epoch_data, int), \
"{}".format(type(which_epoch_data))
t2t_env.start_new_epoch(which_epoch_data, data_dir)
else:
t2t_env.start_new_epoch(-999)
return t2t_env

def _derive_observation_space(self, orig_observ_space):
height, width, channels = orig_observ_space.shape
if self.grayscale:
Expand Down
22 changes: 22 additions & 0 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,28 @@ def env_fn(in_graph):
return env_fn


def make_simulated_env_fn_from_hparams(
real_env, hparams, batch_size, initial_frame_chooser, model_dir,
sim_video_dir=None):
"""Creates a simulated env_fn."""
model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
if hparams.wm_policy_param_sharing:
model_hparams.optimizer_zero_grads = True
return make_simulated_env_fn(
reward_range=real_env.reward_range,
observation_space=real_env.observation_space,
action_space=real_env.action_space,
frame_stack_size=hparams.frame_stack_size,
frame_height=real_env.frame_height, frame_width=real_env.frame_width,
initial_frame_chooser=initial_frame_chooser, batch_size=batch_size,
model_name=hparams.generative_model,
model_hparams=trainer_lib.create_hparams(hparams.generative_model_params),
model_dir=model_dir,
intrinsic_reward_scale=hparams.intrinsic_reward_scale,
sim_video_dir=sim_video_dir,
)


def get_policy(observations, hparams, action_space):
"""Get a policy network.

Expand Down
9 changes: 6 additions & 3 deletions tensor2tensor/rl/envs/simulated_batch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ def initialize(self, sess):
model_loader = tf.train.Saver(
var_list=tf.global_variables(scope="next_frame*") # pylint:disable=unexpected-keyword-arg
)
trainer_lib.restore_checkpoint(
self._model_dir, saver=model_loader, sess=sess, must_restore=True
)
if os.path.isdir(self._model_dir):
trainer_lib.restore_checkpoint(
self._model_dir, saver=model_loader, sess=sess, must_restore=True
)
else:
model_loader.restore(sess=sess, save_path=self._model_dir)

def __str__(self):
return "SimulatedEnv"
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/rl/envs/simulated_batch_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@


class FlatBatchEnv(Env):
"""TODO(konradczechowski): Add doc-string."""

"""Gym environment interface for Batched Environments (with batch size = 1)"""
def __init__(self, batch_env):
if batch_env.batch_size != 1:
raise ValueError("Number of environments in batch must be equal to one")
Expand Down