Skip to content

Commit

Permalink
[RLlib] Fix RNN learning for tf-eager/tf2.x. (#11720)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Nov 2, 2020
1 parent bfc4f95 commit 54d85a6
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 73 deletions.
26 changes: 23 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ py_test(
args = ["--torch", "--yaml-dir=tuned_examples/ddpg"]
)

# DDPPO
py_test(
name = "run_regression_tests_cartpole_ddppo_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = glob(["tuned_examples/ppo/cartpole-ddppo.yaml"]),
args = ["--yaml-dir=tuned_examples/ppo", "--torch"]
)

# DQN/Simple-Q
py_test(
name = "run_regression_tests_cartpole_dqn_tf",
Expand Down Expand Up @@ -1555,7 +1566,7 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand All @@ -1564,7 +1575,16 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=tf", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=tf2", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand All @@ -1573,7 +1593,7 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand Down
7 changes: 5 additions & 2 deletions rllib/examples/cartpole_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--torch", action="store_true")
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--use-prev-action-reward", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
Expand Down Expand Up @@ -45,7 +46,9 @@
"use_lstm": True,
"lstm_use_prev_action_reward": args.use_prev_action_reward,
},
"framework": "torch" if args.torch else "tf",
"framework": args.framework,
# Run with tracing enabled for tfe/tf2.
"eager_tracing": args.framework in ["tfe", "tf2"],
})

stop = {
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def compute_penalty(actions, rewards):
"model": {
"custom_model": "eager_model"
},
# Alternatively, use "tf2" here for enforcing TF version 2.x.
"framework": "tfe",
}
stop = {
Expand Down
62 changes: 43 additions & 19 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.util.debug import log_once
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -96,15 +97,16 @@ def __init__(self, *args, **kwargs):
self._traced_apply_gradients = None
super(TracedEagerPolicy, self).__init__(*args, **kwargs)

@override(Policy)
@override(eager_policy_cls)
@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
def _learn_on_batch_eager(self, samples):

if self._traced_learn_on_batch is None:
self._traced_learn_on_batch = tf.function(
super(TracedEagerPolicy, self).learn_on_batch,
autograph=False)
super(TracedEagerPolicy, self)._learn_on_batch_eager,
autograph=False,
experimental_relax_shapes=True)

return self._traced_learn_on_batch(samples)

