Skip to content
Merged
78 changes: 62 additions & 16 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)])
@pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)])
@pytest.mark.parametrize("exploration", ["random", "mode"])
def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
@pytest.mark.parametrize("action_space", ["discrete", "continuous"])
def test_ppo_maker(
device, from_pixels, shared_mapping, gsde, exploration, action_space
):
if not gsde and exploration != "random":
pytest.skip("no need to test this setting")
flags = list(from_pixels + shared_mapping + gsde)
Expand All @@ -262,11 +265,17 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
# if gsde and from_pixels:
# pytest.skip("gsde and from_pixels are incompatible")

env_maker = (
ContinuousActionConvMockEnvNumpy
if from_pixels
else ContinuousActionVecMockEnv
)
if from_pixels:
if action_space == "continuous":
env_maker = ContinuousActionConvMockEnvNumpy
else:
env_maker = DiscreteActionConvMockEnvNumpy
else:
if action_space == "continuous":
env_maker = ContinuousActionVecMockEnv
else:
env_maker = DiscreteActionVecMockEnv

env_maker = transformed_env_constructor(
cfg, use_env_creator=False, custom_env_maker=env_maker
)
Expand All @@ -284,6 +293,18 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
)
return

if action_space == "discrete" and cfg.gSDE:
with pytest.raises(
RuntimeError,
match="cannot use gSDE with discrete actions",
):
actor_value = make_a2c_model(
proof_environment,
device=device,
cfg=cfg,
)
return

actor_value = make_ppo_model(
proof_environment,
device=device,
Expand All @@ -296,9 +317,11 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
"pixels_orig" if len(from_pixels) else "observation_orig",
"action",
"sample_log_prob",
"loc",
"scale",
]
if action_space == "continuous":
expected_keys += ["loc", "scale"]
else:
expected_keys += ["logits"]
if shared_mapping:
expected_keys += ["hidden"]
if len(gsde):
Expand Down Expand Up @@ -365,7 +388,10 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):
@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)])
@pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)])
@pytest.mark.parametrize("exploration", ["random", "mode"])
def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration):
@pytest.mark.parametrize("action_space", ["discrete", "continuous"])
def test_a2c_maker(
device, from_pixels, shared_mapping, gsde, exploration, action_space
):
A2CModelConfig.advantage_in_loss = False
if not gsde and exploration != "random":
pytest.skip("no need to test this setting")
Expand All @@ -389,11 +415,17 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration):
# if gsde and from_pixels:
# pytest.skip("gsde and from_pixels are incompatible")

env_maker = (
ContinuousActionConvMockEnvNumpy
if from_pixels
else ContinuousActionVecMockEnv
)
if from_pixels:
if action_space == "continuous":
env_maker = ContinuousActionConvMockEnvNumpy
else:
env_maker = DiscreteActionConvMockEnvNumpy
else:
if action_space == "continuous":
env_maker = ContinuousActionVecMockEnv
else:
env_maker = DiscreteActionVecMockEnv

env_maker = transformed_env_constructor(
cfg, use_env_creator=False, custom_env_maker=env_maker
)
Expand All @@ -411,6 +443,18 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration):
)
return

if action_space == "discrete" and cfg.gSDE:
with pytest.raises(
RuntimeError,
match="cannot use gSDE with discrete actions",
):
actor_value = make_a2c_model(
proof_environment,
device=device,
cfg=cfg,
)
return

