Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Fix LSTM regression on truncated sequences and add regression test #2898

Merged
merged 7 commits into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/rllib/env/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@


def is_atari(env):
if (hasattr(env.observation_space, "shape")
and env.observation_space.shape is not None
and len(env.observation_space.shape) <= 2):
return False
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")


Expand Down
179 changes: 179 additions & 0 deletions python/ray/rllib/examples/cartpole_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Partially observed variant of the CartPole gym environment.

https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py

We delete the velocity component of the state, so that it can only be solved
by a LSTM policy."""

import argparse
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=200)


class CartPoleStatelessEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 60
}

def __init__(self):
self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
self.total_mass = (self.masspole + self.masscart)
self.length = 0.5 # actually half the pole's length
self.polemass_length = (self.masspole * self.length)
self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates

# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * math.pi / 360
self.x_threshold = 2.4

high = np.array([
self.x_threshold * 2,
self.theta_threshold_radians * 2,
])

self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high)

self.seed()
self.viewer = None
self.state = None

self.steps_beyond_done = None

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def step(self, action):
assert self.action_space.contains(
action), "%r (%s) invalid" % (action, type(action))
state = self.state
x, x_dot, theta, theta_dot = state
force = self.force_mag if action == 1 else -self.force_mag
costheta = math.cos(theta)
sintheta = math.sin(theta)
temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length *
(4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass)
)
xacc = (temp -
self.polemass_length * thetaacc * costheta / self.total_mass)
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
self.state = (x, x_dot, theta, theta_dot)
done = (x < -self.x_threshold or x > self.x_threshold
or theta < -self.theta_threshold_radians
or theta > self.theta_threshold_radians)
done = bool(done)

if not done:
reward = 1.0
elif self.steps_beyond_done is None:
# Pole just fell!
self.steps_beyond_done = 0
reward = 1.0
else:
self.steps_beyond_done += 1
reward = 0.0

rv = np.r_[self.state[0], self.state[2]]
return rv, reward, done, {}

def reset(self):
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4, ))
self.steps_beyond_done = None

rv = np.r_[self.state[0], self.state[2]]
return rv

def render(self, mode='human'):
screen_width = 600
screen_height = 400

world_width = self.x_threshold * 2
scale = screen_width / world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * 1.0
cartwidth = 50.0
cartheight = 30.0

if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height)
l, r, t, b = (-cartwidth / 2, cartwidth / 2, cartheight / 2,
-cartheight / 2)
axleoffset = cartheight / 4.0
cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
self.carttrans = rendering.Transform()
cart.add_attr(self.carttrans)
self.viewer.add_geom(cart)
l, r, t, b = (-polewidth / 2, polewidth / 2,
polelen - polewidth / 2, -polewidth / 2)
pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
pole.set_color(.8, .6, .4)
self.poletrans = rendering.Transform(translation=(0, axleoffset))
pole.add_attr(self.poletrans)
pole.add_attr(self.carttrans)
self.viewer.add_geom(pole)
self.axle = rendering.make_circle(polewidth / 2)
self.axle.add_attr(self.poletrans)
self.axle.add_attr(self.carttrans)
self.axle.set_color(.5, .5, .8)
self.viewer.add_geom(self.axle)
self.track = rendering.Line((0, carty), (screen_width, carty))
self.track.set_color(0, 0, 0)
self.viewer.add_geom(self.track)

if self.state is None:
return None

x = self.state
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART
self.carttrans.set_translation(cartx, carty)
self.poletrans.set_rotation(-x[2])

return self.viewer.render(return_rgb_array=mode == 'rgb_array')

def close(self):
if self.viewer:
self.viewer.close()


if __name__ == "__main__":
import ray
from ray import tune

args = parser.parse_args()

tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())

ray.init()
tune.run_experiments({
"test": {
"env": "cartpole_stateless",
"run": "PG",
"stop": {
"episode_reward_mean": args.stop
},
"config": {
"model": {
"use_lstm": True,
},
},
}
})
2 changes: 2 additions & 0 deletions python/ray/rllib/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ def _build_layers(self, inputs, num_outputs, options):
self.state_in = [c_in, h_in]

# Setup LSTM outputs
state_in = rnn.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=state_in,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/visual_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _parse_results(res_path):
for line in f:
pass
res_dict = _flatten_dict(json.loads(line.strip()))
except Exception as e:
except Exception:
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
return res_dict

Expand All @@ -45,7 +45,7 @@ def _parse_configs(cfg_path):
try:
with open(cfg_path) as f:
cfg_dict = _flatten_dict(json.load(f))
except Exception as e:
except Exception:
logger.exception("Config parsing failed.")
return cfg_dict

Expand Down
10 changes: 10 additions & 0 deletions test/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2}'

docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env Pong-ram-v4 \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2}'

docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v0 \
Expand Down Expand Up @@ -293,6 +300,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2

docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=75

# No Xray for PyTorch
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
Expand Down