Expand All @@ -130,21 +132,23 @@ def compute_actions(self,
if self._traced_compute_actions is None:
self._traced_compute_actions = tf.function(
super(TracedEagerPolicy, self).compute_actions,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_compute_actions(
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
info_batch, episodes, explore, timestep, **kwargs)

@override(Policy)
@override(eager_policy_cls)
@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
def _compute_gradients_eager(self, samples):

if self._traced_compute_gradients is None:
self._traced_compute_gradients = tf.function(
super(TracedEagerPolicy, self).compute_gradients,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_compute_gradients(samples)

Expand All @@ -156,7 +160,8 @@ def apply_gradients(self, grads):
if self._traced_apply_gradients is None:
self._traced_apply_gradients = tf.function(
super(TracedEagerPolicy, self).apply_gradients,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_apply_gradients(grads)

Expand Down Expand Up @@ -208,6 +213,12 @@ def __init__(self, observation_space, action_space, config):
self._loss_initialized = False
self._sess = None

self._loss = loss_fn
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
callable(get_batch_divisibility_req) else \
(get_batch_divisibility_req or 1)
self._max_seq_len = config["model"]["max_seq_len"]

if get_default_config:
config = dict(get_default_config(), **config)

Expand Down Expand Up @@ -287,18 +298,36 @@ def postprocess_trajectory(self,
return sample_batch

@override(Policy)
def learn_on_batch(self, samples):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
return self._learn_on_batch_eager(samples)

@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
def _learn_on_batch_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
self._apply_gradients(grads_and_vars)
return stats

@override(Policy)
def compute_gradients(self, samples):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
return self._compute_gradients_eager(samples)

@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
def _compute_gradients_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
grads = [g for g, v in grads_and_vars]
Expand Down Expand Up @@ -396,7 +425,8 @@ def compute_actions(self,
extra_fetches.update(extra_action_fetches_fn(self))

# Update our global timestep by the batch size.
self.global_timestep += len(obs_batch)
self.global_timestep += len(obs_batch) if \
isinstance(obs_batch, (tuple, list)) else obs_batch.shape[0]

return actions, state_out, extra_fetches

Expand Down Expand Up @@ -554,14 +584,8 @@ def _compute_gradients(self, samples):
state_in.append(samples["state_in_{}".format(i)])
self._state_in = state_in

self._seq_lens = None
if len(state_in) > 0:
self._seq_lens = tf.ones(
samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32)
samples["seq_lens"] = self._seq_lens

model_out, _ = self.model(samples, self._state_in,
self._seq_lens)
samples.get("seq_lens"))
loss = loss_fn(self, self.model, self.dist_class, samples)

variables = self.model.trainable_variables()
Expand Down
9 changes: 6 additions & 3 deletions rllib/policy/rnn_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,23 @@ def chop_into_sequences(episode_ids,
if seq_len:
seq_lens.append(seq_len)
assert sum(seq_lens) == len(unique_ids)
seq_lens = np.array(seq_lens)
seq_lens = np.array(seq_lens, dtype=np.int32)

# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
max_seq_len = max(seq_lens) + _extra_padding

feature_sequences = []
for f in feature_columns:
f = np.array(f)
# Save unnecessary copy.
if not isinstance(f, np.ndarray):
f = np.array(f)
length = len(seq_lens) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_:
f_pad = [None] * length
else:
f_pad = np.zeros((length, ) + np.shape(f)[1:])
# Make sure type doesn't change.
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
seq_base = 0
i = 0
for len_ in seq_lens:
Expand Down
75 changes: 29 additions & 46 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,22 @@ def compute_log_likelihoods(
@DeveloperAPI
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Compute gradients (will calculate all losses and `backward()`
# them to get the grads).
grads, fetches = self.compute_gradients(postprocessed_batch)

# Step the optimizers.
for i, opt in enumerate(self._optimizers):
opt.step()

if self.model:
fetches["model"] = self.model.metrics()
return fetches

@override(Policy)
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
Expand All @@ -341,8 +357,6 @@ def learn_on_batch(
)

train_batch = self._lazy_tensor_dict(postprocessed_batch)

# Calculate the actual policy loss.
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))

Expand All @@ -358,26 +372,30 @@ def learn_on_batch(

assert len(loss_out) == len(self._optimizers)

# assert not any(torch.isnan(l) for l in loss_out)
fetches = self.extra_compute_grad_fetches()

# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}

all_grads = []
for i, opt in enumerate(self._optimizers):
# Erase gradients in all vars of this optimizer.
opt.zero_grad()
# Recompute gradients of loss over all variables.
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
grad_info.update(self.extra_grad_process(opt, loss_out[i]))

if self.distributed_world_size:
grads = []
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad)
grads = []
# Note that return values are just references;
# Calling zero_grad would modify the values.
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad)
all_grads.append(p.grad.data.cpu().numpy())
else:
all_grads.append(None)

if self.distributed_world_size:
start = time.time()
if torch.cuda.is_available():
# Sadly, allreduce_coalesced does not work with CUDA yet.
Expand All @@ -395,45 +413,10 @@ def learn_on_batch(

grad_info["allreduce_latency"] += time.time() - start

# Step the optimizers.
for i, opt in enumerate(self._optimizers):
opt.step()

grad_info["allreduce_latency"] /= len(self._optimizers)
grad_info.update(self.extra_grad_info(train_batch))
if self.model:
grad_info["model"] = self.model.metrics()
return dict(fetches, **{LEARNER_STATS_KEY: grad_info})

@override(Policy)
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))
assert len(loss_out) == len(self._optimizers)
fetches = self.extra_compute_grad_fetches()

grad_process_info = {}
grads = []
for i, opt in enumerate(self._optimizers):
opt.zero_grad()
loss_out[i].backward()
grad_process_info = self.extra_grad_process(opt, loss_out[i])

# Note that return values are just references;
# calling zero_grad will modify the values
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)

grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})

@override(Policy)
@DeveloperAPI
Expand Down

0 comments on commit 54d85a6

Please sign in to comment.