diff --git a/test/test_modules.py b/test/test_modules.py index 25bcd07028f..f3bb917aca7 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -36,6 +36,8 @@ RSSMRollout, ) from torchrl.modules.models.utils import SquashDims +from torchrl.modules.planners.mppi import MPPIPlanner +from torchrl.objectives.value import TDLambdaEstimate @pytest.fixture @@ -437,9 +439,9 @@ def test_lstm_net_nobatch(device, out_features, hidden_size): torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("batch_size", [3, 5]) class TestPlanner: - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("batch_size", [3, 5]) def test_CEM_model_free_env(self, device, batch_size, seed=1): env = MockBatchedUnLockedEnv(device=device) torch.manual_seed(seed) @@ -448,13 +450,48 @@ def test_CEM_model_free_env(self, device, batch_size, seed=1): planning_horizon=10, optim_steps=2, num_candidates=100, - num_top_k_candidates=2, + top_k=2, ) td = env.reset(TensorDict({}, batch_size=batch_size).to(device)) td_copy = td.clone() td = planner(td) - assert td.get("action").shape[1:] == env.action_spec.shape + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) + assert env.action_spec.is_in(td.get("action")) + + for key in td.keys(): + if key != "action": + assert torch.allclose(td[key], td_copy[key]) + def test_MPPI(self, device, batch_size, seed=1): + torch.manual_seed(seed) + env = MockBatchedUnLockedEnv(device=device) + value_net = nn.LazyLinear(1, device=device) + value_net = ValueOperator(value_net, in_keys=["observation"]) + advantage_module = TDLambdaEstimate( + 0.99, + 0.95, + value_net, + ) + value_net(env.reset()) + planner = MPPIPlanner( + env, + advantage_module, + temperature=1.0, + planning_horizon=10, + optim_steps=2, + num_candidates=100, + top_k=2, + ) + td = env.reset(TensorDict({}, batch_size=batch_size).to(device)) + td_copy = td.clone() + td = planner(td) + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) assert env.action_spec.is_in(td.get("action")) for key in td.keys(): diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index aa1a8f40b8b..491f02f3391 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDictBase +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.envs import EnvBase from torchrl.modules.planners.common import MPCPlannerBase @@ -36,7 +36,7 @@ class CEMPlanner(MPCPlannerBase): planner num_candidates (int): The number of candidates to sample from the Gaussian distributions. - num_top_k_candidates (int): The number of top candidates to use to + top_k (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution. reward_key (str, optional): The key in the TensorDict to use to retrieve the reward. Defaults to "reward". @@ -61,13 +61,17 @@ class CEMPlanner(MPCPlannerBase): ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: - ... tensordict = TensorDict({}, + ... tensordict = TensorDict( + ... {}, ... batch_size=self.batch_size, ... device=self.device, ... ) - ... tensordict = tensordict.update(self.input_spec.rand(self.batch_size)) - ... tensordict = tensordict.update(self.observation_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.input_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.observation_spec.rand(self.batch_size)) ... return tensordict + ... >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( @@ -91,7 +95,12 @@ class CEMPlanner(MPCPlannerBase): action: Tensor(torch.Size([5, 1]), dtype=torch.float32), done: Tensor(torch.Size([5, 1]), dtype=torch.bool), hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), - next_hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + next: LazyStackedTensorDict( + fields={ + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), reward: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, batch_size=torch.Size([5]), device=cpu, @@ -104,7 +113,7 @@ def __init__( planning_horizon: int, optim_steps: int, num_candidates: int, - num_top_k_candidates: int, + top_k: int, reward_key: str = "reward", action_key: str = "action", ): @@ -112,46 +121,66 @@ def __init__( self.planning_horizon = planning_horizon self.optim_steps = optim_steps self.num_candidates = num_candidates - self.num_top_k_candidates = num_top_k_candidates + self.top_k = top_k self.reward_key = reward_key def planning(self, tensordict: TensorDictBase) -> torch.Tensor: batch_size = tensordict.batch_size - expanded_original_tensordict = ( - tensordict.unsqueeze(-1) - .expand(*batch_size, self.num_candidates) - .reshape(-1) + action_shape = ( + *batch_size, + self.num_candidates, + self.planning_horizon, + *self.action_spec.shape, ) - flatten_batch_size = batch_size.numel() - actions_means = torch.zeros( - flatten_batch_size, + action_stats_shape = ( + *batch_size, 1, self.planning_horizon, *self.action_spec.shape, - device=tensordict.device, - dtype=self.env.action_spec.dtype, ) - actions_stds = torch.ones( - flatten_batch_size, - 1, + action_topk_shape = ( + *batch_size, + self.top_k, self.planning_horizon, *self.action_spec.shape, + ) + TIME_DIM = len(self.action_spec.shape) - 3 + K_DIM = len(self.action_spec.shape) - 4 + expanded_original_tensordict = ( + tensordict.unsqueeze(-1) + .expand(*batch_size, self.num_candidates) + .to_tensordict() + ) + _action_means = torch.zeros( + *action_stats_shape, device=tensordict.device, dtype=self.env.action_spec.dtype, ) + _action_stds = torch.ones_like(_action_means) + container = TensorDict( + { + "tensordict": expanded_original_tensordict, + "stats": TensorDict( + { + "_action_means": _action_means, + "_action_stds": _action_stds, + }, + [*batch_size, 1, self.planning_horizon], + ), + }, + batch_size, + ) for _ in range(self.optim_steps): + actions_means = container.get(("stats", "_action_means")) + actions_stds = container.get(("stats", "_action_stds")) actions = actions_means + actions_stds * torch.randn( - flatten_batch_size, - self.num_candidates, - self.planning_horizon, - *self.action_spec.shape, - device=tensordict.device, - dtype=self.env.action_spec.dtype, + *action_shape, + device=actions_means.device, + dtype=actions_means.dtype, ) - actions = actions.flatten(0, 1) actions = self.env.action_spec.project(actions) - optim_tensordict = expanded_original_tensordict.to_tensordict() + optim_tensordict = container.get("tensordict").clone() policy = _PrecomputedActionsSequentialSetter(actions) optim_tensordict = self.env.rollout( max_steps=self.planning_horizon, @@ -159,23 +188,21 @@ def planning(self, tensordict: TensorDictBase) -> torch.Tensor: auto_reset=False, tensordict=optim_tensordict, ) - rewards = ( - optim_tensordict.get(self.reward_key) - .sum(dim=1) - .reshape(flatten_batch_size, self.num_candidates) - ) - _, top_k = rewards.topk(self.num_top_k_candidates, dim=1) - best_actions = actions.unflatten( - 0, (flatten_batch_size, self.num_candidates) + sum_rewards = optim_tensordict.get(self.reward_key).sum( + dim=TIME_DIM, keepdim=True + ) + _, top_k = sum_rewards.topk(self.top_k, dim=K_DIM) + top_k = top_k.expand(action_topk_shape) + best_actions = actions.gather(K_DIM, top_k) + container.set_( + ("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True) + ) + container.set_( + ("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True) ) - best_actions = best_actions[ - torch.arange(flatten_batch_size, device=tensordict.device).unsqueeze(1), - top_k, - ] - actions_means = best_actions.mean(dim=1, keepdim=True) - actions_stds = best_actions.std(dim=1, keepdim=True) - return actions_means[:, :, 0].reshape(*batch_size, *self.action_spec.shape) + action_means = container.get(("stats", "_action_means")) + return action_means[..., 0, 0, :] class _PrecomputedActionsSequentialSetter: @@ -183,9 +210,10 @@ def __init__(self, actions): self.actions = actions self.cmpt = 0 - def __call__(self, td): - if self.cmpt >= self.actions.shape[1]: - raise ValueError("Precomputed actions are too short") - td = td.set("action", self.actions[:, self.cmpt]) + def __call__(self, tensordict): + # checks that the step count is lower or equal to the horizon + if self.cmpt >= self.actions.shape[-2]: + raise ValueError("Precomputed actions sequence is too short") + tensordict = tensordict.set("action", self.actions[..., self.cmpt, :]) self.cmpt += 1 - return td + return tensordict diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py new file mode 100644 index 00000000000..f1a5fe9b255 --- /dev/null +++ b/torchrl/modules/planners/mppi.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.envs import EnvBase +from torchrl.modules.planners.common import MPCPlannerBase + + +class MPPIPlanner(MPCPlannerBase): + """MPPI Planner Module. + + Reference: + - Model predictive path integral control using covariance variable importance + sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149 + - Temporal Difference Learning for Model Predictive Control + (Hansen N., Wang X., Su H.) https://arxiv.org/abs/2203.04955 + + This module will perform a MPPI planning step when given a TensorDict + containing initial states. + + A call to the module returns the actions that empirically maximised the + returns given a planning horizon + + Args: + env (EnvBase): The environment to perform the planning step on (can be + `ModelBasedEnv` or :obj:`EnvBase`). + planning_horizon (int): The length of the simulated trajectories + optim_steps (int): The number of optimization steps used by the MPC + planner + num_candidates (int): The number of candidates to sample from the + Gaussian distributions. + top_k (int): The number of top candidates to use to + update the mean and standard deviation of the Gaussian distribution. + reward_key (str, optional): The key in the TensorDict to use to + retrieve the reward. Defaults to "reward". + action_key (str, optional): The key in the TensorDict to use to store + the action. Defaults to "action" + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.envs.model_based import ModelBasedEnvBase + >>> from torchrl.modules import TensorDictModule, ValueOperator + >>> from torchrl.objectives.value import TDLambdaEstimate + >>> class MyMBEnv(ModelBasedEnvBase): + ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): + ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) + ... self.observation_spec = CompositeSpec( + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... ) + ... self.input_spec = CompositeSpec( + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), + ... action=NdUnboundedContinuousTensorSpec((1,)), + ... ) + ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + ... + ... def _reset(self, tensordict: TensorDict) -> TensorDict: + ... tensordict = TensorDict( + ... {}, + ... batch_size=self.batch_size, + ... device=self.device, + ... ) + ... tensordict = tensordict.update( + ... self.input_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.observation_spec.rand(self.batch_size)) + ... return tensordict + >>> from torchrl.modules import MLP, WorldModelWrapper + >>> import torch.nn as nn + >>> world_model = WorldModelWrapper( + ... TensorDictModule( + ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), + ... in_keys=["hidden_observation", "action"], + ... out_keys=["hidden_observation"], + ... ), + ... TensorDictModule( + ... nn.Linear(4, 1), + ... in_keys=["hidden_observation"], + ... out_keys=["reward"], + ... ), + ... ) + >>> env = MyMBEnv(world_model) + >>> value_net = nn.Linear(4, 1) + >>> value_net = ValueOperator(value_net, in_keys=["hidden_observation"]) + >>> adv = TDLambdaEstimate( + ... 0.99, + ... 0.95, + ... value_net, + ... ) + >>> # Build a planner and use it as actor + >>> planner = MPPIPlanner( + ... env, + ... adv, + ... temperature=1.0, + ... planning_horizon=10, + ... optim_steps=11, + ... num_candidates=7, + ... top_k=3) + >>> env.rollout(5, planner) + TensorDict( + fields={ + action: Tensor(torch.Size([5, 1]), dtype=torch.float32), + done: Tensor(torch.Size([5, 1]), dtype=torch.bool), + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + next: LazyStackedTensorDict( + fields={ + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + """ + + def __init__( + self, + env: EnvBase, + advantage_module: nn.Module, + temperature: float, + planning_horizon: int, + optim_steps: int, + num_candidates: int, + top_k: int, + reward_key: str = "reward", + action_key: str = "action", + ): + super().__init__(env=env, action_key=action_key) + self.advantage_module = advantage_module + self.planning_horizon = planning_horizon + self.optim_steps = optim_steps + self.num_candidates = num_candidates + self.top_k = top_k + self.reward_key = reward_key + self.register_buffer("temperature", torch.tensor(temperature)) + + def planning(self, tensordict: TensorDictBase) -> torch.Tensor: + batch_size = tensordict.batch_size + action_shape = ( + *batch_size, + self.num_candidates, + self.planning_horizon, + *self.action_spec.shape, + ) + action_stats_shape = ( + *batch_size, + 1, + self.planning_horizon, + *self.action_spec.shape, + ) + action_topk_shape = ( + *batch_size, + self.top_k, + self.planning_horizon, + *self.action_spec.shape, + ) + adv_topk_shape = ( + *batch_size, + self.top_k, + 1, + 1, + ) + K_DIM = len(self.action_spec.shape) - 4 + expanded_original_tensordict = ( + tensordict.unsqueeze(-1) + .expand(*batch_size, self.num_candidates) + .to_tensordict() + ) + _action_means = torch.zeros( + *action_stats_shape, + device=tensordict.device, + dtype=self.env.action_spec.dtype, + ) + _action_stds = torch.ones_like(_action_means) + container = TensorDict( + { + "tensordict": expanded_original_tensordict, + "stats": TensorDict( + { + "_action_means": _action_means, + "_action_stds": _action_stds, + }, + [*batch_size, 1, self.planning_horizon], + ), + }, + batch_size, + ) + + for _ in range(self.optim_steps): + actions_means = container.get(("stats", "_action_means")) + actions_stds = container.get(("stats", "_action_stds")) + actions = actions_means + actions_stds * torch.randn( + *action_shape, + device=actions_means.device, + dtype=actions_means.dtype, + ) + actions = self.env.action_spec.project(actions) + optim_tensordict = container.get("tensordict").clone() + policy = _PrecomputedActionsSequentialSetter(actions) + optim_tensordict = self.env.rollout( + max_steps=self.planning_horizon, + policy=policy, + auto_reset=False, + tensordict=optim_tensordict, + ) + # compute advantage + self.advantage_module(optim_tensordict) + # get advantage of the current state + advantage = optim_tensordict["advantage"][..., :1, :] + # get top-k trajectories + _, top_k = advantage.topk(self.top_k, dim=K_DIM) + # get omega weights for each top-k trajectory + vals = advantage.gather(K_DIM, top_k.expand(adv_topk_shape)) + Omegas = (self.temperature * vals).exp() + + # gather best actions + best_actions = actions.gather(K_DIM, top_k.expand(action_topk_shape)) + + # compute weighted average + _action_means = (Omegas * best_actions).sum( + dim=K_DIM, keepdim=True + ) / Omegas.sum(K_DIM, True) + _action_stds = ( + (Omegas * (best_actions - _action_means).pow(2)).sum( + dim=K_DIM, keepdim=True + ) + / Omegas.sum(K_DIM, True) + ).sqrt() + container.set_(("stats", "_action_means"), _action_means) + container.set_(("stats", "_action_stds"), _action_stds) + action_means = container.get(("stats", "_action_means")) + return action_means[..., 0, 0, :] + + +class _PrecomputedActionsSequentialSetter: + def __init__(self, actions): + self.actions = actions + self.cmpt = 0 + + def __call__(self, tensordict): + # checks that the step count is lower or equal to the horizon + if self.cmpt >= self.actions.shape[-2]: + raise ValueError("Precomputed actions sequence is too short") + tensordict = tensordict.set("action", self.actions[..., self.cmpt, :]) + self.cmpt += 1 + return tensordict