Skip to content

Commit

Permalink
[RLlib] Fix no gradient clipping happening in QMix. (#25656)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Jun 10, 2022
1 parent 730df43 commit c364592
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion rllib/algorithms/qmix/qmix_policy.py
Expand Up @@ -447,6 +447,7 @@ def to_batches(arr, dtype):
# Optimise
self.rmsprop_optimizer.zero_grad()
loss_out.backward()
grad_norm_info = apply_grad_clipping(self, self.rmsprop_optimizer, loss_out)
self.rmsprop_optimizer.step()

mask_elems = mask.sum().item()
Expand All @@ -456,7 +457,8 @@ def to_batches(arr, dtype):
"q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
stats.update(apply_grad_clipping(self, self.rmsprop_optimizer, loss_out))
stats.update(grad_norm_info)

return {LEARNER_STATS_KEY: stats}

@override(TorchPolicy)
Expand Down

0 comments on commit c364592

Please sign in to comment.