Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix RNN learning for tf-eager/tf2.x. #11720

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ py_test(
args = ["--torch", "--yaml-dir=tuned_examples/ddpg"]
)

# DDPPO
py_test(
name = "run_regression_tests_cartpole_ddppo_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = glob(["tuned_examples/ppo/cartpole-ddppo.yaml"]),
args = ["--yaml-dir=tuned_examples/ppo", "--torch"]
)

# DQN/Simple-Q
py_test(
name = "run_regression_tests_cartpole_dqn_tf",
Expand Down Expand Up @@ -1555,7 +1566,7 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand All @@ -1564,7 +1575,16 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=tf", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=tf2", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand All @@ -1573,7 +1593,7 @@ py_test(
tags = ["examples", "examples_C"],
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

py_test(
Expand Down
7 changes: 5 additions & 2 deletions rllib/examples/cartpole_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--torch", action="store_true")
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--use-prev-action-reward", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
Expand Down Expand Up @@ -45,7 +46,9 @@
"use_lstm": True,
"lstm_use_prev_action_reward": args.use_prev_action_reward,
},
"framework": "torch" if args.torch else "tf",
"framework": args.framework,
# Run with tracing enabled for tfe/tf2.
"eager_tracing": args.framework in ["tfe", "tf2"],
})

stop = {
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def compute_penalty(actions, rewards):
"model": {
"custom_model": "eager_model"
},
# Alternatively, use "tf2" here for enforcing TF version 2.x.
"framework": "tfe",
}
stop = {
Expand Down
62 changes: 43 additions & 19 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.util.debug import log_once
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -96,15 +97,16 @@ def __init__(self, *args, **kwargs):
self._traced_apply_gradients = None
super(TracedEagerPolicy, self).__init__(*args, **kwargs)

@override(Policy)
@override(eager_policy_cls)
@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
def _learn_on_batch_eager(self, samples):

if self._traced_learn_on_batch is None:
self._traced_learn_on_batch = tf.function(
super(TracedEagerPolicy, self).learn_on_batch,
autograph=False)
super(TracedEagerPolicy, self)._learn_on_batch_eager,
autograph=False,
experimental_relax_shapes=True)

return self._traced_learn_on_batch(samples)

