diff --git a/test/test_helpers.py b/test/test_helpers.py index 77807effac6..35d84c53e07 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -241,7 +241,10 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) -def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): +@pytest.mark.parametrize("action_space", ["discrete", "continuous"]) +def test_ppo_maker( + device, from_pixels, shared_mapping, gsde, exploration, action_space +): if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) @@ -262,11 +265,17 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): # if gsde and from_pixels: # pytest.skip("gsde and from_pixels are incompatible") - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) + if from_pixels: + if action_space == "continuous": + env_maker = ContinuousActionConvMockEnvNumpy + else: + env_maker = DiscreteActionConvMockEnvNumpy + else: + if action_space == "continuous": + env_maker = ContinuousActionVecMockEnv + else: + env_maker = DiscreteActionVecMockEnv + env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker ) @@ -284,6 +293,18 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): ) return + if action_space == "discrete" and cfg.gSDE: + with pytest.raises( + RuntimeError, + match="cannot use gSDE with discrete actions", + ): + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + return + actor_value = make_ppo_model( proof_environment, device=device, @@ -296,9 +317,11 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): "pixels_orig" if len(from_pixels) else "observation_orig", "action", "sample_log_prob", - "loc", - "scale", ] + if action_space == "continuous": + expected_keys += ["loc", "scale"] + else: + expected_keys += ["logits"] if shared_mapping: expected_keys += ["hidden"] if len(gsde): @@ -365,7 +388,10 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): @pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) -def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): +@pytest.mark.parametrize("action_space", ["discrete", "continuous"]) +def test_a2c_maker( + device, from_pixels, shared_mapping, gsde, exploration, action_space +): A2CModelConfig.advantage_in_loss = False if not gsde and exploration != "random": pytest.skip("no need to test this setting") @@ -389,11 +415,17 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): # if gsde and from_pixels: # pytest.skip("gsde and from_pixels are incompatible") - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) + if from_pixels: + if action_space == "continuous": + env_maker = ContinuousActionConvMockEnvNumpy + else: + env_maker = DiscreteActionConvMockEnvNumpy + else: + if action_space == "continuous": + env_maker = ContinuousActionVecMockEnv + else: + env_maker = DiscreteActionVecMockEnv + env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker ) @@ -411,6 +443,18 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): ) return + if action_space == "discrete" and cfg.gSDE: + with pytest.raises( + RuntimeError, + match="cannot use gSDE with discrete actions", + ): + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + return + actor_value = make_a2c_model( proof_environment, device=device, @@ -423,9 +467,11 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): "pixels_orig" if len(from_pixels) else "observation_orig", "action", "sample_log_prob", - "loc", - "scale", ] + if action_space == "continuous": + expected_keys += ["loc", "scale"] + else: + expected_keys += ["logits"] if shared_mapping: expected_keys += ["hidden"] if len(gsde): diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 24742d62ee0..e9e57eadb45 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -501,6 +501,7 @@ def make_a2c_model( out_keys = ["action"] if action_spec.domain == "continuous": + dist_in_keys = ["loc", "scale"] out_features = (2 - cfg.gSDE) * action_spec.shape[-1] if cfg.distribution == "tanh_normal": policy_distribution_kwargs = { @@ -520,6 +521,7 @@ def make_a2c_model( out_features = action_spec.shape[-1] policy_distribution_kwargs = {} policy_distribution_class = OneHotCategorical + dist_in_keys = ["logits"] else: raise NotImplementedError( f"actions with domain {action_spec.domain} are not supported" @@ -560,20 +562,22 @@ def make_a2c_model( num_cells=[64], out_features=out_features, ) + + shared_out_keys = ["hidden"] if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) - in_keys = ["hidden"] + if action_spec.domain == "continuous": + policy_net = NormalParamWrapper( + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", + ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] + policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys ) else: - in_keys = ["hidden"] gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, - in_keys=in_keys, + in_keys=shared_out_keys, out_keys=["action"], # will be overwritten ) @@ -601,7 +605,7 @@ def make_a2c_model( policy_operator = ProbabilisticActor( spec=CompositeSpec(action=action_spec), module=actor_module, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, default_interaction_mode="random", distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, @@ -611,7 +615,7 @@ def make_a2c_model( num_cells=[64], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=["hidden"]) + value_operator = ValueOperator(value_net, in_keys=shared_out_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -637,11 +641,13 @@ def make_a2c_model( ) if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) + if action_spec.domain == "continuous": + policy_net = NormalParamWrapper( + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", + ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor @@ -676,7 +682,7 @@ def make_a2c_model( policy_po = ProbabilisticActor( actor_module, spec=action_spec, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, return_log_prob=True, @@ -790,6 +796,7 @@ def make_ppo_model( out_keys = ["action"] if action_spec.domain == "continuous": + dist_in_keys = ["loc", "scale"] out_features = (2 - cfg.gSDE) * action_spec.shape[-1] if cfg.distribution == "tanh_normal": policy_distribution_kwargs = { @@ -809,6 +816,7 @@ def make_ppo_model( out_features = action_spec.shape[-1] policy_distribution_kwargs = {} policy_distribution_class = OneHotCategorical + dist_in_keys = ["logits"] else: raise NotImplementedError( f"actions with domain {action_spec.domain} are not supported" @@ -849,20 +857,22 @@ def make_ppo_model( num_cells=[200], out_features=out_features, ) + + shared_out_keys = ["hidden"] if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) - in_keys = ["hidden"] + if action_spec.domain == "continuous": + policy_net = NormalParamWrapper( + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", + ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] + policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys ) else: - in_keys = ["hidden"] gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, - in_keys=in_keys, + in_keys=shared_out_keys, out_keys=["action"], # will be overwritten ) @@ -882,7 +892,7 @@ def make_ppo_model( actor_module, SafeModule( LazygSDEModule(transform=transform), - in_keys=["action", gSDE_state_key, "_eps_gSDE"], + in_keys=["action", gSDE_state_key, "_eps_gSD"], out_keys=["loc", "scale", "action", "_eps_gSDE"], ), ) @@ -890,7 +900,7 @@ def make_ppo_model( policy_operator = ProbabilisticActor( spec=CompositeSpec(action=action_spec), module=actor_module, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, default_interaction_mode="random", distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, @@ -900,7 +910,7 @@ def make_ppo_model( num_cells=[200], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=["hidden"]) + value_operator = ValueOperator(value_net, in_keys=shared_out_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -926,11 +936,13 @@ def make_ppo_model( ) if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) + if action_spec.domain == "continuous": + policy_net = NormalParamWrapper( + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", + ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor @@ -965,7 +977,7 @@ def make_ppo_model( policy_po = ProbabilisticActor( actor_module, spec=action_spec, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, return_log_prob=True,