diff --git a/obp/policy/offline_continuous.py b/obp/policy/offline_continuous.py index 8ce6a69..0324293 100644 --- a/obp/policy/offline_continuous.py +++ b/obp/policy/offline_continuous.py @@ -526,7 +526,7 @@ def _estimate_policy_gradient( reward: torch.Tensor, pscore: torch.Tensor, action_by_current_policy: torch.Tensor, - ) -> float: + ) -> torch.Tensor: """Estimate the policy gradient. Parameters