Skip to content
Permalink
Browse files

[rllib] Implement learn_on_batch() in torch policy graph

  • Loading branch information...
ericl committed May 13, 2019
1 parent f3b8b90 commit 69352e3302d1a8eeff594f4ee73c858a874df2b4
Showing with 18 additions and 0 deletions.
  1. +18 −0 python/ray/rllib/evaluation/torch_policy_graph.py
@@ -85,6 +85,24 @@ def compute_actions(self,
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))

@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
self._optimizer.zero_grad()
loss_out.backward()

grad_process_info = self.extra_grad_process()
self._optimizer.step()

grad_info = self.extra_grad_info()
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}

@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
with self.lock:

0 comments on commit 69352e3

Please sign in to comment.
You can’t perform that action at this time.