-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. #7238
Conversation
…torch_memory_leak # Conflicts: # rllib/agents/ppo/tests/test_ppo.py
…torch_memory_leak
…torch_memory_leak # Conflicts: # rllib/agents/ppo/ppo_torch_policy.py
…torch_memory_leak � Conflicts: � rllib/agents/ppo/ppo_torch_policy.py
Test FAILed. |
rllib/agents/a3c/a3c_torch_policy.py
Outdated
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], | ||
policy.config["use_gae"], policy.config["use_critic"]) | ||
|
||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we actually move this into the postprocess method defined in torch_policy_template
? That way it will work automatically for all torch policies and we don't need to clutter the individual ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do.
rllib/agents/ppo/ppo_torch_policy.py
Outdated
"total_loss": policy.loss_obj.loss.cpu().detach().numpy(), | ||
"policy_loss": policy.loss_obj.mean_policy_loss.cpu().detach().numpy(), | ||
"vf_loss": policy.loss_obj.mean_vf_loss.cpu().detach().numpy(), | ||
"total_loss": policy.loss_obj.loss.item(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, could we automatically apply .item() to the dict values returned in the common torch policy class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All done. ... Waiting for tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main question here is if we can automatically insert these conversions in the template, to avoid having to do this for each algo (which could be brittle).
…torch_memory_leak
Test FAILed. |
Test FAILed. |
@ericl Everything is handled by the template now, which also does the bumpy/item conversion AND takes care of the no_grad. Individual TorchPolicies don't have to worry about this anymore. |
Test FAILed. |
Test PASSed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
rllib/agents/a3c/a3c_torch_policy.py
Outdated
completed = sample_batch[SampleBatch.DONES][-1] | ||
if completed: | ||
|
||
if sample_batch[SampleBatch.DONES][-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer to use intermediate variables to clarify the value of a long expression when psosible.
Possibly related to the changes? |
I'll make custom_tf_policy size=medium, see whether that fixes that. I changed back the intermediary var. |
Test FAILed. |
@ericl Tests are all pass. Please merge. |
PPO torch has a memory leak due to a missing
torch.no_grad()
aroundcompute_advantages
.PPO torch produces lots of intermediary Tensors (which are then garbage collected) when run on CPU or GPU. This is due to the reporting code, creating volatile CPU Tensors (then numpy'ing them).
#6962
Closes #6962
scripts/format.sh
to lint the changes in this PR.