File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
rl_algo_impls/shared/actor Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -172,7 +172,7 @@ def forward(
172172 not action_masks
173173 ), f"{ self .__class__ .__name__ } does not support action_masks"
174174 pi = self ._distribution (obs )
175- return pi_forward (pi , actions )
175+ return pi_forward (pi , actions , self . bijector )
176176
177177 def sample_weights (self , batch_size : int = 1 ) -> None :
178178 std = self ._get_std ()
@@ -187,13 +187,13 @@ def action_shape(self) -> Tuple[int, ...]:
187187
188188
189189def pi_forward (
190- distribution : Distribution , actions : Optional [torch .Tensor ] = None
190+ distribution : Distribution ,
191+ actions : Optional [torch .Tensor ] = None ,
192+ bijector : Optional [TanhBijector ] = None ,
191193) -> PiForward :
192194 logp_a = None
193195 entropy = None
194196 if actions is not None :
195197 logp_a = distribution .log_prob (actions )
196- entropy = (
197- - logp_a if self .bijector else sum_independent_dims (distribution .entropy ())
198- )
198+ entropy = - logp_a if bijector else sum_independent_dims (distribution .entropy ())
199199 return PiForward (distribution , logp_a , entropy )
You can’t perform that action at this time.
0 commit comments