Expand All @@ -130,21 +132,23 @@ def compute_actions(self,
if self._traced_compute_actions is None:
self._traced_compute_actions = tf.function(
super(TracedEagerPolicy, self).compute_actions,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_compute_actions(
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
info_batch, episodes, explore, timestep, **kwargs)

@override(Policy)
@override(eager_policy_cls)
@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
def _compute_gradients_eager(self, samples):

if self._traced_compute_gradients is None:
self._traced_compute_gradients = tf.function(
super(TracedEagerPolicy, self).compute_gradients,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_compute_gradients(samples)

Expand All @@ -156,7 +160,8 @@ def apply_gradients(self, grads):
if self._traced_apply_gradients is None:
self._traced_apply_gradients = tf.function(
super(TracedEagerPolicy, self).apply_gradients,
autograph=False)
autograph=False,
experimental_relax_shapes=True)

return self._traced_apply_gradients(grads)

Expand Down Expand Up @@ -208,6 +213,12 @@ def __init__(self, observation_space, action_space, config):
self._loss_initialized = False
self._sess = None

self._loss = loss_fn
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
callable(get_batch_divisibility_req) else \
(get_batch_divisibility_req or 1)
self._max_seq_len = config["model"]["max_seq_len"]

if get_default_config:
config = dict(get_default_config(), **config)

Expand Down Expand Up @@ -287,18 +298,36 @@ def postprocess_trajectory(self,
return sample_batch

@override(Policy)
def learn_on_batch(self, samples):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
return self._learn_on_batch_eager(samples)

@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
def _learn_on_batch_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
self._apply_gradients(grads_and_vars)
return stats

@override(Policy)
def compute_gradients(self, samples):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
return self._compute_gradients_eager(samples)

@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
def _compute_gradients_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
grads = [g for g, v in grads_and_vars]
Expand Down Expand Up @@ -396,7 +425,8 @@ def compute_actions(self,
extra_fetches.update(extra_action_fetches_fn(self))

# Update our global timestep by the batch size.
self.global_timestep += len(obs_batch)
self.global_timestep += len(obs_batch) if \
isinstance(obs_batch, (tuple, list)) else obs_batch.shape[0]

return actions, state_out, extra_fetches

Expand Down Expand Up @@ -554,14 +584,8 @@ def _compute_gradients(self, samples):
state_in.append(samples["state_in_{}".format(i)])
self._state_in = state_in

self._seq_lens = None
if len(state_in) > 0:
self._seq_lens = tf.ones(
samples[SampleBatch.CUR_OBS].shape[0], dtype=tf.int32)
samples["seq_lens"] = self._seq_lens

model_out, _ = self.model(samples, self._state_in,
self._seq_lens)
samples.get("seq_lens"))
loss = loss_fn(self, self.model, self.dist_class, samples)

variables = self.model.trainable_variables()
Expand Down
9 changes: 6 additions & 3 deletions rllib/policy/rnn_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,23 @@ def chop_into_sequences(episode_ids,
if seq_len:
seq_lens.append(seq_len)
assert sum(seq_lens) == len(unique_ids)
seq_lens = np.array(seq_lens)
seq_lens = np.array(seq_lens, dtype=np.int32)

# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
max_seq_len = max(seq_lens) + _extra_padding

feature_sequences = []
for f in feature_columns:
f = np.array(f)
# Save unnecessary copy.
if not isinstance(f, np.ndarray):
f = np.array(f)
length = len(seq_lens) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_:
f_pad = [None] * length
else:
f_pad = np.zeros((length, ) + np.shape(f)[1:])
# Make sure type doesn't change.
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
seq_base = 0
i = 0
for len_ in seq_lens:
Expand Down
75 changes: 29 additions & 46 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,22 @@ def compute_log_likelihoods(
@DeveloperAPI
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Compute gradients (will calculate all losses and `backward()`
# them to get the grads).
grads, fetches = self.compute_gradients(postprocessed_batch)

# Step the optimizers.
for i, opt in enumerate(self._optimizers):
opt.step()

if self.model:
fetches["model"] = self.model.metrics()
return fetches

@override(Policy)
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
Expand All @@ -341,8 +357,6 @@ def learn_on_batch(
)

train_batch = self._lazy_tensor_dict(postprocessed_batch)

# Calculate the actual policy loss.
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))

Expand All @@ -358,26 +372,30 @@ def learn_on_batch(

assert len(loss_out) == len(self._optimizers)

# assert not any(torch.isnan(l) for l in loss_out)
fetches = self.extra_compute_grad_fetches()

# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}

all_grads = []
for i, opt in enumerate(self._optimizers):
# Erase gradients in all vars of this optimizer.
opt.zero_grad()
# Recompute gradients of loss over all variables.
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
grad_info.update(self.extra_grad_process(opt, loss_out[i]))

if self.distributed_world_size:
grads = []
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad)
grads = []
# Note that return values are just references;
# Calling zero_grad would modify the values.
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad)
all_grads.append(p.grad.data.cpu().numpy())
else:
all_grads.append(None)

if self.distributed_world_size:
start = time.time()
if torch.cuda.is_available():
# Sadly, allreduce_coalesced does not work with CUDA yet.
Expand All @@ -395,45 +413,10 @@ def learn_on_batch(

grad_info["allreduce_latency"] += time.time() - start

# Step the optimizers.
for i, opt in enumerate(self._optimizers):
opt.step()

grad_info["allreduce_latency"] /= len(self._optimizers)
grad_info.update(self.extra_grad_info(train_batch))
if self.model:
grad_info["model"] = self.model.metrics()
return dict(fetches, **{LEARNER_STATS_KEY: grad_info})

@override(Policy)
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))
assert len(loss_out) == len(self._optimizers)
fetches = self.extra_compute_grad_fetches()

grad_process_info = {}
grads = []
for i, opt in enumerate(self._optimizers):
opt.zero_grad()
loss_out[i].backward()
grad_process_info = self.extra_grad_process(opt, loss_out[i])

# Note that return values are just references;
# calling zero_grad will modify the values
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)

grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
return grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})

@override(Policy)
@DeveloperAPI
Expand Down