Skip to content

Commit

Permalink
Merge branch 'master' into CI/add-publish-workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
MischaPanch committed Mar 2, 2024
2 parents 5108c82 + 1aee41f commit eb99da3
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 20 deletions.
2 changes: 0 additions & 2 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -58,7 +57,6 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
# if you want to use python vector env, please refer to other test scripts
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
Expand Down
2 changes: 1 addition & 1 deletion tianshou/policy/modelfree/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward( # type: ignore
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
if self.deterministic_eval and not self.training:
act = logits.argmax(axis=-1)
act = dist.mode
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)
Expand Down
15 changes: 1 addition & 14 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,6 @@ def process_fn(
batch: BatchWithReturnsProtocol
return batch

def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor:
if self.action_type == "discrete":
return logits.argmax(-1)
if self.action_type == "continuous":
# assume that the mode of the distribution is the first element
# of the actor's output (the "logits")
return logits[0]
raise RuntimeError(
f"Unknown action type: {self.action_type}. "
f"This should not happen and might be a bug."
f"Supported action types are: 'discrete' and 'continuous'.",
)

def forward(
self,
batch: ObsBatchProtocol,
Expand Down Expand Up @@ -198,7 +185,7 @@ def forward(

# in this case, the dist is unused!
if self.deterministic_eval and not self.training:
act = self._get_deterministic_action(logits)
act = dist.mode
else:
act = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
Expand Down
5 changes: 4 additions & 1 deletion tianshou/policy/modelfree/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def forward( # type: ignore
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
if self.deterministic_eval and not self.training:
act = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
Expand Down
4 changes: 2 additions & 2 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
This is useful when solving "hard exploration" problems.
"default" is equivalent to GaussianNoise(sigma=0.1).
:param deterministic_eval: whether to use deterministic action
(mean of Gaussian policy) in evaluation mode instead of stochastic
(mode of Gaussian policy) in evaluation mode instead of stochastic
action sampled by the policy. Does not affect training.
:param action_scaling: whether to map actions from range [-1, 1]
to range[action_spaces.low, action_spaces.high].
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward( # type: ignore
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training:
act = logits[0]
act = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
Expand Down

0 comments on commit eb99da3

Please sign in to comment.