diff --git a/rllib/algorithms/qmix/qmix_policy.py b/rllib/algorithms/qmix/qmix_policy.py index 20855a799bb1f..32d7dca25254f 100644 --- a/rllib/algorithms/qmix/qmix_policy.py +++ b/rllib/algorithms/qmix/qmix_policy.py @@ -447,6 +447,7 @@ def to_batches(arr, dtype): # Optimise self.rmsprop_optimizer.zero_grad() loss_out.backward() + grad_norm_info = apply_grad_clipping(self, self.rmsprop_optimizer, loss_out) self.rmsprop_optimizer.step() mask_elems = mask.sum().item() @@ -456,7 +457,8 @@ def to_batches(arr, dtype): "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } - stats.update(apply_grad_clipping(self, self.rmsprop_optimizer, loss_out)) + stats.update(grad_norm_info) + return {LEARNER_STATS_KEY: stats} @override(TorchPolicy)