From b99c373dacb710c8e3a86c36bef6012707626b51 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:00:37 +0100 Subject: [PATCH 01/14] extend ppo make model method --- torchrl/trainers/helpers/models.py | 106 ++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 34 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 24742d62ee0..9d4c01fc9bd 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -492,6 +492,7 @@ def make_a2c_model( specs = proof_environment.specs # TODO: use env.sepcs action_spec = specs["action_spec"] + # Define input observation format for actor and critic if in_keys_actor is None and proof_environment.from_pixels: in_keys_actor = ["pixels"] in_keys_critic = ["pixels"] @@ -501,6 +502,7 @@ def make_a2c_model( out_keys = ["action"] if action_spec.domain == "continuous": + dist_in_keys = ["loc", "scale"] out_features = (2 - cfg.gSDE) * action_spec.shape[-1] if cfg.distribution == "tanh_normal": policy_distribution_kwargs = { @@ -520,6 +522,7 @@ def make_a2c_model( out_features = action_spec.shape[-1] policy_distribution_kwargs = {} policy_distribution_class = OneHotCategorical + dist_in_keys = ["logits"] else: raise NotImplementedError( f"actions with domain {action_spec.domain} are not supported" @@ -560,16 +563,17 @@ def make_a2c_model( num_cells=[64], out_features=out_features, ) + + in_keys = ["hidden"] if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) - in_keys = ["hidden"] + if action_spec.domain == "continuous": + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] + actor_net, in_keys=in_keys, out_keys=dist_in_keys ) else: - in_keys = ["hidden"] gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, @@ -601,7 +605,7 @@ def make_a2c_model( policy_operator = ProbabilisticActor( spec=CompositeSpec(action=action_spec), module=actor_module, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, default_interaction_mode="random", distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, @@ -611,7 +615,7 @@ def make_a2c_model( num_cells=[64], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=["hidden"]) + value_operator = ValueOperator(value_net, in_keys=in_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -637,11 +641,12 @@ def make_a2c_model( ) if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) + if action_spec.domain == "continuous": + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + actor_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor @@ -676,7 +681,7 @@ def make_a2c_model( policy_po = ProbabilisticActor( actor_module, spec=action_spec, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, return_log_prob=True, @@ -781,6 +786,7 @@ def make_ppo_model( specs = proof_environment.specs # TODO: use env.sepcs action_spec = specs["action_spec"] + # Define input observation format for actor and critic if in_keys_actor is None and proof_environment.from_pixels: in_keys_actor = ["pixels"] in_keys_critic = ["pixels"] @@ -790,6 +796,7 @@ def make_ppo_model( out_keys = ["action"] if action_spec.domain == "continuous": + dist_in_keys = ["loc", "scale"] out_features = (2 - cfg.gSDE) * action_spec.shape[-1] if cfg.distribution == "tanh_normal": policy_distribution_kwargs = { @@ -809,6 +816,7 @@ def make_ppo_model( out_features = action_spec.shape[-1] policy_distribution_kwargs = {} policy_distribution_class = OneHotCategorical + dist_in_keys = ["logits"] else: raise NotImplementedError( f"actions with domain {action_spec.domain} are not supported" @@ -833,7 +841,7 @@ def make_ppo_model( ) common_module = MLP( num_cells=[ - 400, + 64, ], out_features=hidden_features, activate_last_layer=True, @@ -846,19 +854,20 @@ def make_ppo_model( ) policy_net = MLP( - num_cells=[200], + num_cells=[64], out_features=out_features, ) + + in_keys = ["hidden"] if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) - in_keys = ["hidden"] + if action_spec.domain == "continuous": + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] + actor_net, in_keys=in_keys, out_keys=dist_in_keys ) else: - in_keys = ["hidden"] gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, @@ -882,7 +891,7 @@ def make_ppo_model( actor_module, SafeModule( LazygSDEModule(transform=transform), - in_keys=["action", gSDE_state_key, "_eps_gSDE"], + in_keys=["action", gSDE_state_key, "_eps_gSD"], out_keys=["loc", "scale", "action", "_eps_gSDE"], ), ) @@ -890,17 +899,17 @@ def make_ppo_model( policy_operator = ProbabilisticActor( spec=CompositeSpec(action=action_spec), module=actor_module, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, default_interaction_mode="random", distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, return_log_prob=True, ) value_net = MLP( - num_cells=[200], + num_cells=[64], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=["hidden"]) + value_operator = ValueOperator(value_net, in_keys=in_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -914,23 +923,24 @@ def make_ppo_model( if cfg.lstm: policy_net = LSTMNet( out_features=out_features, - lstm_kwargs={"input_size": 256, "hidden_size": 256}, - mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, + lstm_kwargs={"input_size": 64, "hidden_size": 64}, + mlp_kwargs={"num_cells": [64, 64], "out_features": 64}, ) in_keys_actor += ["hidden0", "hidden1"] out_keys += ["hidden0", "hidden1", ("next", "hidden0"), ("next", "hidden1")] else: policy_net = MLP( - num_cells=[400, 300], + num_cells=[64, 64], out_features=out_features, ) if not cfg.gSDE: - actor_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" - ) + if action_spec.domain == "continuous": + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + actor_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor @@ -965,7 +975,7 @@ def make_ppo_model( policy_po = ProbabilisticActor( actor_module, spec=action_spec, - dist_in_keys=["loc", "scale"], + dist_in_keys=dist_in_keys, distribution_class=policy_distribution_class, distribution_kwargs=policy_distribution_kwargs, return_log_prob=True, @@ -973,7 +983,7 @@ def make_ppo_model( ) value_net = MLP( - num_cells=[400, 300], + num_cells=[64, 64], out_features=1, ) value_po = ValueOperator( @@ -1823,6 +1833,34 @@ class PPOModelConfig: # if True, uses an LSTM for the policy. shared_mapping: bool = False # if True, the first layers of the actor-critic are shared. + shared_convnet_kwargs: dict = dict( + bias_last_layer=True, + depth=None, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + # if the shared network is a CNN, these parameters will be used + shared_mlp_kwargs: dict = dict( + num_cells=[400], + out_features=300, + ) + # if the shared network is a MLP, these parameters will be used + policy_mlp_kwargs: dict = dict( + num_cells=[256, 448], + out_features=1, # will be overwritten + ) + # parameters used for the MLP actor network + value_mlp_kwargs: dict = dict( + num_cells=[256, 448], + out_features=1, # will be overwritten + ) + # parameters used for the MLP value network + lstm_kwargs: dict = dict( + input_size=256, + hidden_size=256, + ) + # if lstm is True, these parameters will be used for the network @dataclass From 0042695cef23c3963271da784399e5371b39e5ab Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:06:25 +0100 Subject: [PATCH 02/14] ppoconf --- torchrl/trainers/helpers/models.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 9d4c01fc9bd..c777c68d101 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -1833,35 +1833,7 @@ class PPOModelConfig: # if True, uses an LSTM for the policy. shared_mapping: bool = False # if True, the first layers of the actor-critic are shared. - shared_convnet_kwargs: dict = dict( - bias_last_layer=True, - depth=None, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - # if the shared network is a CNN, these parameters will be used - shared_mlp_kwargs: dict = dict( - num_cells=[400], - out_features=300, - ) - # if the shared network is a MLP, these parameters will be used - policy_mlp_kwargs: dict = dict( - num_cells=[256, 448], - out_features=1, # will be overwritten - ) - # parameters used for the MLP actor network - value_mlp_kwargs: dict = dict( - num_cells=[256, 448], - out_features=1, # will be overwritten - ) - # parameters used for the MLP value network - lstm_kwargs: dict = dict( - input_size=256, - hidden_size=256, - ) - # if lstm is True, these parameters will be used for the network - + @dataclass class A2CModelConfig: From db7c0cdd23599342c69c8944b963c493b16f5ff4 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:21:34 +0100 Subject: [PATCH 03/14] fix --- torchrl/trainers/helpers/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index c777c68d101..89a91c79678 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -861,11 +861,11 @@ def make_ppo_model( in_keys = ["hidden"] if not cfg.gSDE: if action_spec.domain == "continuous": - actor_net = NormalParamWrapper( + policy_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=dist_in_keys + policy_net, in_keys=in_keys, out_keys=dist_in_keys ) else: gSDE_state_key = "hidden" @@ -936,11 +936,11 @@ def make_ppo_model( if not cfg.gSDE: if action_spec.domain == "continuous": - actor_net = NormalParamWrapper( + policy_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=dist_in_keys + policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor @@ -1833,7 +1833,7 @@ class PPOModelConfig: # if True, uses an LSTM for the policy. shared_mapping: bool = False # if True, the first layers of the actor-critic are shared. - + @dataclass class A2CModelConfig: From b08124f17e944f77a7415a0e427fec4325fca833 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:29:34 +0100 Subject: [PATCH 04/14] minor naming change --- torchrl/trainers/helpers/models.py | 32 ++++++++++++++---------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 89a91c79678..b74ff069127 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -492,7 +492,6 @@ def make_a2c_model( specs = proof_environment.specs # TODO: use env.sepcs action_spec = specs["action_spec"] - # Define input observation format for actor and critic if in_keys_actor is None and proof_environment.from_pixels: in_keys_actor = ["pixels"] in_keys_critic = ["pixels"] @@ -564,20 +563,20 @@ def make_a2c_model( out_features=out_features, ) - in_keys = ["hidden"] + shared_out_keys = ["hidden"] if not cfg.gSDE: if action_spec.domain == "continuous": actor_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - actor_net, in_keys=in_keys, out_keys=dist_in_keys + actor_net, in_keys=shared_out_keys, out_keys=dist_in_keys ) else: gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, - in_keys=in_keys, + in_keys=shared_out_keys, out_keys=["action"], # will be overwritten ) @@ -615,7 +614,7 @@ def make_a2c_model( num_cells=[64], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=in_keys) + value_operator = ValueOperator(value_net, in_keys=shared_out_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -786,7 +785,6 @@ def make_ppo_model( specs = proof_environment.specs # TODO: use env.sepcs action_spec = specs["action_spec"] - # Define input observation format for actor and critic if in_keys_actor is None and proof_environment.from_pixels: in_keys_actor = ["pixels"] in_keys_critic = ["pixels"] @@ -841,7 +839,7 @@ def make_ppo_model( ) common_module = MLP( num_cells=[ - 64, + 400, ], out_features=hidden_features, activate_last_layer=True, @@ -854,24 +852,24 @@ def make_ppo_model( ) policy_net = MLP( - num_cells=[64], + num_cells=[200], out_features=out_features, ) - in_keys = ["hidden"] + shared_out_keys = ["hidden"] if not cfg.gSDE: if action_spec.domain == "continuous": policy_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - policy_net, in_keys=in_keys, out_keys=dist_in_keys + policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys ) else: gSDE_state_key = "hidden" actor_module = SafeModule( policy_net, - in_keys=in_keys, + in_keys=shared_out_keys, out_keys=["action"], # will be overwritten ) @@ -906,10 +904,10 @@ def make_ppo_model( return_log_prob=True, ) value_net = MLP( - num_cells=[64], + num_cells=[200], out_features=1, ) - value_operator = ValueOperator(value_net, in_keys=in_keys) + value_operator = ValueOperator(value_net, in_keys=shared_out_keys) actor_value = ActorValueOperator( common_operator=common_operator, policy_operator=policy_operator, @@ -923,14 +921,14 @@ def make_ppo_model( if cfg.lstm: policy_net = LSTMNet( out_features=out_features, - lstm_kwargs={"input_size": 64, "hidden_size": 64}, - mlp_kwargs={"num_cells": [64, 64], "out_features": 64}, + lstm_kwargs={"input_size": 256, "hidden_size": 256}, + mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, ) in_keys_actor += ["hidden0", "hidden1"] out_keys += ["hidden0", "hidden1", ("next", "hidden0"), ("next", "hidden1")] else: policy_net = MLP( - num_cells=[64, 64], + num_cells=[400, 300], out_features=out_features, ) @@ -983,7 +981,7 @@ def make_ppo_model( ) value_net = MLP( - num_cells=[64, 64], + num_cells=[400, 300], out_features=1, ) value_po = ValueOperator( From c9106c27e4947a1e82c7f2f9e4eb31972e75606e Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:42:39 +0100 Subject: [PATCH 05/14] tests --- test/test_helpers.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 77807effac6..0c5a07d9f82 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -241,7 +241,8 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) -def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): +@pytest.mark.parametrize("action_space", ["discrete", "continuous"]) +def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, action_space): if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) @@ -262,11 +263,17 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): # if gsde and from_pixels: # pytest.skip("gsde and from_pixels are incompatible") - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) + if from_pixels: + if action_space == "continuous": + env_maker = (ContinuousActionConvMockEnvNumpy) + else: + env_maker = (DiscreteActionConvMockEnvNumpy) + else: + if action_space == "continuous": + env_maker = (ContinuousActionVecMockEnv) + else: + env_maker = (DiscreteActionVecMockEnv) + env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker ) @@ -365,7 +372,8 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): @pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) -def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): +@pytest.mark.parametrize("action_space", ["discrete", "continuous"]) +def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration, action_space): A2CModelConfig.advantage_in_loss = False if not gsde and exploration != "random": pytest.skip("no need to test this setting") @@ -389,11 +397,17 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): # if gsde and from_pixels: # pytest.skip("gsde and from_pixels are incompatible") - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) + if from_pixels: + if action_space == "continuous": + env_maker = (ContinuousActionConvMockEnvNumpy) + else: + env_maker = (DiscreteActionConvMockEnvNumpy) + else: + if action_space == "continuous": + env_maker = (ContinuousActionVecMockEnv) + else: + env_maker = (DiscreteActionVecMockEnv) + env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker ) From 5cfe92430509d455d8d63f21d2b9ab67f4fb2fe5 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 09:52:50 +0100 Subject: [PATCH 06/14] fix in variable --- test/test_helpers.py | 24 ++++++++++++++++++++++++ torchrl/trainers/helpers/models.py | 8 ++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 0c5a07d9f82..57c44a14f6c 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -291,6 +291,18 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, actio ) return + if action_space == "discrete" and cfg.gSDE: + with pytest.raises( + RuntimeError, + match="cannot use gSDE with discrete actions", + ): + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + return + actor_value = make_ppo_model( proof_environment, device=device, @@ -425,6 +437,18 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration, actio ) return + if action_space == "discrete" and cfg.gSDE: + with pytest.raises( + RuntimeError, + match="cannot use gSDE with discrete actions", + ): + actor_value = make_a2c_model( + proof_environment, + device=device, + cfg=cfg, + ) + return + actor_value = make_a2c_model( proof_environment, device=device, diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index b74ff069127..a8c1268dfa5 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -566,11 +566,11 @@ def make_a2c_model( shared_out_keys = ["hidden"] if not cfg.gSDE: if action_spec.domain == "continuous": - actor_net = NormalParamWrapper( + policy_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - actor_net, in_keys=shared_out_keys, out_keys=dist_in_keys + policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys ) else: gSDE_state_key = "hidden" @@ -641,11 +641,11 @@ def make_a2c_model( if not cfg.gSDE: if action_spec.domain == "continuous": - actor_net = NormalParamWrapper( + policy_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) actor_module = SafeModule( - actor_net, in_keys=in_keys_actor, out_keys=dist_in_keys + policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys ) else: in_keys = in_keys_actor From 5872109d69c8f87cfd9614370f06eb2127cdae18 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 10:04:15 +0100 Subject: [PATCH 07/14] added tests --- test/test_helpers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 57c44a14f6c..d457d16d4b7 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -315,9 +315,11 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, actio "pixels_orig" if len(from_pixels) else "observation_orig", "action", "sample_log_prob", - "loc", - "scale", ] + if action_space == "continuous": + expected_keys += ["loc", "scale"] + else: + expected_keys += ["logits"] if shared_mapping: expected_keys += ["hidden"] if len(gsde): @@ -461,9 +463,11 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration, actio "pixels_orig" if len(from_pixels) else "observation_orig", "action", "sample_log_prob", - "loc", - "scale", ] + if action_space == "continuous": + expected_keys += ["loc", "scale"] + else: + expected_keys += ["logits"] if shared_mapping: expected_keys += ["hidden"] if len(gsde): From 8ad1b6f351bb01ac89ddf177a0cc29f951fe93bf Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 29 Nov 2022 10:40:14 +0100 Subject: [PATCH 08/14] format --- test/test_helpers.py | 24 ++++++++++++++---------- torchrl/trainers/helpers/models.py | 12 ++++++++---- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index d457d16d4b7..35d84c53e07 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -242,7 +242,9 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) @pytest.mark.parametrize("action_space", ["discrete", "continuous"]) -def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, action_space): +def test_ppo_maker( + device, from_pixels, shared_mapping, gsde, exploration, action_space +): if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) @@ -265,14 +267,14 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, actio if from_pixels: if action_space == "continuous": - env_maker = (ContinuousActionConvMockEnvNumpy) + env_maker = ContinuousActionConvMockEnvNumpy else: - env_maker = (DiscreteActionConvMockEnvNumpy) + env_maker = DiscreteActionConvMockEnvNumpy else: if action_space == "continuous": - env_maker = (ContinuousActionVecMockEnv) + env_maker = ContinuousActionVecMockEnv else: - env_maker = (DiscreteActionVecMockEnv) + env_maker = DiscreteActionVecMockEnv env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker @@ -387,7 +389,9 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration, actio @pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) @pytest.mark.parametrize("action_space", ["discrete", "continuous"]) -def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration, action_space): +def test_a2c_maker( + device, from_pixels, shared_mapping, gsde, exploration, action_space +): A2CModelConfig.advantage_in_loss = False if not gsde and exploration != "random": pytest.skip("no need to test this setting") @@ -413,14 +417,14 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration, actio if from_pixels: if action_space == "continuous": - env_maker = (ContinuousActionConvMockEnvNumpy) + env_maker = ContinuousActionConvMockEnvNumpy else: - env_maker = (DiscreteActionConvMockEnvNumpy) + env_maker = DiscreteActionConvMockEnvNumpy else: if action_space == "continuous": - env_maker = (ContinuousActionVecMockEnv) + env_maker = ContinuousActionVecMockEnv else: - env_maker = (DiscreteActionVecMockEnv) + env_maker = DiscreteActionVecMockEnv env_maker = transformed_env_constructor( cfg, use_env_creator=False, custom_env_maker=env_maker diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index a8c1268dfa5..e9e57eadb45 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -567,7 +567,8 @@ def make_a2c_model( if not cfg.gSDE: if action_spec.domain == "continuous": policy_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", ) actor_module = SafeModule( policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys @@ -642,7 +643,8 @@ def make_a2c_model( if not cfg.gSDE: if action_spec.domain == "continuous": policy_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", ) actor_module = SafeModule( policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys @@ -860,7 +862,8 @@ def make_ppo_model( if not cfg.gSDE: if action_spec.domain == "continuous": policy_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", ) actor_module = SafeModule( policy_net, in_keys=shared_out_keys, out_keys=dist_in_keys @@ -935,7 +938,8 @@ def make_ppo_model( if not cfg.gSDE: if action_spec.domain == "continuous": policy_net = NormalParamWrapper( - policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" + policy_net, + scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", ) actor_module = SafeModule( policy_net, in_keys=in_keys_actor, out_keys=dist_in_keys From 13b8bcc76ba91b068a2af621160cbb9702408cb0 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 2 Dec 2022 07:41:18 +0100 Subject: [PATCH 09/14] redo tests --- test/test_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 35d84c53e07..610429f5387 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -244,7 +244,7 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("action_space", ["discrete", "continuous"]) def test_ppo_maker( device, from_pixels, shared_mapping, gsde, exploration, action_space -): +): if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) From f88c5d0b8f3852970f4bcb67204ad8a24640905f Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 2 Dec 2022 07:50:47 +0100 Subject: [PATCH 10/14] format --- test/test_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 610429f5387..35d84c53e07 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -244,7 +244,7 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.parametrize("action_space", ["discrete", "continuous"]) def test_ppo_maker( device, from_pixels, shared_mapping, gsde, exploration, action_space -): +): if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) From f936f492e18ed632de6043d7bcb23eb3bb38d8d2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 30 Nov 2022 16:17:37 +0000 Subject: [PATCH 11/14] [Feature] Support for in-place functionalization (#714) --- .../linux_libs/scripts_gym/install.sh | 4 + .../linux_libs/scripts_gym/run_test.sh | 2 + .../linux_libs/scripts_habitat/install.sh | 2 +- .../linux_libs/scripts_jumanji/install.sh | 5 +- .../linux_olddeps/scripts_gym_0_13/install.sh | 4 + test/test_cost.py | 834 +++++++++--------- test/test_functorch.py | 323 ------- test/test_libs.py | 12 +- test/test_modules.py | 32 - test/test_tensordictmodules.py | 678 ++++---------- torchrl/data/replay_buffers/storages.py | 4 +- torchrl/modules/__init__.py | 7 - torchrl/modules/distributions/continuous.py | 2 +- torchrl/modules/tensordict_module/actors.py | 185 ++-- torchrl/modules/tensordict_module/common.py | 114 ++- .../modules/tensordict_module/exploration.py | 13 +- .../tensordict_module/probabilistic.py | 60 +- torchrl/modules/tensordict_module/sequence.py | 63 +- torchrl/modules/utils/__init__.py | 84 ++ torchrl/objectives/a2c.py | 8 +- torchrl/objectives/common.py | 514 ++++------- torchrl/objectives/ddpg.py | 34 +- torchrl/objectives/deprecated.py | 49 +- torchrl/objectives/dqn.py | 8 +- torchrl/objectives/ppo.py | 11 +- torchrl/objectives/redq.py | 125 ++- torchrl/objectives/reinforce.py | 5 - torchrl/objectives/sac.py | 83 +- torchrl/objectives/utils.py | 78 +- torchrl/objectives/value/advantages.py | 48 +- 30 files changed, 1328 insertions(+), 2063 deletions(-) delete mode 100644 test/test_functorch.py diff --git a/.circleci/unittest/linux_libs/scripts_gym/install.sh b/.circleci/unittest/linux_libs/scripts_gym/install.sh index 9e3739fe2b2..0cdee0320c1 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/install.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/install.sh @@ -44,5 +44,9 @@ fi # install tensordict pip install git+https://github.com/pytorch-labs/tensordict +# smoke test +python -c "import tensordict" + printf "* Installing torchrl\n" python setup.py develop +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh b/.circleci/unittest/linux_libs/scripts_gym/run_test.sh index 3e151539d92..4e08dec58e1 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/run_test.sh @@ -5,6 +5,8 @@ set -e eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env +yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y + export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" diff --git a/.circleci/unittest/linux_libs/scripts_habitat/install.sh b/.circleci/unittest/linux_libs/scripts_habitat/install.sh index af2f78de49f..e5833cd1356 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/install.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/install.sh @@ -41,7 +41,7 @@ fi pip install git+https://github.com/pytorch-labs/tensordict # smoke test -python -c "import functorch" +python -c "import functorch;import tensordict" printf "* Installing torchrl\n" pip3 install -e . diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh index c0f97977649..767070f2b25 100755 --- a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh @@ -35,8 +35,11 @@ else pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict + # smoke test -python -c "import functorch" +python -c "import functorch;import tensordict" printf "* Installing torchrl\n" pip3 install -e . diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 9e3739fe2b2..0cdee0320c1 100755 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -44,5 +44,9 @@ fi # install tensordict pip install git+https://github.com/pytorch-labs/tensordict +# smoke test +python -c "import tensordict" + printf "* Installing torchrl\n" python setup.py develop +python -c "import torchrl" diff --git a/test/test_cost.py b/test/test_cost.py index 40fdd8919f6..e1edaac8f4b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6,17 +6,14 @@ import argparse from copy import deepcopy -from tensordict.nn.functional_modules import FunctionalModuleWithBuffers - _has_functorch = True -FUNCTORCH_ERR = "" try: - import functorch + import functorch as ft # noqa - make_functional_with_buffers = functorch.make_functional_with_buffers + make_functional_with_buffers = ft.make_functional_with_buffers + FUNCTORCH_ERR = "" except ImportError as err: _has_functorch = False - make_functional_with_buffers = FunctionalModuleWithBuffers._create_from FUNCTORCH_ERR = str(err) import numpy as np @@ -24,9 +21,10 @@ import torch from _utils_internal import dtype_fixture, get_available_devices # noqa from mocking_classes import ContinuousActionConvMockEnv +from tensordict.nn import get_functional # from torchrl.data.postprocs.utils import expand_as_right -from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from tensordict.tensordict import assert_allclose_td, TensorDict from tensordict.utils import expand_as_right from torch import autograd, nn from torchrl.data import ( @@ -65,6 +63,7 @@ ProbabilisticActor, ValueOperator, ) +from torchrl.modules.utils import Buffer from torchrl.objectives import ( A2CLoss, ClipPPOLoss, @@ -270,56 +269,12 @@ def _create_seq_mock_data_dqn( ) return td - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( "action_spec_type", ("nd_bounded", "one_hot", "categorical") ) - @pytest.mark.parametrize("is_nn_module", (False, True)) - def test_dqn(self, delay_value, device, action_spec_type, is_nn_module): - torch.manual_seed(self.seed) - actor = self._create_mock_actor( - action_spec_type=action_spec_type, device=device, is_nn_module=is_nn_module - ) - td = self._create_mock_data_dqn( - action_spec_type=action_spec_type, device=device - ) - loss_fn = DQNLoss(actor, gamma=0.9, loss_function="l2", delay_value=delay_value) - with _check_td_steady(td): - loss = loss_fn(td) - assert loss_fn.priority_key in td.keys() - - sum([item for _, item in loss.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) - def test_dqn_nofunctorch(self, delay_value, device, action_spec_type): + def test_dqn(self, delay_value, device, action_spec_type): torch.manual_seed(self.seed) actor = self._create_mock_actor( action_spec_type=action_spec_type, device=device @@ -351,9 +306,6 @@ def test_dqn_nofunctorch(self, delay_value, device, action_spec_type): p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("n", range(4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) @@ -395,68 +347,6 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): sum([item for _, item in loss_ms.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("n", range(4)) - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) - def test_dqn_batcher_nofunctorch( - self, n, delay_value, device, action_spec_type, gamma=0.9 - ): - torch.manual_seed(self.seed) - actor = self._create_mock_actor( - action_spec_type=action_spec_type, device=device - ) - - td = self._create_seq_mock_data_dqn( - action_spec_type=action_spec_type, device=device - ) - loss_fn = DQNLoss( - actor, gamma=gamma, loss_function="l2", delay_value=delay_value - ) - - ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) - ms_td = ms(td.clone()) - - with _check_td_steady(ms_td): - loss_ms = loss_fn(ms_td) - assert loss_fn.priority_key in ms_td.keys() - - with torch.no_grad(): - loss = loss_fn(td) - if n == 0: - assert_allclose_td(td, ms_td.select(*list(td.keys()))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) - assert ( - abs(_loss - _loss_ms) < 1e-3 - ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" - else: - with pytest.raises(AssertionError): - assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): @@ -473,62 +363,13 @@ def test_dqn_batcher_nofunctorch( p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("atoms", range(4, 10)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_devices()) @pytest.mark.parametrize( "action_spec_type", ("mult_one_hot", "one_hot", "categorical") ) - @pytest.mark.parametrize("is_nn_module", (False, True)) def test_distributional_dqn( - self, atoms, delay_value, device, action_spec_type, is_nn_module, gamma=0.9 - ): - torch.manual_seed(self.seed) - actor = self._create_mock_distributional_actor( - action_spec_type=action_spec_type, atoms=atoms, is_nn_module=is_nn_module - ).to(device) - - td = self._create_mock_data_dqn( - action_spec_type=action_spec_type, atoms=atoms - ).to(device) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=delay_value) - - with _check_td_steady(td): - loss = loss_fn(td) - assert loss_fn.priority_key in td.keys() - - sum([item for _, item in loss.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("atoms", range(4, 10)) - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("mult_one_hot", "one_hot", "categorical") - ) - def test_distributional_dqn_nofunctorch( self, atoms, delay_value, device, action_spec_type, gamma=0.9 ): torch.manual_seed(self.seed) @@ -556,7 +397,10 @@ def test_distributional_dqn_nofunctorch( if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: - assert not (target_value == target_value2).any() + for key, val in target_value.flatten_keys(",").items(): + if key in ("support",): + continue + assert not (val == target_value2[tuple(key.split(","))]).any(), key # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] @@ -678,6 +522,14 @@ def test_ddpg(self, delay_actor, delay_value, device): with _check_td_steady(td): loss = loss_fn(td) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -686,20 +538,20 @@ def test_ddpg(self, delay_actor, delay_value, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) elif k == "loss_value": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values(True, True) ) else: raise NotImplementedError(k) @@ -712,12 +564,18 @@ def test_ddpg(self, delay_actor, delay_value, device): assert p.grad.norm() > 0.0 # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): + target_actor = [p.clone() for p in loss_fn.target_actor_network_params.values()] + target_value = [p.clone() for p in loss_fn.target_value_network_params.values()] + _i = -1 + for _i, p in enumerate(loss_fn.parameters()): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + assert _i >= 0 + target_actor2 = [ + p.clone() for p in loss_fn.target_actor_network_params.values() + ] + target_value2 = [ + p.clone() for p in loss_fn.target_value_network_params.values() + ] if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -930,54 +788,78 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_value": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) else: raise NotImplementedError(k) @@ -1063,14 +945,44 @@ def test_sac_batcher( assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] - target_value = [p.clone() for p in loss_fn.target_value_network_params] + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_value = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] for p in loss_fn.parameters(): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue2 = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_value2 = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -1264,29 +1176,41 @@ def test_redq(self, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) else: raise NotImplementedError(k) @@ -1339,29 +1263,30 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) + if isinstance(p, nn.Parameter) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) ) else: raise NotImplementedError(k) @@ -1385,13 +1310,19 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): p.data *= 0 counter = 0 - for p in loss_fn.qvalue_network_params: + for key, p in loss_fn.qvalue_network_params.items(True, True): + if not isinstance(key, tuple): + key = (key,) if not isinstance(p, nn.Parameter): counter += 1 - assert (p == loss_fn._param_maps[p]).all() + key = "_sep_".join(["qvalue_network", *key]) + mapped_param = next( + (k for k, val in loss_fn._param_maps.items() if val == key) + ) + assert (p == getattr(loss_fn, mapped_param)).all() assert (p == 0).all() - assert counter == len(loss_fn._actor_network_params) - assert counter == len(loss_fn.actor_network_params) + assert counter == len(loss_fn._actor_network_params.keys(True, True)) + assert counter == len(loss_fn.actor_network_params.keys(True, True)) # check that params of the original actor are those of the loss_fn for p in actor.parameters(): @@ -1497,12 +1428,20 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_actor = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) for p in loss_fn.parameters(): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_actor2 = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -1554,6 +1493,29 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): ) return value.to(device) + def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + base_layer = nn.Linear(obs_dim, 5) + net = NormalParamWrapper( + nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) + ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + dist_in_keys=["loc", "scale"], + spec=CompositeSpec(action=action_spec, loc=None, scale=None), + ) + module = nn.Sequential(base_layer, nn.Linear(5, 1)) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return actor.to(device), value.to(device) + def _create_mock_distributional_actor( self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 ): @@ -1657,26 +1619,105 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage): loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() + counter = 0 + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + assert counter == 2 + + value.zero_grad() + loss_objective.backward() + counter = 0 + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + assert counter == 2 + actor.zero_grad() + + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) + @pytest.mark.parametrize("device", get_available_devices()) + def test_ppo_shared(self, loss_class, device, advantage): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ppo(device=device) + + actor, value = self._create_mock_actor_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + gradient_mode=False, + ) + elif advantage == "td": + advantage = TDEstimate( + gamma=0.9, + value_network=value, + gradient_mode=False, + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimate( + gamma=0.9, + lmbda=0.9, + value_network=value, + gradient_mode=False, + ) + else: + raise NotImplementedError + loss_fn = loss_class( + actor, + value, + gamma=0.9, + loss_critic_type="l2", + advantage_module=advantage, + ) + + loss = loss_fn(td) + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + counter = 0 for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 assert "actor" not in name assert "critic" in name if p.grad is None: assert "actor" in name assert "critic" not in name + assert counter == 2 value.zero_grad() loss_objective.backward() named_parameters = loss_fn.named_parameters() + counter = 0 for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 assert "actor" in name assert "critic" not in name if p.grad is None: assert "actor" not in name assert "critic" in name actor.zero_grad() + assert counter == 4 + @pytest.mark.skipif( + not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" + ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @@ -1707,58 +1748,41 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ) floss_fn, params, buffers = make_functional_with_buffers(loss_fn) - + # fill params with zero + for p in params: + p.data.zero_() + # assert len(list(floss_fn.parameters())) == 0 loss = floss_fn(params, buffers, td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in name - assert "critic" in name - if p.grad is None: - assert "actor" in name - assert "critic" not in name - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in key - assert "value" in key or "critic" in key - if p.grad is None: - assert "actor" in key - assert "value" not in key and "critic" not in key - - if _has_functorch: - for param in params: - param.grad = None - else: - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + + for param in params: + param.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "critic" not in name - if p.grad is None: - assert "actor" not in name - assert "critic" in name - for param in params: - param.grad = None - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in key - assert "value" not in key and "critic" not in key - if p.grad is None: - assert "actor" not in key - assert "value" in key or "critic" in key - for param in params.flatten_keys(".").values(): - param.grad = None + + for (name, other_p), p in zip(named_parameters, params): + assert other_p.shape == p.shape + assert other_p.dtype == p.dtype + assert other_p.device == p.device + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + for param in params: + param.grad = None class TestA2C: @@ -1858,7 +1882,7 @@ def test_a2c(self, device, gradient_mode, advantage): RuntimeError, match="tensordict stored action require grad.", ): - loss = loss_fn._log_probs(td) + _ = loss_fn._log_probs(td) td["action"].requires_grad = False # Check error is raised when advantage_diff_key present and does not required grad @@ -1900,6 +1924,9 @@ def test_a2c(self, device, gradient_mode, advantage): # test reset loss_fn.reset() + @pytest.mark.skipif( + not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" + ) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) @@ -1936,51 +1963,27 @@ def test_a2c_diff(self, device, gradient_mode, advantage): loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in name - assert "critic" in name - if p.grad is None: - assert "actor" in name - assert "critic" not in name - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in key - assert "value" in key or "critic" in key - if p.grad is None: - assert "actor" in key - assert "value" not in key and "critic" not in key - - if _has_functorch: - for param in params: - param.grad = None - else: - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + + for param in params: + param.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "critic" not in name - if p.grad is None: - assert "actor" not in name - assert "critic" in name - for param in params: - param.grad = None - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in key - assert "value" not in key and "critic" not in key - if p.grad is None: - assert "actor" not in key - assert "value" in key or "critic" in key - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + for param in params: + param.grad = None class TestReinforce: @@ -2008,20 +2011,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): advantage_module = GAE( gamma=gamma, lmbda=0.9, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) elif advantage == "td": advantage_module = TDEstimate( gamma=gamma, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) elif advantage == "td_lambda": advantage_module = TDLambdaEstimate( gamma=0.9, lmbda=0.9, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) else: @@ -2496,21 +2499,28 @@ def test_hold_out(): @pytest.mark.parametrize("mode", ["hard", "soft"]) @pytest.mark.parametrize("value_network_update_interval", [100, 1000]) @pytest.mark.parametrize("device", get_available_devices()) -def test_updater(mode, value_network_update_interval, device): +@pytest.mark.parametrize( + "dtype", + [ + torch.float64, + torch.float32, + ], +) +def test_updater(mode, value_network_update_interval, device, dtype): torch.manual_seed(100) class custom_module_error(nn.Module): def __init__(self): super().__init__() - self._target_params = [torch.randn(3, 4)] - self._target_error_params = [torch.randn(3, 4)] + self.target_params = [torch.randn(3, 4)] + self.target_error_params = [torch.randn(3, 4)] self.params = nn.ParameterList( [nn.Parameter(torch.randn(3, 4, requires_grad=True))] ) module = custom_module_error().to(device) with pytest.raises( - RuntimeError, match="Your module seems to have a _target tensor list " + RuntimeError, match="Your module seems to have a target tensor list " ): if mode == "hard": upd = HardUpdate(module, value_network_update_interval) @@ -2524,21 +2534,18 @@ def __init__(self): self.convert_to_functional(module1, "module1", create_target_params=True) module2 = torch.nn.BatchNorm2d(10).eval() self.module2 = module2 - if _has_functorch: - iterator_params = self.target_module1_params - iterator_buffers = self.target_module1_buffers - else: - iterator_params = self.target_module1_params.values() - iterator_buffers = self.target_module1_buffers.values() + iterator_params = self.target_module1_params.values( + include_nested=True, leaves_only=True + ) for target in iterator_params: - target.data.normal_() - for target in iterator_buffers: if target.dtype is not torch.int64: target.data.normal_() else: target.data += 10 - module = custom_module().to(device) + module = custom_module().to(device).to(dtype) + _ = module.module1_params + _ = module.target_module1_params if mode == "hard": upd = HardUpdate( module, value_network_update_interval=value_network_update_interval @@ -2546,130 +2553,79 @@ def __init__(self): elif mode == "soft": upd = SoftUpdate(module, 1 - 1 / value_network_update_interval) upd.init_() - for _, v in upd._targets.items(): - if isinstance(v, TensorDictBase): - for _v in v.values(): - if _v.dtype is not torch.int64: - _v.copy_(torch.randn_like(_v)) - else: - _v += 10 + for _, _v in upd._targets.items(True, True): + if _v.dtype is not torch.int64: + _v.copy_(torch.randn_like(_v)) else: - for _v in v: - if _v.dtype is not torch.int64: - _v.copy_(torch.randn_like(_v)) - else: - _v += 10 + _v += 10 # total dist - if _has_functorch: - d0 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d0 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d0 += (target_val[key] - val[key]).norm().item() + d0 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + assert target_val.dtype is source_val.dtype, key + assert target_val.device == source_val.device, key + if target_val.dtype == torch.long: + continue + d0 += (target_val - source_val).norm().item() assert d0 > 0 if mode == "hard": for i in range(value_network_update_interval + 1): # test that no update is occuring until value_network_update_interval - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 == d0, i assert upd.counter == i upd.step() assert upd.counter == 0 # test that a new update has occured - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 < d0 elif mode == "soft": upd.step() - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 < d0 upd.init_() upd.step() - if _has_functorch: - d2 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d2 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d2 += (target_val[key] - val[key]).norm().item() + d2 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d2 += (target_val - source_val).norm().item() assert d2 < 1e-6 @@ -2953,12 +2909,22 @@ def __init__(self, actor_network, qvalue_network): p.data += torch.randn_like(p) assert len(list(loss.parameters())) == 6 - assert len(loss.actor_network_params) == 4 - assert len(loss.qvalue_network_params) == 4 - for p in loss.actor_network_params: - assert isinstance(p, nn.Parameter) - assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all() - assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all() + assert ( + len(loss.actor_network_params.keys(include_nested=True, leaves_only=True)) == 4 + ) + assert ( + len(loss.qvalue_network_params.keys(include_nested=True, leaves_only=True)) == 4 + ) + for p in loss.actor_network_params.values(include_nested=True, leaves_only=True): + assert isinstance(p, nn.Parameter) or isinstance(p, Buffer) + for i, (key, value) in enumerate( + loss.qvalue_network_params.items(include_nested=True, leaves_only=True) + ): + p1 = value + p2 = loss.actor_network_params[key] + assert (p1 == p2).all() + if i == 1: + break # map module if dest == "double": @@ -2970,16 +2936,18 @@ def __init__(self, actor_network, qvalue_network): else: loss = loss.to(dest) - for p in loss.actor_network_params: + for p in loss.actor_network_params.values(include_nested=True, leaves_only=True): assert isinstance(p, nn.Parameter) assert p.dtype is expected_dtype assert p.device == torch.device(expected_device) - assert loss.qvalue_network_params[0].dtype is expected_dtype - assert loss.qvalue_network_params[1].dtype is expected_dtype - assert loss.qvalue_network_params[0].device == torch.device(expected_device) - assert loss.qvalue_network_params[1].device == torch.device(expected_device) - assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all() - assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all() + for i, (key, qvalparam) in enumerate( + loss.qvalue_network_params.items(include_nested=True, leaves_only=True) + ): + assert qvalparam.dtype is expected_dtype, (key, qvalparam) + assert qvalparam.device == torch.device(expected_device), key + assert (qvalparam == loss.actor_network_params[key]).all(), key + if i == 1: + break if __name__ == "__main__": diff --git a/test/test_functorch.py b/test/test_functorch.py deleted file mode 100644 index 7b043968afb..00000000000 --- a/test/test_functorch.py +++ /dev/null @@ -1,323 +0,0 @@ -import argparse - -import pytest -import torch - -try: - from functorch import vmap - - _has_functorch = True -except ImportError: - _has_functorch = False -from tensordict import TensorDict -from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) -from torch import nn -from torchrl.modules import SafeModule, SafeSequential - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_patch(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - fmodule, params = FunctionalModule._create_from(module) - x = torch.randn(10, 1, 3) - if batch_params: - params = params.expand(10, *params.batch_size) - y = vmap(fmodule, (0, 0))(params, x) - else: - y = vmap(fmodule, (None, 0))(params, x) - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(10, 2, 3) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - y = vmap(fmodule, (0, 0, 0))(params, buffers, x) - else: - raise NotImplementedError - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdmodule(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - fmodule, params = FunctionalModule._create_from(module) - tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - tdmodule(td, params=params, vmap=(0, 0)) - else: - tdmodule(td, params=params, vmap=(None, 0)) - y = td["y"] - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - y = td["y"] - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - buffers = buffers.expand(10, *buffers.batch_size) - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - tdmodule(td, params=params, buffers=buffers, vmap=(None, None, 0)) - y = td["y"] - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - y = td["y"] - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdsequence(moduletype, batch_params): - if moduletype == "linear": - module1 = nn.Linear(3, 4) - fmodule1, params1 = FunctionalModule._create_from(module1) - module2 = nn.Linear(4, 5) - fmodule2, params2 = FunctionalModule._create_from(module2) - elif moduletype == "bn1": - module1 = nn.BatchNorm1d(3) - fmodule1, params1, buffers1 = FunctionalModuleWithBuffers._create_from(module1) - module2 = nn.BatchNorm1d(3) - fmodule2, params2, buffers2 = FunctionalModuleWithBuffers._create_from(module2) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) - params = TensorDict({"0": params1, "1": params2}, []) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - assert {"0", "1"} == set(params.keys()) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - tdmodule(td, params=params, vmap=(0, 0)) - else: - tdmodule(td, params=params, vmap=(None, 0)) - z = td["z"] - assert z.shape == torch.Size([10, 1, 5]) - elif moduletype == "bn1": - tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) - params = TensorDict({"0": params1, "1": params2}, []) - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - assert {"0", "1"} == set(params.keys()) - assert {"0", "1"} == set(buffers.keys()) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - z = td["z"] - assert z.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): - if moduletype == "linear": - module1 = nn.Linear(3, 4) - module2 = nn.Linear(4, 5) - elif moduletype == "bn1": - module1 = nn.BatchNorm1d(3) - module2 = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - assert {"0", "1"} == set(params.keys()) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - buffers = buffers.expand(10, *buffers.batch_size) - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - tdmodule(td, params=params, buffers=buffers, vmap=(None, None, 0)) - z = td["z"] - assert z.shape == torch.Size([10, 1, 5]) - elif moduletype == "bn1": - tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - assert {"0", "1"} == set(params.keys()) - assert {"0", "1"} == set(buffers.keys()) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - z = td["z"] - assert z.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -class TestNativeFunctorch: - def test_vamp_basic(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict): - a = tensordict["a"] - return TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ) - - tensordict = TensorDict({"a": torch.randn(3)}, []).expand(4) - out = vmap(MyModule(), (0,))(tensordict) - assert out.shape == torch.Size([4]) - assert out["a"].shape == torch.Size([4, 3]) - - def test_vamp_composed(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict, tensor): - a = tensordict["a"] - return ( - TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ), - tensor, - ) - - tensor = torch.randn(3) - tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]).expand(4, 3) - out = vmap(MyModule(), (0, None))(tensordict, tensor) - - assert out[0].shape == torch.Size([4, 3]) - assert out[1].shape == torch.Size([4, 3]) - assert out[0]["a"].shape == torch.Size([4, 3, 1]) - - def test_vamp_composed_flipped(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict, tensor): - a = tensordict["a"] - return ( - TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ), - tensor, - ) - - tensor = torch.randn(3).expand(4, 3) - tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]) - out = vmap(MyModule(), (None, 0))(tensordict, tensor) - - assert out[0].shape == torch.Size([4, 3]) - assert out[1].shape == torch.Size([4, 3]) - assert out[0]["a"].shape == torch.Size([4, 3, 1]) - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_libs.py b/test/test_libs.py index 06e09a0521b..4e4b4b811b2 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -272,15 +272,15 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): ) env = env_lib(*env_args, **env_kwargs) td = env.rollout(max_steps=5) - td0 = td[0].flatten_keys(".") + td0 = td[0] fake_td = env.fake_tensordict() - fake_td = fake_td.flatten_keys(".") - td = td.flatten_keys(".") - assert set(fake_td.keys()) == set(td.keys()) - for key in fake_td.keys(): + assert set(fake_td.keys(include_nested=True, leaves_only=True)) == set( + td.keys(include_nested=True, leaves_only=True) + ) + for key in fake_td.keys(include_nested=True, leaves_only=True): assert fake_td.get(key).shape == td.get(key)[0].shape - for key in fake_td.keys(): + for key in fake_td.keys(include_nested=True, leaves_only=True): assert fake_td.get(key).shape == td0.get(key).shape assert fake_td.get(key).dtype == td0.get(key).dtype assert fake_td.get(key).device == td0.get(key).device diff --git a/test/test_modules.py b/test/test_modules.py index 3a83f48c18c..2f37daab5a8 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -11,10 +11,6 @@ from mocking_classes import MockBatchedUnLockedEnv from packaging import version from tensordict import TensorDict -from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) from torch import nn from torchrl.data.tensor_specs import ( DiscreteTensorSpec, @@ -441,34 +437,6 @@ def test_lstm_net_nobatch(device, out_features, hidden_size): torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) -class TestFunctionalModules: - def test_func_seq(self): - module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3)) - fmodule, params = FunctionalModule._create_from(module) - x = torch.randn(3) - assert (fmodule(params, x) == module(x)).all() - - def test_func_bn(self): - module = nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4)) - module.eval() - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(10, 3) - assert (fmodule(params, buffers, x) == module(x)).all() - - def test_func_transformer(self): - torch.manual_seed(10) - batch = ( - (10,) - if version.parse(torch.__version__) >= version.parse("1.11") - else (1, 10) - ) - module = nn.Transformer(128) - module.eval() - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(*batch, 128) - torch.testing.assert_close(fmodule(params, buffers, x, x), module(x, x)) - - class TestPlanner: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("batch_size", [3, 5]) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 60f513ae213..ef6238777e2 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -7,22 +7,8 @@ import pytest import torch -from tensordict.tensordict import TensorDictBase - -_has_functorch = False -try: - from functorch import make_functional, make_functional_with_buffers - - _has_functorch = True -except ImportError: - from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, - ) - - make_functional = FunctionalModule._create_from - make_functional_with_buffers = FunctionalModuleWithBuffers._create_from from tensordict import TensorDict +from tensordict.nn.functional_modules import make_functional from torch import nn from torchrl.data.tensor_specs import ( CompositeSpec, @@ -38,6 +24,14 @@ from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule from torchrl.modules.tensordict_module.sequence import SafeSequential +_has_functorch = False +try: + from functorch import vmap + + _has_functorch = True +except ImportError: + pass + class TestTDModule: def test_multiple_output(self): @@ -243,7 +237,7 @@ def test_functional(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) - fnet, params = make_functional(net) + params = make_functional(net) if spec_type is None: spec = None @@ -260,7 +254,7 @@ def test_functional(self, safe, spec_type): ): tensordict_module = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -269,7 +263,7 @@ def test_functional(self, safe, spec_type): else: tensordict_module = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -292,75 +286,11 @@ def test_functional_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.Linear(3, 4 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - fnet, params = make_functional(net) - tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) + net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) + params = make_functional(net) - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - return - else: - tensordict_module = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.Linear(3, 4 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -401,13 +331,9 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): safe=safe, **kwargs, ) - tensordict_module, ( - params, - buffers, - ) = tensordict_module.make_functional_with_buffers() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td = tensordict_module(td, params=params, buffers=buffers) + tensordict_module(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -424,8 +350,7 @@ def test_functional_with_buffer(self, safe, spec_type): param_multiplier = 1 net = nn.BatchNorm1d(32 * param_multiplier) - - fnet, params, buffers = make_functional_with_buffers(net) + params = make_functional(net) if spec_type is None: spec = None @@ -442,7 +367,7 @@ def test_functional_with_buffer(self, safe, spec_type): ): tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -451,14 +376,14 @@ def test_functional_with_buffer(self, safe, spec_type): else: tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, ) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -474,12 +399,10 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.BatchNorm1d(32 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - fnet, params, buffers = make_functional_with_buffers(net) + net = NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)) + params = make_functional(net) tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -522,71 +445,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): ) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.BatchNorm1d(32 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - return - else: - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -607,8 +466,6 @@ def test_vmap(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) - fnet, params = make_functional(net) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -624,7 +481,7 @@ def test_vmap(self, safe, spec_type): ): tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -633,27 +490,25 @@ def test_vmap(self, safe, spec_type): else: tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, ) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -664,8 +519,9 @@ def test_vmap(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -684,12 +540,9 @@ def test_vmap_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.Linear(3, 4 * param_multiplier) - net = NormalParamWrapper(net) - in_keys = ["in"] - fnet, params = make_functional(net) + net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -731,113 +584,19 @@ def test_vmap_probabilistic(self, safe, spec_type): **kwargs, ) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.Linear(3, 4 * param_multiplier) - net = NormalParamWrapper(net) - in_keys = ["in"] - tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: + if safe and spec_type == "bounded": with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", + RuntimeError, match="vmap cannot be used with safe=True" ): - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) + td_out = vmap(tdmodule, (None, 0))(td, params) return else: - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, buffers=buffers, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0, None) - td_out = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, None)) + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -847,9 +606,10 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - # vmap = (0, 0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, buffers=buffers, vmap=(0, 0, 0)) + # vmap = (0, 0) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1046,14 +806,6 @@ def test_functional(self, safe, spec_type): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1065,17 +817,17 @@ def test_functional(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1083,14 +835,18 @@ def test_functional(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1098,7 +854,7 @@ def test_functional(self, safe, spec_type): assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) + tdmodule(td, params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1122,14 +878,7 @@ def test_functional_probabilistic(self, safe, spec_type): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - fnet2 = SafeModule(module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) + net2 = SafeModule(module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1149,17 +898,17 @@ def test_functional_probabilistic(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, dist_in_keys=["loc", "scale"], sample_out_key=["out"], @@ -1168,14 +917,18 @@ def test_functional_probabilistic(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1212,19 +965,6 @@ def test_functional_with_buffer( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) - fnet1, params1, buffers1 = make_functional_with_buffers(net1) - fdummy_net, _, _ = make_functional_with_buffers(dummy_net) - fnet2, params2, buffers2 = make_functional_with_buffers(net2) - - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if isinstance(buffers1, TensorDictBase): - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - else: - buffers = list(buffers1) + list(buffers2) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1236,17 +976,17 @@ def test_functional_with_buffer( pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1254,14 +994,18 @@ def test_functional_with_buffer( ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1269,10 +1013,10 @@ def test_functional_with_buffer( assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) + dist, *_ = tdmodule.get_dist(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 7]) @@ -1299,22 +1043,7 @@ def test_functional_with_buffer_probabilistic( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) net2 = NormalParamWrapper(net2) - - fnet1, params1, buffers1 = make_functional_with_buffers(net1) - fdummy_net, _, _ = make_functional_with_buffers(dummy_net) - # fnet2, params2, buffers2 = make_functional_with_buffers(net2) - # fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - fnet2, (params2, buffers2) = net2.make_functional_with_buffers() - - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if isinstance(buffers1, TensorDictBase): - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - else: - buffers = list(buffers1) + list(buffers2) if spec_type is None: spec = None @@ -1334,17 +1063,17 @@ def test_functional_with_buffer_probabilistic( pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, dist_in_keys=["loc", "scale"], sample_out_key=["out"], @@ -1353,14 +1082,18 @@ def test_functional_with_buffer_probabilistic( ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule, ["forward", "get_dist"]) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1368,73 +1101,9 @@ def test_functional_with_buffer_probabilistic( assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) - - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic_laterconstruct( - self, - safe, - spec_type, - ): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - net2 = NormalParamWrapper(net2) - net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(7) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - tdmodule2 = SafeProbabilisticModule( - net2, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) + dist, *_ = tdmodule.get_dist(td, params=params) assert dist.rsample().shape[: td.ndimension()] == td.shape assert td.shape == torch.Size([3]) @@ -1459,11 +1128,6 @@ def test_vmap(self, safe, spec_type): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - params = params1 + params2 - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1475,21 +1139,21 @@ def test_vmap(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1497,14 +1161,18 @@ def test_vmap(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1512,20 +1180,17 @@ def test_vmap(self, safe, spec_type): assert tdmodule[1] is tdmodule2 # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1536,9 +1201,10 @@ def test_vmap(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) # test bounds @@ -1557,14 +1223,10 @@ def test_vmap_probabilistic(self, safe, spec_type): param_multiplier = 2 net1 = nn.Linear(3, 4) - fnet1, params1 = make_functional(net1) net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - fnet2, params2 = make_functional(net2) - fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - - params = params1 + params2 + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1584,14 +1246,14 @@ def test_vmap_probabilistic(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1600,21 +1262,19 @@ def test_vmap_probabilistic(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, tdmodule2) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1625,9 +1285,10 @@ def test_vmap_probabilistic(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) # test bounds @@ -1636,6 +1297,45 @@ def test_vmap_probabilistic(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + @pytest.mark.parametrize("functional", [True, False]) + def test_submodule_sequence(self, functional): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) + + if functional: + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + params = make_functional(sub_seq_1) + sub_seq_1(td_1, params=params) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + params = make_functional(sub_seq_2) + sub_seq_2(td_2, params=params) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) + else: + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) + @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("functional", [True, False]) def test_sequential_partial(self, stack, functional): @@ -1643,29 +1343,14 @@ def test_sequential_partial(self, stack, functional): param_multiplier = 2 net1 = nn.Linear(3, 4) - if functional: - fnet1, params1 = make_functional(net1) - else: - params1 = None - fnet1 = net1 net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - if functional: - fnet2, params2 = make_functional(net2) - else: - fnet2 = net2 - params2 = None - fnet2 = SafeModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) net3 = NormalParamWrapper(net3) - if functional: - fnet3, params3 = make_functional(net3) - else: - fnet3 = net3 - params3 = None - fnet3 = SafeModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) spec = NdBoundedTensorSpec(-0.1, 0.1, 4) spec = CompositeSpec(out=spec, loc=None, scale=None) @@ -1673,14 +1358,14 @@ def test_sequential_partial(self, stack, functional): kwargs = {"distribution_class": TanhNormal} tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["a"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1688,7 +1373,7 @@ def test_sequential_partial(self, stack, functional): **kwargs, ) tdmodule3 = SafeProbabilisticModule( - fnet3, + net3, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1699,6 +1384,11 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) + if functional: + params = make_functional(tdmodule) + else: + params = None + if stack: td = torch.stack( [ @@ -1708,16 +1398,6 @@ def test_sequential_partial(self, stack, functional): 0, ) if functional: - if _has_functorch: - params = params1 + params2 + params3 - else: - params = TensorDict( - { - str(i): params - for i, params in enumerate((params1, params2, params3)) - }, - [], - ) tdmodule(td, params=params) else: tdmodule(td) @@ -1732,16 +1412,6 @@ def test_sequential_partial(self, stack, functional): else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) if functional: - if _has_functorch: - params = params1 + params2 + params3 - else: - params = TensorDict( - { - str(i): params - for i, params in enumerate((params1, params2, params3)) - }, - [], - ) tdmodule(td, params=params) else: tdmodule(td) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a05b53cf7c5..1bba001c1c1 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -370,7 +370,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: .memmap_(prefix=self.scratch_dir) .to(self.device) ) - for key, tensor in sorted(out.flatten_keys(".").items()): + for key, tensor in sorted( + out.items(include_nested=True, leaves_only=True), key=str + ): filesize = os.path.getsize(tensor.filename) / 1024 / 1024 print( f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8c6ca5d8593..5d42fbe753f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -13,13 +13,6 @@ TanhNormal, TruncatedNormal, ) - -# from .functional_modules import ( -# FunctionalModule, -# FunctionalModuleWithBuffers, -# extract_weights, -# extract_buffers, -# ) from .models import ( ConvNet, DdpgCnnActor, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index da136f67208..f0644174be0 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -95,7 +95,7 @@ def _call(self, x: torch.Tensor) -> torch.Tensor: def _inverse(self, y: torch.Tensor) -> torch.Tensor: eps = torch.finfo(y.dtype).eps - y.data.clamp_(-1 + eps, 1 - eps) + y = y.clamp(-1 + eps, 1 - eps) x = super()._inverse(y) return x diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index dba80fc67a5..f9661faa90e 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -76,25 +76,24 @@ class ProbabilisticActor(SafeProbabilisticModule): automatically translated into :obj:`spec = CompositeSpec(action=spec)` Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import NdBoundedTensorSpec - >>> from torchrl.modules import Actor, TanhNormal, NormalParamWrapper + >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, SafeModule, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> action_spec = NdBoundedTensorSpec(shape=torch.Size([4]), ... minimum=-1, maximum=1) >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers( - ... module) - >>> tensordict_module = SafeModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> params = make_functional(module) + >>> tensordict_module = SafeModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, ... spec=action_spec, ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> td = td_module(td, params=params, buffers=buffers) + >>> td = td_module(td, params=params) >>> td TensorDict( fields={ @@ -143,9 +142,9 @@ class ValueOperator(SafeModule): key is part of the in_keys list). Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import NdUnboundedContinuousTensorSpec >>> from torchrl.modules import ValueOperator @@ -157,20 +156,20 @@ class ValueOperator(SafeModule): ... def forward(self, obs, action): ... return self.linear(torch.cat([obs, action], -1)) >>> module = CustomModule() - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) >>> td_module = ValueOperator( - ... in_keys=["observation", "action"], - ... module=fmodule, - ... ) - >>> td_module(td, params=params, buffers=buffers) + ... in_keys=["observation", "action"], module=module + ... ) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 2]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ @@ -210,28 +209,29 @@ class QValueHook: action component. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params = make_functional(module) >>> hook = QValueHook("one_hot") - >>> _ = fmodule.register_forward_hook(hook) + >>> module.register_forward_hook(hook) >>> action_spec = OneHotDiscreteTensorSpec(4) - >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, - shared=False, + action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu) + device=None, + is_shared=False) """ @@ -326,9 +326,9 @@ class DistributionalQValueHook(QValueHook): action component. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor @@ -343,20 +343,21 @@ class DistributionalQValueHook(QValueHook): ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2) ... >>> module = CustomDistributionalQval() - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params = make_functional(module) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) - >>> _ = fmodule.register_forward_hook(hook) - >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> module.register_forward_hook(hook) + >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, - shared=False, + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu) + device=None, + is_shared=False) """ @@ -438,27 +439,27 @@ class QValueActor(Actor): This class hooks the module such that it returns a one-hot encoding of the argmax value. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params= make_functional(module) >>> action_spec = OneHotDiscreteTensorSpec(4) - >>> qvalue_actor = QValueActor(module=fmodule, spec=action_spec) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( fields={ - observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), action: Tensor(torch.Size([5, 4]), dtype=torch.int64), action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32), - chosen_action_value: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, + chosen_action_value: Tensor(torch.Size([5, 1]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu, + device=None, is_shared=False) """ @@ -480,7 +481,6 @@ class DistributionalQValueActor(QValueActor): This class hooks the module such that it returns a one-hot encoding of the argmax value on its support. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict >>> from torch import nn @@ -491,15 +491,15 @@ class DistributionalQValueActor(QValueActor): >>> module = MLP(out_features=(nbins, 4), depth=2) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> _ = qvalue_actor(td) + >>> qvalue_actor(td) >>> print(td) TensorDict( fields={ - observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu, + device=None, is_shared=False) """ @@ -566,7 +566,7 @@ class ActorValueOperator(SafeSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.modules.tensordict_module import ProbabilisticActor + >>> from torchrl.modules import ProbabilisticActor, SafeModule >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) @@ -601,34 +601,40 @@ class ActorValueOperator(SafeSequential): >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu) - + device=None, + is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ @@ -692,7 +698,7 @@ class ActorCriticOperator(ActorValueOperator): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.modules.tensordict_module import ProbabilisticActor + >>> from torchrl.modules import ProbabilisticActor, SafeModule >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) @@ -734,7 +740,7 @@ class ActorCriticOperator(ActorValueOperator): scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value @@ -747,7 +753,7 @@ class ActorCriticOperator(ActorValueOperator): sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) >>> td_clone = td_module.get_critic_operator()(td.clone()) >>> print(td_clone) # no action @@ -761,7 +767,7 @@ class ActorCriticOperator(ActorValueOperator): scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) """ @@ -823,13 +829,24 @@ class ActorCriticWrapper(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec - >>> from torchrl.modules.tensordict_module.deprec import ProbabilisticActor_deprecated - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticWrapper - >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) - >>> module_action = torch.nn.Linear(4, 8) - >>> td_module_action = ProbabilisticActor_deprecated( - ... module=module_action, - ... spec=spec_action, + >>> from torchrl.modules import ( + ActorCriticWrapper, + ProbabilisticActor, + NormalParamWrapper, + SafeModule, + TanhNormal, + ValueOperator, + ) + >>> action_spec = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> action_module = SafeModule( + NormalParamWrapper(torch.nn.Linear(4, 8)), + in_keys=["observation"], + out_keys=["loc", "scale"], + ) + >>> td_module_action = ProbabilisticActor( + ... module=action_module, + ... spec=action_spec, + ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) @@ -843,31 +860,37 @@ class ActorCriticWrapper(SafeSequential): >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu) - + device=None, + is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index c092197eb7c..d666031d028 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -6,6 +6,7 @@ from __future__ import annotations import inspect +import re import warnings from typing import Iterable, Optional, Type, Union @@ -24,10 +25,13 @@ "functional programming should work, but functionality and performance " "may be affected. Consider installing functorch and/or upgrating pytorch." ) - from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, - ) + + class FunctionalModule: # noqa: D101 + pass + + class FunctionalModuleWithBuffers: # noqa: D101 + pass + from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDictBase @@ -51,33 +55,44 @@ def _check_all_str(list_of_str, first_level=True): def _forward_hook_safe_action(module, tensordict_in, tensordict_out): - spec = module.spec - if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): - raise RuntimeError( - "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " - f"keys {module.out_keys}." - ) - elif not isinstance(spec, CompositeSpec): - out_key = module.out_keys[0] - keys = [out_key] - values = [spec] - else: - keys = list(spec.keys()) - values = [spec[key] for key in keys] - for _spec, _key in zip(values, keys): - if _spec is None: - continue - if not _spec.is_in(tensordict_out.get(_key)): - try: - tensordict_out.set_( - _key, - _spec.project(tensordict_out.get(_key)), - ) - except RuntimeError: - tensordict_out.set( - _key, - _spec.project(tensordict_out.get(_key)), - ) + try: + spec = module.spec + if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): + raise RuntimeError( + "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " + f"keys {module.out_keys}." + ) + elif not isinstance(spec, CompositeSpec): + out_key = module.out_keys[0] + keys = [out_key] + values = [spec] + else: + keys = list(spec.keys()) + values = [spec[key] for key in keys] + for _spec, _key in zip(values, keys): + if _spec is None: + continue + if not _spec.is_in(tensordict_out.get(_key)): + try: + tensordict_out.set_( + _key, + _spec.project(tensordict_out.get(_key)), + ) + except RuntimeError: + tensordict_out.set( + _key, + _spec.project(tensordict_out.get(_key)), + ) + except RuntimeError as err: + if re.search( + "attempting to use a Tensor in some data-dependent control flow", str(err) + ): + # "_is_stateless" in module.__dict__ and module._is_stateless: + raise RuntimeError( + f"vmap cannot be used with safe=True, consider turning the safe mode off. (original error message: {err})" + ) + else: + raise err class SafeModule(TensorDictModule): @@ -103,34 +118,35 @@ class SafeModule(TensorDictModule): case, the 'params' (and 'buffers') keyword argument must be specified: Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import NdUnboundedContinuousTensorSpec >>> from torchrl.modules import SafeModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(8) >>> module = torch.nn.GRUCell(4, 8) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) >>> td_fmodule = SafeModule( - ... module=fmodule, + ... module=module, ... spec=spec, ... in_keys=["input", "hidden"], ... out_keys=["output"], ... ) - >>> td_functional = td_fmodule(td.clone(), params=params, buffers=buffers) + >>> params = make_functional(td_fmodule) + >>> td_functional = td_fmodule(td.clone(), params=params) >>> print(td_functional) TensorDict( - fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) In the stateful case: >>> td_module = SafeModule( - ... module=module, + ... module=torch.nn.GRUCell(4, 8), ... spec=spec, ... in_keys=["input", "hidden"], ... out_keys=["output"], @@ -138,27 +154,29 @@ class SafeModule(TensorDictModule): >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( - fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) One can use a vmap operator to call the functional module. In this case the tensordict is expanded to match the batch size (i.e. the tensordict isn't modified in-place anymore): >>> # Model ensemble using vmap - >>> params_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in params) - >>> buffers_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in buffers) - >>> td_vmap = td_fmodule(td.clone(), params=params_repeat, buffers=buffers_repeat, vmap=True) + >>> from functorch import vmap + >>> params_repeat = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_fmodule, (None, 0))(td.clone(), params_repeat) >>> print(td_vmap) TensorDict( - fields={input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([4, 3]), - device=cpu) + device=None, + is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index fe3aac62df9..2dd65bb339d 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -65,7 +65,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], - [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) """ @@ -285,10 +285,19 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): >>> torch.manual_seed(0) >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) - >>> policy = Actor(spec, module=module) + >>> policy = Actor(module=module, spec=spec) >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) >>> print(explorative_policy(td)) + TensorDict( + fields={ + _ou_prev_noise: Tensor(torch.Size([10, 4]), dtype=torch.float32), + _ou_steps: Tensor(torch.Size([10, 1]), dtype=torch.int64), + action: Tensor(torch.Size([10, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) """ def __init__( diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 3061d1017fa..022977f378a 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -77,53 +77,57 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut Default is 1000 Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import SafeProbabilisticModule, TanhNormal, NormalParamWrapper + >>> from tensordict.nn.functional_modules import make_functional + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import ( + NormalParamWrapper, + SafeModule, + SafeProbabilisticModule, + TanhNormal, + ) >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = NdUnboundedContinuousTensorSpec(4) + >>> spec = CompositeSpec(action=NdUnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net = NormalParamWrapper(torch.nn.GRUCell(4, 8)) - >>> fnet, params, buffers = functorch.make_functional_with_buffers(net) - >>> module = SafeModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) + >>> module = SafeModule(net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) >>> td_module = SafeProbabilisticModule( - ... module=module, - ... spec=spec, - ... dist_in_keys=["loc", "scale"], - ... sample_out_key=["action"], - ... distribution_class=TanhNormal, - ... return_log_prob=True, - ... ) - >>> _ = td_module(td, params=params, buffers=buffers) + ... module=module, + ... spec=spec, + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["action"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( fields={ - input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), - action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) - >>> # In the vmap case, the tensordict is again expended to match the batch: - >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) - >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) - >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> from functorch import vmap + >>> params = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_module, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ - input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - action: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32)}, + sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)}, batch_size=torch.Size([4, 3]), - device=cpu, + device=None, is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index bbc3323630f..5e4886b70b2 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -31,77 +31,74 @@ class SafeSequential(TensorDictSequential, SafeModule): TensorDictSequence supports functional, modular and vmap coding: Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, SafeSequential, NormalParamWrapper + >>> from tensordict.nn.functional_modules import make_functional + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import TanhNormal, SafeSequential, SafeModule, NormalParamWrapper >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) - >>> spec1 = NdUnboundedContinuousTensorSpec(4) + >>> spec1 = CompositeSpec(hidden=NdUnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> fnet1, params1, buffers1 = functorch.make_functional_with_buffers(net1) - >>> fmodule1 = SafeModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) + >>> module1 = SafeModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( - ... module=fmodule1, - ... spec=spec1, - ... dist_in_keys=["loc", "scale"], - ... sample_out_key=["hidden"], - ... distribution_class=TanhNormal, - ... return_log_prob=True, - ... ) + ... module=module1, + ... spec=spec1, + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["hidden"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) >>> spec2 = NdUnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) - >>> fmodule2, params2, buffers2 = functorch.make_functional_with_buffers(module2) >>> td_module2 = SafeModule( - ... module=fmodule2, + ... module=module2, ... spec=spec2, ... in_keys=["hidden"], ... out_keys=["output"], ... ) >>> td_module = SafeSequential(td_module1, td_module2) - >>> params = params1 + params2 - >>> buffers = buffers1 + buffers2 - >>> _ = td_module(td, params=params, buffers=buffers) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( fields={ + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), input: Tensor(torch.Size([3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + output: Tensor(torch.Size([3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), - output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) - >>> # The module spec aggregates all the input specs: >>> print(td_module.spec) CompositeSpec( hidden: NdUnboundedContinuousTensorSpec( - shape=torch.Size([4]),space=None,device=cpu,dtype=torch.float32,domain=continuous), + shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous), + loc: None, + scale: None, output: NdUnboundedContinuousTensorSpec( - shape=torch.Size([8]),space=None,device=cpu,dtype=torch.float32,domain=continuous)) + shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous)) In the vmap case: - >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) - >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) - >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> from functorch import vmap + >>> params = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_module, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ + hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32), - output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)}, + scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)}, batch_size=torch.Size([4, 3]), - device=cpu, + device=None, is_shared=False) - """ module: nn.ModuleList diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index 4af16165f7c..ef430b85391 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,4 +3,88 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import OrderedDict + +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta +else: + from torch.nn.parameter import _disabled_torch_function_impl + + # Metaclass to combine _TensorMeta and the instance check override for Parameter. + class _ParameterMeta(torch._C._TensorMeta): + # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. + def __instancecheck__(self, instance): + return super().__instancecheck__(instance) or ( + isinstance(instance, torch.Tensor) + and getattr(instance, "_is_param", False) + ) + + from .mappings import biased_softplus, inv_softplus, mappings + + +class Buffer(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module parameter. + + Parameters are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. + Assigning a Tensor doesn't have such effect. This is because one might + want to cache some temporary state, like last hidden state of the RNN, in + the model. If there was no such class as :class:`Parameter`, these + temporaries would get registered too. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. See + :ref:`locally-disable-grad-doc` for more details. Default: `True` + """ + + def __new__(cls, data=None, requires_grad=False): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Buffer: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data): + raise RuntimeError( + f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation." + ) + t._is_param = True + return t + + # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types + # are still considered that custom tensor type and these methods will not be called for them. + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) + memo[id(self)] = result + return result + + def __repr__(self): + return "Buffer containing:\n" + super(Buffer, self).__repr__() + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, OrderedDict()), + ) + + __torch_function__ = _disabled_torch_function_impl diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index af20007b26a..e6b6d0852a6 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -54,7 +54,9 @@ def __init__( advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() - self.convert_to_functional(actor, "actor") + self.convert_to_functional( + actor, "actor", funs_to_decorate=["forward", "get_dist"] + ) self.convert_to_functional(critic, "critic", compare_against=self.actor_params) self.advantage_key = advantage_key self.advantage_diff_key = advantage_diff_key @@ -93,7 +95,8 @@ def _log_probs( tensordict_clone = tensordict.select(*self.actor.in_keys).clone() dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) @@ -117,7 +120,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = advantage + value.detach() loss_value = distance_loss( diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index a7c90521a5a..312be720433 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -5,17 +5,27 @@ from __future__ import annotations +import itertools +from copy import deepcopy from typing import Iterator, List, Optional, Tuple, Union import torch -from tensordict.nn.functional_modules import FunctionalModuleWithBuffers + +from tensordict.nn import make_functional, repopulate_module + +from tensordict.tensordict import TensorDictBase +from torch import nn, Tensor +from torch.nn import Parameter + +from torchrl.modules import SafeModule +from torchrl.modules.utils import Buffer _has_functorch = False try: - import functorch - from functorch._src.make_functional import _swap_state + import functorch as ft # noqa _has_functorch = True + FUNCTORCH_ERR = "" except ImportError: print( "failed to import functorch. TorchRL's features that do not require " @@ -24,12 +34,6 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -from tensordict.tensordict import TensorDict, TensorDictBase -from torch import nn, Tensor -from torch.nn import Parameter - -from torchrl.modules import SafeModule - class LossModule(nn.Module): """A parent class for RL losses. @@ -44,6 +48,7 @@ class LossModule(nn.Module): def __init__(self): super().__init__() self._param_maps = {} + # self.register_forward_pre_hook(_parameters_to_tensordict) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*". @@ -69,376 +74,208 @@ def convert_to_functional( expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, + funs_to_decorate=None, ) -> None: - if _has_functorch: - return self._convert_to_functional_functorch( - module, - module_name, - expand_dim, - create_target_params, - compare_against, - ) - else: - return self._convert_to_functional_native( - module, - module_name, - expand_dim, - create_target_params, - compare_against, - ) - - def _convert_to_functional_functorch( - self, - module: SafeModule, - module_name: str, - expand_dim: Optional[int] = None, - create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, - ) -> None: + if funs_to_decorate is None: + funs_to_decorate = ["forward"] # To make it robust to device casting, we must register list of # tensors as lazy calls to `getattr(self, name_of_tensor)`. # Otherwise, casting the module to a device will keep old references # to uncast tensors - - network_orig = module - if hasattr(module, "make_functional_with_buffers"): - functional_module, ( - _, - module_buffers, - ) = module.make_functional_with_buffers(clone=True) - else: - ( - functional_module, - module_params, - module_buffers, - ) = functorch.make_functional_with_buffers(module) - - for _ in functional_module.parameters(): - # Erase meta params - none_state = [None for _ in module_params + module_buffers] - if hasattr(functional_module, "all_names_map"): - # functorch >= 0.2.0 - _swap_state( - functional_module.stateless_model, - functional_module.all_names_map, - none_state, - ) - else: - # functorch < 0.2.0 - _swap_state( - functional_module.stateless_model, - functional_module.split_names, - none_state, - ) - break - del module_params - - param_name = module_name + "_params" - - # we keep the original parameters and not the copy returned by functorch - params = network_orig.parameters() - - # unless we need to expand them, in that case we'll delete the weights to make sure that the user does not - # run anything with them expecting them to be updated - params = list(params) - module_buffers = list(module_buffers) - + try: + buffer_names = next(itertools.islice(zip(*module.named_buffers()), 1)) + except StopIteration: + buffer_names = () + params = make_functional(module, funs_to_decorate=funs_to_decorate) + functional_module = deepcopy(module) + repopulate_module(module, params) + + params_and_buffers = params + # we transform the buffers in params to make sure they follow the device + # as tensor = nn.Parameter(tensor) keeps its identity when moved to another device + + def create_buffers(tensor): + + if isinstance(tensor, torch.Tensor) and not isinstance( + tensor, (Buffer, nn.Parameter) + ): + return Buffer(tensor, requires_grad=tensor.requires_grad) + return tensor + + # separate params and buffers + params_and_buffers = params_and_buffers.apply(create_buffers) + for key in params_and_buffers.keys(True): + if "_sep_" in key: + raise KeyError( + f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + ) + params_and_buffers_flat = params_and_buffers.flatten_keys("_sep_") + buffers = params_and_buffers_flat.select(*buffer_names) + params = params_and_buffers_flat.exclude(*buffer_names) + + if expand_dim and not _has_functorch: + raise ImportError( + "expanding params is only possible when functorch is installed," + "as this feature requires calls to the vmap operator." + ) if expand_dim: + # Expands the dims of params and buffers. + # If the param already exist in the module, we return a simple expansion of the + # original one. Otherwise, we expand and resample it. + # For buffers, a cloned expansion (or equivalently a repeat) is returned. if compare_against is not None: compare_against = set(compare_against) else: compare_against = set() - for i, p in enumerate(params): - if p in compare_against: - # expanded parameters are 'detached': the parameter will not - # be trained to minimize loss involving this network. - p_out = p.data.expand(expand_dim, *p.shape) + + def _compare_and_expand(param): + + if param in compare_against: + expanded_param = param.data.expand(expand_dim, *param.shape) # the expanded parameter must be sent to device when to() # is called: - self._param_maps[p_out] = p + return expanded_param else: - p_out = p.repeat(expand_dim, *[1 for _ in p.shape]) + p_out = param.repeat(expand_dim, *[1 for _ in param.shape]) p_out = nn.Parameter( p_out.uniform_( p_out.min().item(), p_out.max().item() ).requires_grad_() ) - params[i] = p_out + return p_out - for i, b in enumerate(module_buffers): - b = b.expand(expand_dim, *b.shape).clone() - module_buffers[i] = b - - # # delete weights of original model as they do not correspond to the optimized weights - # network_orig.to("meta") - - params_list = params - set_params = set(self.parameters()) - setattr( - self, - "_" + param_name, - nn.ParameterList( - [ - p - for p in params_list - if isinstance(p, nn.Parameter) and p not in set_params - ] - ), - ) - setattr(self, param_name, params) - - buffer_name = module_name + "_buffers" - # we register each buffer independently - for i, p in enumerate(module_buffers): - _name = module_name + f"_buffer_{i}" - self.register_buffer(_name, p) - # replace buffer by its name - module_buffers[i] = _name - setattr( - self.__class__, - buffer_name, - property(lambda _self: [getattr(_self, _name) for _name in module_buffers]), - ) - - # we set the functional module - setattr(self, module_name, functional_module) - - name_params_target = "_target_" + param_name - name_buffers_target = "_target_" + buffer_name - if create_target_params: - target_params = [p.detach().clone() for p in getattr(self, param_name)] - for i, p in enumerate(target_params): - name = "_".join([name_params_target, str(i)]) - self.register_buffer(name, p) - target_params[i] = name - setattr( - self.__class__, - name_params_target, - property( - lambda _self: [getattr(_self, _name) for _name in target_params] - ), + params_udpated = params.apply( + _compare_and_expand, batch_size=[expand_dim, *params.shape] ) - target_buffers = [p.detach().clone() for p in getattr(self, buffer_name)] - for i, p in enumerate(target_buffers): - name = "_".join([name_buffers_target, str(i)]) - self.register_buffer(name, p) - target_buffers[i] = name - setattr( - self.__class__, - name_buffers_target, - property( - lambda _self: [getattr(_self, _name) for _name in target_buffers] - ), + params = params_udpated + buffers = buffers.apply( + lambda buffer: Buffer(buffer.expand(expand_dim, *buffer.shape).clone()), + batch_size=[expand_dim, *buffers.shape], ) - else: - setattr(self.__class__, name_params_target, None) - setattr(self.__class__, name_buffers_target, None) + params_and_buffers.update(params.unflatten_keys("_sep_")) + params_and_buffers.update(buffers.unflatten_keys("_sep_")) + params_and_buffers.batch_size = params.batch_size - setattr( - self.__class__, - name_params_target[1:], - property(lambda _self: self._target_param_getter(module_name)), - ) - setattr( - self.__class__, - name_buffers_target[1:], - property(lambda _self: self._target_buffer_getter(module_name)), - ) - - def _convert_to_functional_native( - self, - module: SafeModule, - module_name: str, - expand_dim: Optional[int] = None, - create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, - ) -> None: - # To make it robust to device casting, we must register list of - # tensors as lazy calls to `getattr(self, name_of_tensor)`. - # Otherwise, casting the module to a device will keep old references - # to uncast tensors - - network_orig = module - if hasattr(module, "make_functional_with_buffers"): - functional_module, ( - params, - module_buffers, - ) = module.make_functional_with_buffers(clone=True) - else: - ( - functional_module, - params, - module_buffers, - ) = FunctionalModuleWithBuffers._create_from(module) + # self.params_to_map = params_to_map param_name = module_name + "_params" - # params must be retrieved directly because make_functional will copy the content - params_vals = TensorDict( - {name: value for name, value in network_orig.named_parameters()}, [] - ) - # rename params_vals keys to match params: otherwise we'll have to deal with - # module.module.param or such names. We assume that there is a constant prefix - # and that, when sorted, all keys will match. We could check that the values - # do match too. - keys1 = sorted(params.flatten_keys(".").keys()) - keys2 = sorted(params_vals.keys()) - for key1, key2 in zip(keys1, keys2): - params_vals.rename_key(key2, key1) - params = params_vals.unflatten_keys(".") + prev_set_params = set(self.parameters()) - if expand_dim: - raise ImportError( - "expanding params is only possible when functorch is installed," - "as this feature requires calls to the vmap operator." - ) + # register parameters and buffers + for key, parameter in params.items(): + if parameter not in prev_set_params: + setattr(self, "_sep_".join([module_name, key]), parameter) + else: + for _param_name, p in self.named_parameters(): + if parameter is p: + break + else: + raise RuntimeError("parameter not found") + setattr(self, "_sep_".join([module_name, key]), _param_name) + prev_set_buffers = set(self.buffers()) + for key, buffer in buffers.items(): + if buffer not in prev_set_buffers: + self.register_buffer("_sep_".join([module_name, key]), buffer) + else: + for _buffer_name, b in self.named_buffers(): + if buffer is b: + break + else: + raise RuntimeError("buffer not found") + setattr(self, "_sep_".join([module_name, key]), _buffer_name) - params_list = list(params.flatten_keys(".").values()) - set_params = set(self.parameters()) - setattr( - self, - "_" + param_name, - nn.ParameterList( - [ - p - for p in params_list - if isinstance(p, nn.Parameter) and p not in set_params - ] - ), - ) - setattr(self, param_name, params) - - buffer_name = module_name + "_buffers" - buffers_iter = list(module_buffers.flatten_keys(".").items()) - module_buffers_list = [] - for i, (key, value) in enumerate(sorted(buffers_iter)): - _name = module_name + f"_buffer_{i}" - self.register_buffer(_name, value) - # replace buffer by its name - module_buffers_list.append((_name, key)) + setattr(self, "_" + param_name, params_and_buffers) setattr( self.__class__, - buffer_name, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in module_buffers_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), + param_name, + property(lambda _self=self: _self._param_getter(module_name)), ) - # we set the functional module + # set the functional module setattr(self, module_name, functional_module) - name_params_target = "_target_" + param_name - name_buffers_target = "_target_" + buffer_name + # creates a map nn.Parameter name -> expanded parameter name + for key, value in params.items(True, True): + if not isinstance(key, tuple): + key = (key,) + if not isinstance(value, nn.Parameter): + # find the param name + for name, param in self.named_parameters(): + if param.data.data_ptr() == value.data_ptr() and param is not value: + self._param_maps[name] = "_sep_".join([module_name, *key]) + break + else: + raise RuntimeError("did not find matching param.") + + name_params_target = "_target_" + module_name if create_target_params: - target_params = getattr(self, param_name).detach().clone() - target_params_items = sorted(target_params.flatten_keys(".").items()) + target_params = params_and_buffers.detach().clone() + target_params_items = target_params.items(True, True) target_params_list = [] - for i, (key, val) in enumerate(target_params_items): - name = "_".join([name_params_target, str(i)]) - self.register_buffer(name, val) + for (key, val) in target_params_items: + if not isinstance(key, tuple): + key = (key,) + name = "_sep_".join([name_params_target, *key]) + self.register_buffer(name, Buffer(val)) target_params_list.append((name, key)) - setattr( - self.__class__, - name_params_target, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in target_params_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), - ) - - target_buffers = getattr(self, buffer_name).detach().clone() - target_buffers_items = sorted(target_buffers.flatten_keys(".").items()) - target_buffers_list = [] - for i, (key, val) in enumerate(target_buffers_items): - name = "_".join([name_buffers_target, str(i)]) - self.register_buffer(name, val) - target_buffers_list.append((name, key)) - setattr( - self.__class__, - name_buffers_target, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in target_buffers_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), - ) - + setattr(self, name_params_target + "_params", target_params) else: - setattr(self.__class__, name_params_target, None) - setattr(self.__class__, name_buffers_target, None) - - setattr( - self.__class__, - name_params_target[1:], - property(lambda _self: self._target_param_getter(module_name)), - ) + setattr(self, name_params_target + "_params", None) setattr( self.__class__, - name_buffers_target[1:], - property(lambda _self: self._target_buffer_getter(module_name)), + name_params_target[1:] + "_params", + property(lambda _self=self: _self._target_param_getter(module_name)), ) - def _target_param_getter(self, network_name): - target_name = "_target_" + network_name + "_params" + def _param_getter(self, network_name): + name = "_" + network_name + "_params" param_name = network_name + "_params" - if hasattr(self, target_name): - target_params = getattr(self, target_name) - if target_params is not None: - if isinstance(target_params, TensorDictBase): - return target_params - return tuple(target_params) + if name in self.__dict__: + params = getattr(self, name) + if params is not None: + # get targets and update + for key in params.keys(True, True): + if not isinstance(key, tuple): + key = (key,) + value_to_set = getattr(self, "_sep_".join([network_name, *key])) + if isinstance(value_to_set, str): + value_to_set = getattr(self, value_to_set).detach() + params.set(key, value_to_set) + return params else: params = getattr(self, param_name) - if isinstance(params, TensorDictBase): - return params.detach() - else: - # detach params as a surrogate for targets - return tuple(p.detach() for p in params) + return params.detach() else: raise RuntimeError( - f"{self.__class__.__name__} does not have the target param {target_name}" + f"{self.__class__.__name__} does not have the target param {name}" ) - def _target_buffer_getter(self, network_name): - target_name = "_target_" + network_name + "_buffers" - buffer_name = network_name + "_buffers" - if hasattr(self, target_name): - target_buffers = getattr(self, target_name) - if target_buffers is not None: - if isinstance(target_buffers, TensorDictBase): - return target_buffers - return tuple(target_buffers) + def _target_param_getter(self, network_name): + target_name = "_target_" + network_name + "_params" + param_name = network_name + "_params" + if target_name in self.__dict__: + target_params = getattr(self, target_name) + if target_params is not None: + # get targets and update + for key in target_params.keys(True, True): + if not isinstance(key, tuple): + key = (key,) + value_to_set = getattr( + self, "_sep_".join(["_target_" + network_name, *key]) + ) + target_params.set(key, value_to_set) + return target_params else: - buffers = getattr(self, buffer_name) - if isinstance(buffers, TensorDictBase): - return buffers.detach() - else: - return tuple(p.detach() for p in buffers) + params = getattr(self, param_name) + return params.detach() else: raise RuntimeError( - f"{self.__class__.__name__} does not have the target buffer {target_name}" + f"{self.__class__.__name__} does not have the target param {target_name}" ) def _networks(self) -> Iterator[nn.Module]: @@ -455,7 +292,7 @@ def device(self) -> torch.device: def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: - tensor = tensor.to(self.device) + # tensor = tensor.to(self.device) return super().register_buffer(name, tensor, persistent) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: @@ -476,16 +313,23 @@ def reset(self) -> None: def to(self, *args, **kwargs): # get the names of the parameters to map out = super().to(*args, **kwargs) - lists_of_params = { - name: value - for name, value in self.__dict__.items() - if name.endswith("_params") and (type(value) is list) - } - for _, list_of_params in lists_of_params.items(): - for i, param in enumerate(list_of_params): - # we replace the param by the expanded form if needs be - if param in self._param_maps: - list_of_params[i] = self._param_maps[param].data.expand_as(param) + for origin, target in self._param_maps.items(): + origin_value = getattr(self, origin) + target_value = getattr(self, target) + setattr(self, target, origin_value.expand_as(target_value)) + + # lists_of_params = { + # name: value + # for name, value in self.__dict__.items() + # if name.endswith("_params") and isinstance(value, TensorDictBase) + # } + # for list_of_params in lists_of_params.values(): + # for key, param in list(list_of_params.items(True)): + # if isinstance(param, TensorDictBase): + # continue + # # we replace the param by the expanded form if needs be + # if param in self._param_maps: + # list_of_params[key] = self._param_maps[param].data.expand_as(param) return out def cuda(self, device: Optional[Union[int, device]] = None) -> LossModule: diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 5692f109739..42999226605 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,9 +5,12 @@ from __future__ import annotations +from copy import deepcopy + from typing import Tuple import torch +from tensordict.nn import make_functional, repopulate_module from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.modules import SafeModule @@ -46,6 +49,13 @@ def __init__( super().__init__() self.delay_actor = delay_actor self.delay_value = delay_value + + actor_critic = ActorCriticWrapper(actor_network, value_network) + params = make_functional(actor_critic) + self.actor_critic = deepcopy(actor_critic) + repopulate_module(actor_network, params["module", "0"]) + repopulate_module(value_network, params["module", "1"]) + self.convert_to_functional( actor_network, "actor_network", @@ -57,6 +67,8 @@ def __init__( create_target_params=self.delay_value, compare_against=list(actor_network.parameters()), ) + self.actor_critic.module[0] = self.actor_network + self.actor_critic.module[1] = self.value_network self.actor_in_keys = actor_network.in_keys @@ -116,11 +128,11 @@ def _loss_actor( td_copy = self.actor_network( td_copy, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) with hold_out_params(self.value_network_params) as params: td_copy = self.value_network( - td_copy, params=params, buffers=self.value_network_buffers + td_copy, + params=params, ) return -td_copy.get("state_action_value") @@ -133,16 +145,19 @@ def _loss_value( self.value_network( td_copy, params=self.value_network_params, - buffers=self.value_network_buffers, ) pred_val = td_copy.get("state_action_value").squeeze(-1) - actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) - target_params = list(self.target_actor_network_params) + list( - self.target_value_network_params - ) - target_buffers = list(self.target_actor_network_buffers) + list( - self.target_value_network_buffers + actor_critic = self.actor_critic + target_params = TensorDict( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + batch_size=self.target_actor_network_params.batch_size, + device=self.target_actor_network_params.device, ) with set_exploration_mode("mode"): target_value = next_state_value( @@ -150,7 +165,6 @@ def _loss_value( actor_critic, gamma=self.gamma, params=target_params, - buffers=target_buffers, ) # td_error = pred_val - target_value diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 555525161a4..b80cf2854ff 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -4,10 +4,10 @@ import numpy as np import torch + from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch import Tensor - from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import SafeModule from torchrl.objectives import ( @@ -17,6 +17,15 @@ ) from torchrl.objectives.common import LossModule +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + class REDQLoss_deprecated(LossModule): """REDQ Loss module. @@ -66,6 +75,10 @@ def __init__( delay_qvalue: bool = True, gSDE: bool = False, ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) super().__init__() self.convert_to_functional( actor_network, @@ -163,18 +176,12 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: self.actor_network( tensordict_clone, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) with hold_out_params(self.qvalue_network_params) as params: - tensordict_expand = self.qvalue_network( + tensordict_expand = vmap(self.qvalue_network, (None, 0))( tensordict_clone.select(*self.qvalue_network.in_keys), - tensordict_out=TensorDict( - {}, [self.num_qvalue_nets, *tensordict_clone.shape] - ), - params=params, - buffers=self.qvalue_network_buffers, - vmap=True, + params, ) state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) loss_actor = -( @@ -193,12 +200,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: : self.sub_sample_len ].sort()[0] with torch.no_grad(): - selected_q_params = [ - p[selected_models_idx] for p in self.target_qvalue_network_params - ] - selected_q_buffers = [ - b[selected_models_idx] for b in self.target_qvalue_network_buffers - ] + selected_q_params = self.target_qvalue_network_params[selected_models_idx] next_td = step_mdp(tensordict).select( *self.actor_network.in_keys @@ -208,17 +210,13 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: with set_exploration_mode("random"): self.actor_network( next_td, - params=list(self.target_actor_network_params), - buffers=self.target_actor_network_buffers, + params=self.target_actor_network_params, ) sample_log_prob = next_td.get("sample_log_prob") # get q-values - next_td = self.qvalue_network( + next_td = vmap(self.qvalue_network, (None, 0))( next_td, - tensordict_out=TensorDict({}, [self.sub_sample_len, *next_td.shape]), - params=selected_q_params, - buffers=selected_q_buffers, - vmap=True, + selected_q_params, ) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob @@ -231,12 +229,9 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: gamma=self.gamma, pred_next_val=state_value, ) - tensordict_expand = self.qvalue_network( + tensordict_expand = vmap(self.qvalue_network, (None, 0))( tensordict.select(*self.qvalue_network.in_keys), - tensordict_out=TensorDict({}, [self.num_qvalue_nets, *tensordict.shape]), - params=list(self.qvalue_network_params), - buffers=self.qvalue_network_buffers, - vmap=True, + self.qvalue_network_params, ) pred_val = tensordict_expand.get("state_action_value").squeeze(-1) td_error = abs(pred_val - target_value) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 99b82f1404a..444616ac364 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -91,10 +91,10 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: td_copy = tensordict.clone() if td_copy.device != tensordict.device: raise RuntimeError(f"{tensordict} and {td_copy} have different devices") + assert hasattr(self.value_network, "_is_stateless") self.value_network( td_copy, params=self.value_network_params, - buffers=self.value_network_buffers, ) action = tensordict.get("action") @@ -112,7 +112,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network, gamma=self.gamma, params=self.target_value_network_params, - buffers=self.target_value_network_buffers, next_val_key="chosen_action_value", ) priority_tensor = (pred_val_index - target_value).pow(2) @@ -201,7 +200,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: "tensordict as input" ) batch_size = tensordict.batch_size[0] - support = self.value_network.support + support = self.value_network_params["support"] atoms = support.numel() Vmin = support.min().item() Vmax = support.max().item() @@ -220,7 +219,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( td_clone, params=self.value_network_params, - buffers=self.value_network_buffers, ) # Log probabilities log p(s_t, ·; θonline) action_log_softmax = td_clone.get("action_value") @@ -237,7 +235,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( next_td, params=self.value_network_params, - buffers=self.value_network_buffers, ) # Probabilities p(s_t+n, ·; θonline) next_td_action = next_td.get("action") @@ -249,7 +246,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( next_td, params=self.target_value_network_params, - buffers=self.target_value_network_buffers, ) # Probabilities p(s_t+n, ·; θtarget) pns = next_td.get("action_value").exp() # Double-Q probabilities diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 2926e2c667a..130df12fd3d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -65,7 +65,9 @@ def __init__( advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() - self.convert_to_functional(actor, "actor") + self.convert_to_functional( + actor, "actor", funs_to_decorate=["forward", "get_dist"] + ) # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared self.convert_to_functional(critic, "critic", compare_against=self.actor_params) @@ -106,7 +108,8 @@ def _log_weight( tensordict_clone = tensordict.select(*self.actor.in_keys).clone() dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) @@ -136,7 +139,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = advantage + value.detach() loss_value = distance_loss( @@ -363,7 +365,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: previous_dist = self.actor.build_dist_from_params(tensordict_clone) current_dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 70b3d6a3e8d..d8b28bc677b 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -9,18 +9,28 @@ import numpy as np import torch + +from tensordict.nn import TensorDictSequential from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import SafeModule -from torchrl.objectives.common import _has_functorch, LossModule +from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( distance_loss, - hold_out_params, next_state_value as get_next_state_value, ) +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + class REDQLoss(LossModule): """REDQ Loss module. @@ -75,13 +85,16 @@ def __init__( gSDE: bool = False, ): if not _has_functorch: - raise ImportError("REDQ requires functorch to be installed.") + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) super().__init__() self.convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], ) # let's make sure that actor_network has `return_log_prob` set to True @@ -152,25 +165,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len ].sort()[0] - selected_q_params = [ - p[selected_models_idx] for p in self.target_qvalue_network_params - ] - selected_q_buffers = [ - b[selected_models_idx] for b in self.target_qvalue_network_buffers - ] - - actor_params = [ - torch.stack([p1, p2], 0) - for p1, p2 in zip( - self.actor_network_params, self.target_actor_network_params - ) - ] - actor_buffers = [ - torch.stack([p1, p2], 0) - for p1, p2 in zip( - self.actor_network_buffers, self.target_actor_network_buffers - ) - ] + selected_q_params = self.target_qvalue_network_params[selected_models_idx] + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) tensordict_actor_grad = tensordict_select.select( *obs_keys @@ -187,60 +186,58 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "_eps_gSDE", torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), ) - tensordict_actor = self.actor_network( + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( tensordict_actor, - params=actor_params, - buffers=actor_buffers, - vmap=(0, 0, 0), + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].sample_out_key[0] + tensordict_actor_dist = self.actor_network[-1].build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.sample_out_key[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] ) # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.sub_sample_len, *tensordict_actor[1].batch_size) + ) # for next value estimation tensordict_qval = torch.cat( [ - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand( - self.num_qvalue_nets, *tensordict_actor[0].batch_size - ), # for actor loss - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand( - self.sub_sample_len, *tensordict_actor[1].batch_size - ), # for next value estimation - tensordict_select.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, - ), # for qvalue loss + _actor_loss_td, + _next_val_td, + _qval_td, ], 0, ) # cat params - q_params_detach = hold_out_params(self.qvalue_network_params).params - qvalue_params = [ - torch.cat([p1, p2, p3], 0) - for p1, p2, p3 in zip( - q_params_detach, selected_q_params, self.qvalue_network_params - ) - ] - qvalue_buffers = [ - torch.cat([p1, p2, p3], 0) - for p1, p2, p3 in zip( - self.qvalue_network_buffers, - selected_q_buffers, - self.qvalue_network_buffers, - ) - ] - tensordict_qval = self.qvalue_network( + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [q_params_detach, selected_q_params, self.qvalue_network_params], 0 + ) + tensordict_qval = vmap(self.qvalue_network)( tensordict_qval, - tensordict_out=TensorDict({}, tensordict_qval.shape), - params=qvalue_params, - buffers=qvalue_buffers, - vmap=( - 0, - 0, - 0, - ), # TensorDict vmap will take care of expanding the tuple as needed + qvalue_params, ) state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 294f79c50ec..719710de5b6 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -67,9 +67,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = self.advantage_module( tensordict, params=self.critic_params, - buffers=self.critic_buffers, target_params=self.target_critic_params, - target_buffers=self.target_critic_buffers, ) advantage = tensordict.get(self.advantage_key) @@ -77,7 +75,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = self.actor_network( tensordict, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) log_prob = tensordict.get("sample_log_prob") @@ -108,14 +105,12 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: next_value = self.critic( next_td, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = reward + next_value * self.gamma tensordict_select = tensordict.select(*self.critic.in_keys).clone() value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") loss_value = distance_loss( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index bfc5e088813..79ed9cee11c 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,6 +9,7 @@ import numpy as np import torch +from tensordict.nn import make_functional from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor @@ -19,6 +20,15 @@ from ..envs.utils import set_exploration_mode from .common import LossModule +try: + from functorch import vmap + + _has_functorch = True + err = "" +except ImportError as err: + _has_functorch = False + FUNCTORCH_ERROR = str(err) + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -83,6 +93,10 @@ def __init__( delay_qvalue: bool = False, delay_value: bool = False, ) -> None: + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERROR}" + ) super().__init__() # Actor @@ -91,6 +105,10 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, + funs_to_decorate=[ + "forward", + "get_dist", + ], ) # Value @@ -150,14 +168,12 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self.actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) + make_functional(self.actor_critic) @property def device(self) -> torch.device: - for p in self.actor_network_params: - return p.device - for p in self.qvalue_network_params: - return p.device - for p in self.value_network_params: + for p in self.parameters(): return p.device raise RuntimeError( "At least one of the networks of SACLoss must have trainable " "parameters." @@ -198,8 +214,7 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: with set_exploration_mode("random"): dist = self.actor_network.get_dist( tensordict, - params=list(self.actor_network_params), - buffers=list(self.actor_network_buffers), + params=self.actor_network_params, )[0] a_reparm = dist.rsample() # if not self.actor_network.spec.is_in(a_reparm): @@ -208,11 +223,8 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: td_q = tensordict.select(*self.qvalue_network.in_keys) td_q.set("action", a_reparm) - td_q = self.qvalue_network( - td_q, - params=list(self.target_qvalue_network_params), - buffers=list(self.qvalue_network_buffers), - vmap=True, + td_q = vmap(self.qvalue_network, (None, 0))( + td_q, self.target_qvalue_network_params ) min_q_logprob = td_q.get("state_action_value").min(0)[0].squeeze(-1) @@ -226,12 +238,16 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: return self._alpha * log_prob - min_q_logprob def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: - actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) - params = list(self.target_actor_network_params) + list( - self.target_value_network_params - ) - buffers = list(self.target_actor_network_buffers) + list( - self.target_value_network_buffers + actor_critic = self.actor_critic + params = TensorDict( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + [], + _run_checks=False, ) with set_exploration_mode("mode"): target_value = next_state_value( @@ -240,7 +256,6 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: gamma=self.gamma, next_val_key="state_value", params=params, - buffers=buffers, ) # value loss @@ -260,16 +275,8 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: target_chunks = torch.stack(target_value.chunk(self.num_qvalue_nets, dim=0), 0) # if vmap=True, it is assumed that the input tensordict must be cast to the param shape - tensordict_chunks = qvalue_network( - tensordict_chunks, - params=list(self.qvalue_network_params), - buffers=list(self.qvalue_network_buffers), - vmap=( - 0, - 0, - 0, - 0, - ), + tensordict_chunks = vmap(qvalue_network)( + tensordict_chunks, self.qvalue_network_params ) pred_val = tensordict_chunks.get("state_action_value").squeeze(-1) loss_value = distance_loss( @@ -284,18 +291,14 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( td_copy, - params=list(self.value_network_params), - buffers=list(self.value_network_buffers), + params=self.value_network_params, ) pred_val = td_copy.get("state_value").squeeze(-1) - action_dist = self.actor_network.get_dist( + action_dist, *_ = self.actor_network.get_dist( td_copy, - params=list(self.target_actor_network_params), - buffers=list(self.target_actor_network_buffers), - )[ - 0 - ] # resample an action + params=self.target_actor_network_params, + ) # resample an action action = action_dist.rsample() # if not self.actor_network.spec.is_in(action): # action.data.copy_(self.actor_network.spec.project(action.data)) @@ -303,11 +306,9 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: td_copy.set("action", action, inplace=False) qval_net = self.qvalue_network - td_copy = qval_net( + td_copy = vmap(qval_net, (None, 0))( td_copy, - params=list(self.target_qvalue_network_params), - buffers=list(self.target_qvalue_network_buffers), - vmap=True, + self.target_qvalue_network_params, ) min_qval = td_copy.get("state_action_value").squeeze(-1).min(0)[0] diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4f2da57c93a..ec93e4135c1 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. import functools -from collections import OrderedDict from typing import Iterable, Optional, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F @@ -97,7 +96,7 @@ def __init__( # for properties for name in loss_module.__class__.__dict__: if ( - name.startswith("_target_") + name.startswith("target_") and (name.endswith("params") or name.endswith("buffers")) and (getattr(loss_module, name) is not None) ): @@ -106,12 +105,12 @@ def __init__( # for regular lists: raise an exception for name in loss_module.__dict__: if ( - name.startswith("_target_") + name.startswith("target_") and (name.endswith("params") or name.endswith("buffers")) and (getattr(loss_module, name) is not None) ): raise RuntimeError( - "Your module seems to have a _target tensor list contained " + "Your module seems to have a target tensor list contained " "in a non-dynamic structure (such as a list). If the " "module is cast onto a device, the reference to these " "tensors will be lost." @@ -119,10 +118,10 @@ def __init__( if len(_target_names) == 0: raise RuntimeError( - "Did not found any target parameters or buffers in the loss module." + "Did not find any target parameters or buffers in the loss module." ) - _source_names = ["".join(name.split("_target_")) for name in _target_names] + _source_names = ["".join(name.split("target_")) for name in _target_names] for _source in _source_names: try: @@ -140,28 +139,28 @@ def __init__( @property def _targets(self): - return OrderedDict( - {name: getattr(self.loss_module, name) for name in self._target_names} + return TensorDict( + {name: getattr(self.loss_module, name) for name in self._target_names}, + [], ) @property def _sources(self): - return OrderedDict( - {name: getattr(self.loss_module, name) for name in self._source_names} + return TensorDict( + {name: getattr(self.loss_module, name) for name in self._source_names}, + [], ) def init_(self) -> None: - for source, target in zip(self._sources.values(), self._targets.values()): - if isinstance(source, TensorDictBase) and not source.is_empty(): - # native functional modules - source = list(zip(*sorted(source.items())))[1] - target = list(zip(*sorted(target.items())))[1] - elif isinstance(source, TensorDictBase) and source.is_empty(): - continue - for p_source, p_target in zip(source, target): - if p_target.requires_grad: - raise RuntimeError("the target parameter is part of a graph.") - p_target.data.copy_(p_source.data) + for key, source in self._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target = self._targets[key] + # for p_source, p_target in zip(source, target): + if target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + target.data.copy_(source.data) self.initialized = True def step(self) -> None: @@ -170,29 +169,25 @@ def step(self) -> None: f"{self.__class__.__name__} must be " f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) - - for source, target in zip(self._sources.values(), self._targets.values()): - if isinstance(source, TensorDictBase) and not source.is_empty(): - # native functional modules - source = list(zip(*sorted(source.items())))[1] - target = list(zip(*sorted(target.items())))[1] - elif isinstance(source, TensorDictBase) and source.is_empty(): - continue - for p_source, p_target in zip(source, target): - if p_target.requires_grad: - raise RuntimeError("the target parameter is part of a graph.") - if p_source.is_leaf: - self._step(p_source, p_target) - else: - p_target.copy_(p_source) + for key, source in self._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target = self._targets[key] + if target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + if target.is_leaf: + self._step(source, target) + else: + target.copy_(source) def _step(self, p_source: Tensor, p_target: Tensor) -> None: raise NotImplementedError def __repr__(self) -> str: string = ( - f"{self.__class__.__name__}(sources={list(self._sources)}, targets=" - f"{list(self._targets)})" + f"{self.__class__.__name__}(sources={self._sources}, targets=" + f"{self._targets})" ) return string @@ -281,7 +276,10 @@ class hold_out_params(_context_manager): """Context manager to hold a list of parameters out of a computational graph.""" def __init__(self, params: Iterable[Tensor]) -> None: - self.params = tuple(p.detach() for p in params) + if isinstance(params, TensorDictBase): + self.params = params.detach() + else: + self.params = tuple(p.detach() for p in params) def __enter__(self) -> None: return self.params diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 6ee6ef3503b..4f558feecef 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -46,20 +46,24 @@ def __init__( super().__init__() self.register_buffer("gamma", torch.tensor(gamma)) self.value_network = value_network - self.is_functional = value_network.is_functional self.average_rewards = average_rewards self.gradient_mode = gradient_mode self.value_key = value_key + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -93,8 +97,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -107,10 +109,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) @@ -154,21 +152,25 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("lmbda", torch.tensor(lmbda)) self.value_network = value_network - self.is_functional = value_network.is_functional self.vectorized = vectorized self.average_rewards = average_rewards self.gradient_mode = gradient_mode self.value_key = value_key + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -204,8 +206,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -218,10 +218,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) @@ -270,19 +266,23 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("lmbda", torch.tensor(lmbda)) self.value_network = value_network - self.is_functional = value_network.is_functional self.average_rewards = average_rewards self.gradient_mode = gradient_mode + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -316,8 +316,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get("state_value") @@ -330,10 +328,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get("state_value") done = tensordict.get("done") From 5cd52c2cfa8d85a684c86ca33040cc3b7d314d57 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 30 Nov 2022 16:01:00 -0600 Subject: [PATCH 12/14] [Minor] ubuntu-20.04 for documentation build --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 6bf8219de5c..3ccd4dc4d18 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -7,7 +7,7 @@ on: workflow_dispatch: jobs: build_docs_job: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: matrix: include: From f5579033f68f2e7ccb288fbd9b3c3c61458d0e3d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Nov 2022 18:47:01 -0600 Subject: [PATCH 13/14] [BugFix] Fix TorchRL demo tutorial (#721) --- test/test_collector.py | 5 +- torchrl/collectors/collectors.py | 40 ++++++------ tutorials/sphinx-tutorials/coding_ddpg.py | 8 ++- tutorials/sphinx-tutorials/coding_dqn.py | 63 ++++++++---------- tutorials/sphinx-tutorials/multi_task.py | 9 ++- .../sphinx-tutorials/tensordict_module.py | 64 ++++++++++--------- .../sphinx-tutorials/tensordict_tutorial.py | 7 ++ tutorials/sphinx-tutorials/torch_envs.py | 5 ++ tutorials/sphinx-tutorials/torchrl_demo.py | 62 +++++++++++------- 9 files changed, 148 insertions(+), 115 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 4b8b70d8444..769ef221ae6 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -17,6 +17,7 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import seed_generator @@ -980,12 +981,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): if collector_class is not SyncDataCollector: assert all( - isinstance(p, SafeModule) for p in collector._policy_dict.values() + isinstance(p, TensorDictModule) for p in collector._policy_dict.values() ) assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) assert all(p.module is policy for p in collector._policy_dict.values()) else: - assert isinstance(collector.policy, SafeModule) + assert isinstance(collector.policy, TensorDictModule) assert collector.policy.out_keys == out_keys assert collector.policy.module is policy diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 93a117821de..8170e7723a4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn as nn +from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset @@ -29,7 +30,6 @@ from ..data.utils import CloudpickleWrapper, DEVICE_TYPING from ..envs.common import EnvBase from ..envs.vec_env import _BatchedEnv -from ..modules.tensordict_module import SafeModule, SafeProbabilisticModule from .utils import split_trajectories _TIMEOUT = 1.0 @@ -84,21 +84,21 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): sig = inspect.signature(policy.forward) - if isinstance(policy, SafeModule) or ( + if isinstance(policy, TensorDictModule) or ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - # if the policy is a SafeModule or takes a single argument and defines + # if the policy is a TensorDictModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): - # if it's not a SafeModule, and in_keys and out_keys are not defined then + # if it's not a TensorDictModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but policy is not a SafeModule or + # if in_keys or out_keys were defined but policy is not a TensorDictModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -107,7 +107,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module): "should take a single argument of type TensorDict to policy.forward and define " "both in_keys and out_keys. Alternatively, policy.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the policy with a SafeModule." + "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." ) @@ -116,13 +116,13 @@ def _get_policy_and_device( self, policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[SafeProbabilisticModule, torch.device, Union[None, Callable[[], dict]]]: + ) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. From a policy and a device, assigns the self.device attribute to @@ -133,7 +133,7 @@ def _get_policy_and_device( create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator - policy (SafeProbabilisticModule, optional): a policy to be used + policy (TensorDictModule, optional): a policy to be used device (int, str or torch.device, optional): device where to place the policy observation_spec (TensorSpec, optional): spec of the observations @@ -161,13 +161,13 @@ def _get_policy_and_device( # callables should be supported as policies. if not _policy_is_tensordict_compatible(policy): # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with SafeModule + # so we attempt to auto-wrap policy with TensorDictModule if observation_spec is None: raise ValueError( "Unable to read observation_spec from the environment. This is " "required to check compatibility of the environment and policy " "since the policy is a nn.Module that operates on tensors " - "rather than a SafeModule or a nn.Module that accepts a " + "rather than a TensorDictModule or a nn.Module that accepts a " "TensorDict as input and defines in_keys and out_keys." ) sig = inspect.signature(policy.forward) @@ -181,18 +181,18 @@ def _get_policy_and_device( if isinstance(output, tuple): out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) - policy = SafeModule( + policy = TensorDictModule( policy, in_keys=list(sig.parameters), out_keys=out_keys ) else: raise TypeError( "Arguments to policy.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your policy with a SafeModule then the arguments " + "wrap your policy with a TensorDictModule then the arguments " "to policy.forward must correspond one-to-one with entries in " "env.observation_spec that are prefixed with 'next_'. For more " "complex behaviour and more control you can consider writing " - "your own SafeModule." + "your own TensorDictModule." ) try: @@ -305,7 +305,7 @@ def __init__( ], # noqa: F821 policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -517,7 +517,7 @@ def iterator(self) -> Iterator[TensorDictBase]: def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): - # some keys may be absent -- SafeModule is resilient to missing keys + # some keys may be absent -- TensorDictModule is resilient to missing keys td = td.select(*self.policy.in_keys, strict=False) if self._td_policy is None: self._td_policy = td.to(policy_device) @@ -717,7 +717,7 @@ class _MultiDataCollector(_DataCollector): Args: create_env_fn (list of Callabled): list of Callables, each returning an instance of EnvBase - policy (Callable, optional): Instance of SafeProbabilisticModule class. + policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are sent to the @@ -776,7 +776,7 @@ def __init__( create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -1303,7 +1303,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of SafeProbabilisticModule class. + policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of @@ -1358,7 +1358,7 @@ def __init__( create_env_fn: Callable[[], EnvBase], policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index e12af352f43..67cdaab8e4b 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -32,6 +32,12 @@ # Make all the necessary imports for training +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + from copy import deepcopy from typing import Optional @@ -40,6 +46,7 @@ import torch.cuda import tqdm from matplotlib import pyplot as plt +from tensordict.nn import TensorDictModule from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import ( @@ -64,7 +71,6 @@ MLP, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, - TensorDictModule, ValueOperator, ) from torchrl.modules.distributions.continuous import TanhDelta diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index ef21e18f4eb..78003cbd9b2 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -42,11 +42,19 @@ # to provide a high-level illustration of TorchRL features in the context # of this algorithm. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch import tqdm +from functorch import vmap from IPython import display from matplotlib import pyplot as plt from tensordict import TensorDict +from tensordict.nn import get_functional from torch import nn from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -251,18 +259,16 @@ def make_model(): print("Q-value network results:", tensordict) # make functional - factor, (_, buffers) = actor.make_functional_with_buffers(clone=True, native=True) - # making functional creates a copy of the params, which we don't want (i.e. we want the parameters from `actor` to match those in the params object), - # hence we create the params object in a second step - params = TensorDict({k: v for k, v in net.named_parameters()}, []).unflatten_keys( - "." - ) + # here's an explicit way of creating the parameters and buffer tensordict. + # Alternatively, we could have used `params = make_functional(actor)` from + # tensordict.nn + params = TensorDict({k: v for k, v in actor.named_parameters()}, []) + buffers = TensorDict({k: v for k, v in actor.named_buffers()}, []) + params = params.update(buffers).unflatten_keys(".") # creates a nested TensorDict + factor = get_functional(actor) # creating the target parameters is fairly easy with tensordict: - params_target, buffers_target = ( - params.to_tensordict().detach(), - buffers.to_tensordict().detach(), - ) + (params_target,) = (params.to_tensordict().detach(),) # we wrap our actor in an EGreedyWrapper for data collection actor_explore = EGreedyWrapper( @@ -272,7 +278,7 @@ def make_model(): eps_end=eps_greedy_val_env, ) - return factor, actor, actor_explore, params, buffers, params_target, buffers_target + return factor, actor, actor_explore, params, params_target ############################################################################### @@ -286,14 +292,10 @@ def make_model(): actor, actor_explore, params, - buffers, params_target, - buffers_target, ) = make_model() params_flat = params.flatten_keys(".") -buffers_flat = buffers.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") -buffers_target_flat = buffers_target.flatten_keys(".") ############################################################################### # Regular DQN @@ -393,7 +395,7 @@ def make_model(): # Compute action value (of the action actually taken) at time t sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = factor(sampled_data_out, params=params, buffers=buffers) + sampled_data_out = factor(sampled_data_out, params=params) action_value = sampled_data_out["action_value"] action_value = (action_value * action.to(action_value.dtype)).sum(-1) with torch.no_grad(): @@ -402,7 +404,6 @@ def make_model(): next_value = factor( tdstep.select(*actor.in_keys), params=params_target, - buffers=buffers_target, )["chosen_action_value"].squeeze(-1) exp_value = reward + gamma * next_value * (1 - done) assert exp_value.shape == action_value.shape @@ -420,9 +421,6 @@ def make_model(): for (key, p1) in params_flat.items(): p2 = params_target_flat[key] params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) - for (key, p1) in buffers_flat.items(): - p2 = buffers_target_flat[key] - buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) pbar.set_description( f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}" @@ -513,7 +511,7 @@ def make_model(): "grad_vals": grad_vals, "traj_lengths_training": traj_lengths, "traj_count": traj_count, - "weights": (params, buffers), + "weights": (params,), }, "saved_results_td0.pt", ) @@ -548,14 +546,10 @@ def make_model(): actor, actor_explore, params, - buffers, params_target, - buffers_target, ) = make_model() params_flat = params.flatten_keys(".") -buffers_flat = buffers.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") -buffers_target_flat = buffers_target.flatten_keys(".") ############################################################################### @@ -632,19 +626,15 @@ def make_model(): action = sampled_data["action"].clone() sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = factor( - sampled_data_out, params=params, buffers=buffers, vmap=(None, None, 0) - ) + sampled_data_out = vmap(factor, (0, None))(sampled_data_out, params) action_value = sampled_data_out["action_value"] action_value = (action_value * action.to(action_value.dtype)).sum(-1, True) with torch.no_grad(): tdstep = step_mdp(sampled_data) - next_value = factor( - tdstep.select(*actor.in_keys), - params=params_target, - buffers=buffers_target, - vmap=(None, None, 0), - )["chosen_action_value"] + next_value = vmap(factor, (0, None))( + tdstep.select(*actor.in_keys), params + ) + next_value = next_value["chosen_action_value"] error = vec_td_lambda_advantage_estimate( gamma, lmbda, @@ -671,9 +661,6 @@ def make_model(): for (key, p1) in params_flat.items(): p2 = params_target_flat[key] params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) - for (key, p1) in buffers_flat.items(): - p2 = buffers_target_flat[key] - buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) pbar.set_description( f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}" @@ -765,7 +752,7 @@ def make_model(): "grad_vals": grad_vals, "traj_lengths_training": traj_lengths, "traj_count": traj_count, - "weights": (params, buffers), + "weights": (params,), }, "saved_results_tdlambda.pt", ) diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 59aa6c76d68..9896d616fbc 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -9,14 +9,21 @@ # can compute actions in diverse settings using a distinct set of weights. # You will also be able to execute diverse environments in parallel. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn ############################################################################## from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import MLP, TensorDictModule, TensorDictSequential +from torchrl.modules import MLP ############################################################################### # We design two environments, one humanoid that must complete the stand task diff --git a/tutorials/sphinx-tutorials/tensordict_module.py b/tutorials/sphinx-tutorials/tensordict_module.py index 648b92909ac..c4102b35208 100644 --- a/tutorials/sphinx-tutorials/tensordict_module.py +++ b/tutorials/sphinx-tutorials/tensordict_module.py @@ -6,7 +6,7 @@ """ ############################################################################## # For a convenient usage of the ``TensorDict`` class with ``nn.Module``, -# TorchRL provides an interface between the two named ``TensorDictModule``. +# :obj:`tensordict` provides an interface between the two named ``TensorDictModule``. # The ``TensorDictModule`` class is an ``nn.Module`` that takes a # ``TensorDict`` as input when called. # It is up to the user to define the keys to be read as input and output. @@ -14,10 +14,16 @@ # TensorDictModule by examples # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch import torch.nn as nn from tensordict import TensorDict -from torchrl.modules import TensorDictModule, TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential ############################################################################### # Example 1: Simple usage @@ -143,10 +149,10 @@ def forward(self, x): ############################################################################### # Example 5: Compatibility with functorch # ----------------------------------------- -# ``TensorDictModule`` comes with its own ``make_functional_with_buffers`` -# method to make it functional (you should not be using -# ``functorch.make_functional_with_buffers(tensordictmodule)``, that will -# not work in general). +# tensordict.nn is compatible with functorch. It also comes with its own functional +# utilities. Let us have a look: + +import functorch tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) @@ -155,29 +161,39 @@ def forward(self, x): in_keys=["a"], out_keys=["output_1", "output_2"], ) -func, (params, buffers) = splitlinear.make_functional_with_buffers() -func(tensordict, params=params, buffers=buffers) +func, params, buffers = functorch.make_functional_with_buffers(splitlinear) +print(func(params, buffers, tensordict)) + +############################################################################### +# This can be used with the vmap operator. For example, we use 3 replicas of the +# params and buffers and execute a vectorized map over these for a single batch +# of data: + +params_expand = [p.expand(3, *p.shape) for p in params] +buffers_expand = [p.expand(3, *p.shape) for p in buffers] +print(functorch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict)) ############################################################################### -# We can also use the ``vmap`` operator, here's an example of -# model ensembling with it: +# We can also use the native :obj:`get_functional()` function from tensordict.nn, +# which modifies the module to make it accept the parameters as regular inputs: + +from tensordict.nn import make_functional tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) num_models = 10 model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"]) -fmodel, (params, buffers) = model.make_functional_with_buffers() -params = [torch.randn(num_models, *p.shape, device=p.device) for p in params] -buffers = [torch.randn(num_models, *b.shape, device=b.device) for b in buffers] -result_td = fmodel(tensordict, params=params, buffers=buffers, vmap=True) +params = make_functional(model) +# we stack two groups of parameters to show the vmap usage: +params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0) +result_td = functorch.vmap(model, (None, 0))(tensordict, params) print("the output tensordict shape is: ", result_td.shape) +from tensordict.nn import ProbabilisticTensorDictModule + ############################################################################### # Do's and don't with TensorDictModule # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Don't use ``nn.Module`` wrappers with ``TensorDictModule`` componants. -# This would break some of ``TensorDictModule`` features such as ``functorch`` -# compatibility. # # Don't use ``nn.Sequence``, similar to ``nn.Module``, it would break features # such as ``functorch`` compatibility. Do use ``TensorDictSequential`` instead. @@ -189,14 +205,6 @@ def forward(self, x): # # tensordict_out = module(tensordict) # don't! # -# Don't use ``make_functional_with_buffers`` from ``functorch`` directly but -# use ``TensorDictModule.make_functional_with_buffers`` instead. -# -# TensorDictModule for RL -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# TorchRL provides a few RL-specific ``TensorDictModule`` instances that -# serves domain-specific needs. -# # ``ProbabilisticTensorDictModule`` # ---------------------------------- # ``ProbabilisticTensorDictModule`` is a special case of a ``TensorDictModule`` @@ -214,11 +222,7 @@ def forward(self, x): # One can find the parameters in the output tensordict as well as the log # probability if needed. -from torchrl.modules import ( - NormalParamWrapper, - ProbabilisticTensorDictModule, - TanhNormal, -) +from torchrl.modules import NormalParamWrapper, TanhNormal td = TensorDict( {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, diff --git a/tutorials/sphinx-tutorials/tensordict_tutorial.py b/tutorials/sphinx-tutorials/tensordict_tutorial.py index e856f754f9f..ad50c6a3d1f 100644 --- a/tutorials/sphinx-tutorials/tensordict_tutorial.py +++ b/tutorials/sphinx-tutorials/tensordict_tutorial.py @@ -84,6 +84,12 @@ # However to achieve this you would need to write a complicated collate # function that make sure that every modality is aggregated properly. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + def collate_dict_fn(dict_list): final_dict = {} @@ -123,6 +129,7 @@ def collate_dict_fn(dict_list): # TensorDict structure # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + import torch ############################################################################### diff --git a/tutorials/sphinx-tutorials/torch_envs.py b/tutorials/sphinx-tutorials/torch_envs.py index 27512b85e20..5f4f70172b2 100644 --- a/tutorials/sphinx-tutorials/torch_envs.py +++ b/tutorials/sphinx-tutorials/torch_envs.py @@ -25,6 +25,11 @@ # will pass the arguments and keyword arguments to the root library builder. # # With gym, it means that building an environment is as easy as: +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore import torch from matplotlib import pyplot as plt diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 0dd2fbb3236..f7ac96f35e1 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -124,6 +124,12 @@ # TensorDict # ------------------------------ +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch from tensordict import TensorDict @@ -172,7 +178,9 @@ # Here are some other functionalities of TensorDict. print( - "view(-1): ", tensordict.view(-1).batch_size, tensordict.view(-1).get("key 1").shape + "view(-1): ", + tensordict.view(-1).batch_size, + tensordict.view(-1).get("key 1").shape, ) print("to device: ", tensordict.to("cpu")) @@ -348,7 +356,8 @@ from torchrl.envs import ParallelEnv base_env = ParallelEnv( - 4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) + 4, + lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), ) env = TransformedEnv( base_env, Compose(NoopResetEnv(3), ToTensorImage()) @@ -384,7 +393,10 @@ # Example of a CNN model: cnn = ConvNet( - num_cells=[32, 64], kernel_sizes=[8, 4], strides=[2, 1], aggregator_class=SquashDims + num_cells=[32, 64], + kernel_sizes=[8, 4], + strides=[2, 1], + aggregator_class=SquashDims, ) print(cnn) print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed @@ -393,7 +405,7 @@ # TensorDictModules # ------------------------------ -from torchrl.modules import TensorDictModule +from tensordict.nn import TensorDictModule tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10]) module = nn.Linear(3, 4) @@ -405,7 +417,7 @@ # Sequences of Modules # ------------------------------ -from torchrl.modules import TensorDictSequential +from tensordict.nn import TensorDictSequential backbone_module = nn.Linear(5, 3) backbone = TensorDictModule( @@ -446,20 +458,21 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -fsequence, (params, buffers) = sequence.make_functional_with_buffers() -len(list(fsequence.parameters())) # functional modules have no parameters +from tensordict.nn import make_functional + +params = make_functional(sequence) +len(list(sequence.parameters())) # functional modules have no parameters ############################################################################### -fsequence(tensordict, params=params, buffers=buffers) +sequence(tensordict, params) ############################################################################### -params_expand = [p.expand(4, *p.shape) for p in params] -buffers_expand = [b.expand(4, *b.shape) for b in buffers] -tensordict_exp = fsequence( - tensordict, params=params_expand, buffers=buffers, vmap=(0, 0, None) -) +import functorch + +params_expand = params.expand(4) +tensordict_exp = functorch.vmap(sequence, (None, 0))(tensordict, params_expand) print(tensordict_exp) ############################################################################### @@ -468,10 +481,11 @@ torch.manual_seed(0) from torchrl.data import NdBoundedTensorSpec +from torchrl.modules import SafeModule spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) -module = TensorDictModule( +module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True ) tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) @@ -491,14 +505,12 @@ tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(tensordict) # action is the default value +from tensordict.nn import ProbabilisticTensorDictModule + ############################################################################### # Probabilistic modules -from torchrl.modules import ( - NormalParamWrapper, - ProbabilisticTensorDictModule, - TanhNormal, -) +from torchrl.modules import NormalParamWrapper, TanhNormal td = TensorDict( {"input": torch.randn(3, 5)}, @@ -572,7 +584,7 @@ action_spec = env.action_spec actor_module = nn.Linear(3, 1) -actor = TensorDictModule( +actor = SafeModule( actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"] ) @@ -628,6 +640,8 @@ (tensordict_rollout == tensordicts_prealloc).all() +from tensordict.nn import TensorDictModule + # Collectors # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -637,7 +651,6 @@ from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.modules import TensorDictModule # EnvCreator makes sure that we can send a lambda function from process to process parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1"))) @@ -666,6 +679,8 @@ print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) +collector.shutdown() +del collector ############################################################################### @@ -685,6 +700,7 @@ print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) +collector.shutdown() del collector ############################################################################### @@ -728,11 +744,11 @@ def forward(self, obs, action): ############################################################################### -loss_td +print(loss_td) ############################################################################### -tensordict +print(tensordict) ############################################################################### # State of the Library From 9a5f08b68bc3a635d04c275fb1f63089b7273c85 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 1 Dec 2022 13:29:33 +0000 Subject: [PATCH 14/14] [Doc] Update tutorial links in readme (#724) --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 149b85c4e0b..050c1956796 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ Here's another example of an off-policy training loop in TorchRL (assuming that ``` -Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information. +Check our TorchRL-specific [TensorDict tutorial](https://pytorch.org/rl/tutorials/tensordict_tutorial.html) for more information. The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! @@ -141,7 +141,7 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) ``` - The corresponding [tutorial](tutorials/tensordictmodule.ipynb) provides more context about its features. + The corresponding [tutorial](https://pytorch.org/rl/tutorials/tensordict_module.html) provides more context about its features. @@ -358,8 +358,8 @@ A series of [examples](examples/) are provided with an illustrative purpose: and many more to come! -We also provide [tutorials and demos](tutorials) that give a sense of what the -library can do. +We also provide [tutorials and demos](tutorials/README.md) that give a sense of +what the library can do. ## Installation Create a conda environment where the packages will be installed.