actor_value = make_a2c_model(
proof_environment,
device=device,
Expand All @@ -423,9 +467,11 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration):
"pixels_orig" if len(from_pixels) else "observation_orig",
"action",
"sample_log_prob",
"loc",
"scale",
]
if action_space == "continuous":
expected_keys += ["loc", "scale"]
else:
expected_keys += ["logits"]
if shared_mapping:
expected_keys += ["hidden"]
if len(gsde):
Expand Down
70 changes: 41 additions & 29 deletions torchrl/trainers/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def make_a2c_model(
out_keys = ["action"]

if action_spec.domain == "continuous":
dist_in_keys = ["loc", "scale"]
out_features = (2 - cfg.gSDE) * action_spec.shape[-1]
if cfg.distribution == "tanh_normal":
policy_distribution_kwargs = {
Expand All @@ -520,6 +521,7 @@ def make_a2c_model(
out_features = action_spec.shape[-1]
policy_distribution_kwargs = {}
policy_distribution_class = OneHotCategorical
dist_in_keys = ["logits"]
else:
raise NotImplementedError(
f"actions with domain {action_spec.domain} are not supported"
Expand Down Expand Up @@ -560,20 +562,22 @@ def make_a2c_model(
num_cells=[64],
out_features=out_features,
)

shared_out_keys = ["hidden"]
if not cfg.gSDE:
actor_net = NormalParamWrapper(
policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}"
)
in_keys = ["hidden"]
if action_spec.domain == "continuous":
policy_net = NormalParamWrapper(
policy_net,
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
)
actor_module = SafeModule(
actor_net, in_keys=in_keys, out_keys=["loc", "scale"]
policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys
)
else:
in_keys = ["hidden"]
gSDE_state_key = "hidden"
actor_module = SafeModule(
policy_net,
in_keys=in_keys,
in_keys=shared_out_keys,
out_keys=["action"], # will be overwritten
)

Expand Down Expand Up @@ -601,7 +605,7 @@ def make_a2c_model(
policy_operator = ProbabilisticActor(
spec=CompositeSpec(action=action_spec),
module=actor_module,
dist_in_keys=["loc", "scale"],
dist_in_keys=dist_in_keys,
default_interaction_mode="random",
distribution_class=policy_distribution_class,
distribution_kwargs=policy_distribution_kwargs,
Expand All @@ -611,7 +615,7 @@ def make_a2c_model(
num_cells=[64],
out_features=1,
)
value_operator = ValueOperator(value_net, in_keys=["hidden"])
value_operator = ValueOperator(value_net, in_keys=shared_out_keys)
actor_value = ActorValueOperator(
common_operator=common_operator,
policy_operator=policy_operator,
Expand All @@ -637,11 +641,13 @@ def make_a2c_model(
)

if not cfg.gSDE:
actor_net = NormalParamWrapper(
policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}"
)
if action_spec.domain == "continuous":
policy_net = NormalParamWrapper(
policy_net,
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
)
actor_module = SafeModule(
actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"]
policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys
)
else:
in_keys = in_keys_actor
Expand Down Expand Up @@ -676,7 +682,7 @@ def make_a2c_model(
policy_po = ProbabilisticActor(
actor_module,
spec=action_spec,
dist_in_keys=["loc", "scale"],
dist_in_keys=dist_in_keys,
distribution_class=policy_distribution_class,
distribution_kwargs=policy_distribution_kwargs,
return_log_prob=True,
Expand Down Expand Up @@ -790,6 +796,7 @@ def make_ppo_model(
out_keys = ["action"]

if action_spec.domain == "continuous":
dist_in_keys = ["loc", "scale"]
out_features = (2 - cfg.gSDE) * action_spec.shape[-1]
if cfg.distribution == "tanh_normal":
policy_distribution_kwargs = {
Expand All @@ -809,6 +816,7 @@ def make_ppo_model(
out_features = action_spec.shape[-1]
policy_distribution_kwargs = {}
policy_distribution_class = OneHotCategorical
dist_in_keys = ["logits"]
else:
raise NotImplementedError(
f"actions with domain {action_spec.domain} are not supported"
Expand Down Expand Up @@ -849,20 +857,22 @@ def make_ppo_model(
num_cells=[200],
out_features=out_features,
)

shared_out_keys = ["hidden"]
if not cfg.gSDE:
actor_net = NormalParamWrapper(
policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}"
)
in_keys = ["hidden"]
if action_spec.domain == "continuous":
policy_net = NormalParamWrapper(
policy_net,
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
)
actor_module = SafeModule(
actor_net, in_keys=in_keys, out_keys=["loc", "scale"]
policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys
)
else:
in_keys = ["hidden"]
gSDE_state_key = "hidden"
actor_module = SafeModule(
policy_net,
in_keys=in_keys,
in_keys=shared_out_keys,
out_keys=["action"], # will be overwritten
)

Expand All @@ -882,15 +892,15 @@ def make_ppo_model(
actor_module,
SafeModule(
LazygSDEModule(transform=transform),
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
in_keys=["action", gSDE_state_key, "_eps_gSD"],
out_keys=["loc", "scale", "action", "_eps_gSDE"],
),
)

policy_operator = ProbabilisticActor(
spec=CompositeSpec(action=action_spec),
module=actor_module,
dist_in_keys=["loc", "scale"],
dist_in_keys=dist_in_keys,
default_interaction_mode="random",
distribution_class=policy_distribution_class,
distribution_kwargs=policy_distribution_kwargs,
Expand All @@ -900,7 +910,7 @@ def make_ppo_model(
num_cells=[200],
out_features=1,
)
value_operator = ValueOperator(value_net, in_keys=["hidden"])
value_operator = ValueOperator(value_net, in_keys=shared_out_keys)
actor_value = ActorValueOperator(
common_operator=common_operator,
policy_operator=policy_operator,
Expand All @@ -926,11 +936,13 @@ def make_ppo_model(
)

if not cfg.gSDE:
actor_net = NormalParamWrapper(
policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}"
)
if action_spec.domain == "continuous":
policy_net = NormalParamWrapper(
policy_net,
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
)
actor_module = SafeModule(
actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"]
policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys
)
else:
in_keys = in_keys_actor
Expand Down Expand Up @@ -965,7 +977,7 @@ def make_ppo_model(
policy_po = ProbabilisticActor(
actor_module,
spec=action_spec,
dist_in_keys=["loc", "scale"],
dist_in_keys=dist_in_keys,
distribution_class=policy_distribution_class,
distribution_kwargs=policy_distribution_kwargs,
return_log_prob=True,
Expand Down