Skip to content

Commit 983cb75

Browse files
committed
Pass along bijector into pi_forward's calculation
Fixes CarRacing-v0 and other users of StateDependentNoise
1 parent 9ba0ab5 commit 983cb75

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

rl_algo_impls/shared/actor/state_dependent_noise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def forward(
172172
not action_masks
173173
), f"{self.__class__.__name__} does not support action_masks"
174174
pi = self._distribution(obs)
175-
return pi_forward(pi, actions)
175+
return pi_forward(pi, actions, self.bijector)
176176

177177
def sample_weights(self, batch_size: int = 1) -> None:
178178
std = self._get_std()
@@ -187,13 +187,13 @@ def action_shape(self) -> Tuple[int, ...]:
187187

188188

189189
def pi_forward(
190-
distribution: Distribution, actions: Optional[torch.Tensor] = None
190+
distribution: Distribution,
191+
actions: Optional[torch.Tensor] = None,
192+
bijector: Optional[TanhBijector] = None,
191193
) -> PiForward:
192194
logp_a = None
193195
entropy = None
194196
if actions is not None:
195197
logp_a = distribution.log_prob(actions)
196-
entropy = (
197-
-logp_a if self.bijector else sum_independent_dims(distribution.entropy())
198-
)
198+
entropy = -logp_a if bijector else sum_independent_dims(distribution.entropy())
199199
return PiForward(distribution, logp_a, entropy)

0 commit comments

Comments
 (0)