-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Minor fixes (torch GPU bugs + some cleanup). #11609
Conversation
@@ -331,7 +332,8 @@ def compute_actions( | |||
fetched = builder.get(to_fetch) | |||
|
|||
# Update our global timestep by the batch size. | |||
self.global_timestep += fetched[0].shape[0] | |||
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a tensor (eager tracing) or a list.
@@ -28,7 +28,7 @@ def test_multi_agent_pendulum(self): | |||
"env": "multi_agent_pendulum", | |||
"stop": { | |||
"timesteps_total": 500000, | |||
"episode_reward_mean": -300.0, | |||
"episode_reward_mean": -400.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make some tests run a little faster.
@@ -187,39 +187,35 @@ def test_batch_ids(self): | |||
def test_global_vars_update(self): | |||
# Allow for Unittest run. | |||
ray.init(num_cpus=5, ignore_reinit_error=True) | |||
for fw in framework_iterator(frameworks=()): | |||
for fw in framework_iterator(frameworks=("tf2", "tf")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was completely deactivated! frameworks=empty
agent.stop() | ||
|
||
def test_no_step_on_init(self): | ||
register_env("fail", lambda _: FailOnStepEnv()) | ||
for fw in framework_iterator(frameworks=()): | ||
for fw in framework_iterator(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -80,7 +80,8 @@ def __init__(self, | |||
behaviour_policy_logits=behaviour_logits, | |||
target_policy_logits=target_logits, | |||
actions=tf.unstack(actions, axis=2), | |||
discounts=tf.cast(~dones, tf.float32) * discount, | |||
discounts=tf.cast(~tf.cast(dones, tf.bool), tf.float32) * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would fail if dones are floats (0.0 or 1.0).
if sess: | ||
expected_logp = sess.run(expected_logp) | ||
elif fw == "torch": | ||
expected_logp = expected_logp.detach().cpu().numpy() | ||
adv = adv.detach().cpu().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
failed on GPU
rllib/offline/json_reader.py
Outdated
@@ -150,6 +150,7 @@ def _from_json(batch: str) -> SampleBatchType: | |||
|
|||
if data_type == "SampleBatch": | |||
for k, v in data.items(): | |||
print("Trying to unpack {}: {}".format(k, v))#TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this prior to merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, trying to catch the failing test_marwil/bc.py problem, which is related to a different compression format on the observation and will go into a different PR.
Minor fixes:
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.