From 94a44b5a252c432e3c47577fa46ed49c230fcce3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 4 Jan 2023 21:12:52 -0500 Subject: [PATCH] Hotfix for #331 (#342) --- cleanrl/rpo_continuous_action.py | 4 ++-- tests/test_mujoco.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/cleanrl/rpo_continuous_action.py b/cleanrl/rpo_continuous_action.py index 86736b337..e339017a6 100644 --- a/cleanrl/rpo_continuous_action.py +++ b/cleanrl/rpo_continuous_action.py @@ -135,12 +135,12 @@ def get_action_and_value(self, x, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - else: # new to RPO + else: # new to RPO # sample again to add stochasticity to the policy z = torch.FloatTensor(action_mean.shape).uniform_(-self.rpo_alpha, self.rpo_alpha) action_mean = action_mean + z probs = Normal(action_mean, action_std) - + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) diff --git a/tests/test_mujoco.py b/tests/test_mujoco.py index b346e08b6..cc3acb9f6 100644 --- a/tests/test_mujoco.py +++ b/tests/test_mujoco.py @@ -15,3 +15,13 @@ def test_mujoco(): shell=True, check=True, ) + subprocess.run( + "python cleanrl/rpo_continuous_action.py --env-id Hopper-v4 --num-envs 1 --num-steps 64 --total-timesteps 128", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/rpo_continuous_action.py --env-id dm_control/cartpole-balance-v0 --num-envs 1 --num-steps 64 --total-timesteps 128", + shell=True, + check=True, + )