Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
```diff
- obs, done = env.reset()
+ tensordict = env.reset()
policy = TensorDictModule(
policy = SafeModule(
model,
in_keys=["observation_pixels", "observation_vector"],
out_keys=["action"],
Expand Down Expand Up @@ -106,14 +106,14 @@ Here's another example of an off-policy training loop in TorchRL (assuming that

Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.

The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!

<details>
<summary>Code</summary>

```diff
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
Expand All @@ -122,19 +122,19 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
+ out = tensordict["out"]
```

The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
The `SafeSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
```python
encoder_module = TransformerEncoder(...)
encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
encoder = SafeModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
decoder_module = TransformerDecoder(...)
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
transformer = TensorDictSequential(encoder, decoder)
decoder = SafeModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
transformer = SafeSequential(encoder, decoder)
assert transformer.in_keys == ["src", "src_mask", "tgt"]
assert transformer.out_keys == ["memory", "output"]
```

`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
`SafeSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
```python
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
Expand Down Expand Up @@ -261,9 +261,9 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
# Wrap it in a TensorDictModule, indicating what key to read in and where to
# Wrap it in a SafeModule, indicating what key to read in and where to
# write out the output
common_module = TensorDictModule(
common_module = SafeModule(
common_module,
in_keys=["pixels"],
out_keys=["hidden"],
Expand All @@ -277,10 +277,10 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
activation=nn.ELU,
)
)
# Wrap the nn.Module in a ProbabilisticTensorDictModule, indicating how
# Wrap the nn.Module in a SafeProbabilisticModule, indicating how
# to build the torch.distribution.Distribution object and what to do with it
policy_module = ProbabilisticTensorDictModule( # stochastic policy
TensorDictModule(
policy_module = SafeProbabilisticModule( # stochastic policy
SafeModule(
policy_module,
in_keys=["hidden"],
out_keys=["loc", "scale"],
Expand Down Expand Up @@ -409,7 +409,7 @@ pip3 install torchrl
This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one
should install the library locally (see below).

The **nightly build** can be installed via
The **nightly build** can be installed via
```
pip install torchrl-nightly
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ With these, the following methods are implemented:
having reproducible results.
- :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for
a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`.
The policy should be coded using a :obj:`TensorDictModule` (or any other
The policy should be coded using a :obj:`SafeModule` (or any other
:obj:`TensorDict`-compatible module).


Expand Down
7 changes: 3 additions & 4 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ TensorDict modules
:toctree: generated/
:template: rl_template_noinherit.rst

TensorDictModule
ProbabilisticTensorDictModule
TensorDictSequential
TensorDictModuleWrapper
SafeModule
SafeProbabilisticModule
SafeSequential
Actor
ProbabilisticActor
ValueOperator
Expand Down
2 changes: 1 addition & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ def test_imports():
) # noqa: F401
from torchrl.envs import Transform, TransformedEnv # noqa: F401
from torchrl.envs.gym_like import GymLikeEnv # noqa: F401
from torchrl.modules import TensorDictModule # noqa: F401
from torchrl.modules import SafeModule # noqa: F401
from torchrl.objectives.common import LossModule # noqa: F401
15 changes: 5 additions & 10 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@
from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm
from torchrl.modules import (
Actor,
LSTMNet,
OrnsteinUhlenbeckProcessWrapper,
TensorDictModule,
)
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule

# torch.set_default_dtype(torch.double)

Expand Down Expand Up @@ -754,7 +749,7 @@ def create_env():
return ContinuousActionVecMockEnv()

n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1]
policy = TensorDictModule(
policy = SafeModule(
torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"]
)
policy(create_env().reset())
Expand Down Expand Up @@ -898,7 +893,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec),
)

policy = TensorDictModule(**policy_kwargs)
policy = SafeModule(**policy_kwargs)

env_maker = lambda: GymEnv(PENDULUM_VERSIONED)

Expand Down Expand Up @@ -985,12 +980,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker):

if collector_class is not SyncDataCollector:
assert all(
isinstance(p, TensorDictModule) for p in collector._policy_dict.values()
isinstance(p, SafeModule) for p in collector._policy_dict.values()
)
assert all(p.out_keys == out_keys for p in collector._policy_dict.values())
assert all(p.module is policy for p in collector._policy_dict.values())
else:
assert isinstance(collector.policy, TensorDictModule)
assert isinstance(collector.policy, SafeModule)
assert collector.policy.out_keys == out_keys
assert collector.policy.module is policy

Expand Down
66 changes: 26 additions & 40 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
from copy import deepcopy

from torchrl.modules.functional_modules import FunctionalModuleWithBuffers
from tensordict.nn.functional_modules import FunctionalModuleWithBuffers

_has_functorch = True
try:
Expand Down Expand Up @@ -41,10 +41,10 @@
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
from torchrl.modules import (
DistributionalQValueActor,
ProbabilisticTensorDictModule,
QValueActor,
TensorDictModule,
TensorDictSequential,
SafeModule,
SafeProbabilisticModule,
SafeSequential,
WorldModelWrapper,
)
from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
Expand Down Expand Up @@ -777,9 +777,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
spec=CompositeSpec(action=action_spec, loc=None, scale=None),
module=module,
Expand Down Expand Up @@ -1096,9 +1094,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
Expand Down Expand Up @@ -1151,13 +1147,9 @@ def __init__(self):
def forward(self, hidden, act):
return self.linear(torch.cat([hidden, act], -1))

common = TensorDictModule(
CommonClass(), in_keys=["observation"], out_keys=["hidden"]
)
common = SafeModule(CommonClass(), in_keys=["observation"], out_keys=["hidden"])
actor_subnet = ProbabilisticActor(
TensorDictModule(
ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]
),
SafeModule(ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]),
dist_in_keys=["loc", "scale"],
distribution_class=TanhNormal,
return_log_prob=True,
Expand Down Expand Up @@ -1528,9 +1520,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
Expand Down Expand Up @@ -1763,9 +1753,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
Expand Down Expand Up @@ -1989,9 +1977,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value):
gamma = 0.9
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor_net = ProbabilisticActor(
module,
distribution_class=TanhNormal,
Expand Down Expand Up @@ -2138,7 +2124,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20

# World Model and reward model
rssm_rollout = RSSMRollout(
TensorDictModule(
SafeModule(
rssm_prior,
in_keys=["state", "belief", "action"],
out_keys=[
Expand All @@ -2148,7 +2134,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
("next", "belief"),
],
),
TensorDictModule(
SafeModule(
rssm_posterior,
in_keys=[("next", "belief"), ("next", "encoded_latents")],
out_keys=[
Expand All @@ -2162,20 +2148,20 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU
)
# World Model and reward model
world_modeler = TensorDictSequential(
TensorDictModule(
world_modeler = SafeSequential(
SafeModule(
obs_encoder,
in_keys=[("next", "pixels")],
out_keys=[("next", "encoded_latents")],
),
rssm_rollout,
TensorDictModule(
SafeModule(
obs_decoder,
in_keys=[("next", "state"), ("next", "belief")],
out_keys=[("next", "reco_pixels")],
),
)
reward_module = TensorDictModule(
reward_module = SafeModule(
reward_module,
in_keys=[("next", "state"), ("next", "belief")],
out_keys=["reward"],
Expand Down Expand Up @@ -2209,8 +2195,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
reward_module = MLP(
out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU
)
transition_model = TensorDictSequential(
TensorDictModule(
transition_model = SafeSequential(
SafeModule(
rssm_prior,
in_keys=["state", "belief", "action"],
out_keys=[
Expand All @@ -2221,7 +2207,7 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
],
),
)
reward_model = TensorDictModule(
reward_model = SafeModule(
reward_module,
in_keys=["state", "belief"],
out_keys=["reward"],
Expand Down Expand Up @@ -2255,8 +2241,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
num_cells=mlp_num_units,
activation_class=nn.ELU,
)
actor_model = ProbabilisticTensorDictModule(
TensorDictModule(
actor_model = SafeProbabilisticModule(
SafeModule(
actor_module,
in_keys=["state", "belief"],
out_keys=["loc", "scale"],
Expand All @@ -2278,7 +2264,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
return actor_model

def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
value_model = TensorDictModule(
value_model = SafeModule(
MLP(
out_features=1,
depth=3,
Expand Down Expand Up @@ -2380,7 +2366,7 @@ def test_dreamer_env(self, device, imagination_horizon, discount_loss):
# test reconstruction
with pytest.raises(ValueError, match="No observation decoder provided"):
mb_env.decode_obs(rollout)
mb_env.obs_decoder = TensorDictModule(
mb_env.obs_decoder = SafeModule(
nn.LazyLinear(4, device=device),
in_keys=["state"],
out_keys=["reco_observation"],
Expand Down Expand Up @@ -2896,13 +2882,13 @@ def test_shared_params(dest, expected_dtype, expected_device):
if torch.cuda.device_count() == 0 and dest == "cuda":
pytest.skip("no cuda device available")
module_hidden = torch.nn.Linear(4, 4)
td_module_hidden = TensorDictModule(
td_module_hidden = SafeModule(
module=module_hidden,
spec=None,
in_keys=["observation"],
out_keys=["hidden"],
)
module_action = TensorDictModule(
module_action = SafeModule(
NormalParamWrapper(torch.nn.Linear(4, 8)),
in_keys=["hidden"],
out_keys=["loc", "scale"],
Expand Down
Loading