Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
RSSMRollout,
)
from torchrl.modules.models.utils import SquashDims
from torchrl.modules.planners.mppi import MPPIPlanner
from torchrl.objectives.value import TDLambdaEstimate


@pytest.fixture
Expand Down Expand Up @@ -437,9 +439,9 @@ def test_lstm_net_nobatch(device, out_features, hidden_size):
torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1])


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("batch_size", [3, 5])
class TestPlanner:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("batch_size", [3, 5])
def test_CEM_model_free_env(self, device, batch_size, seed=1):
env = MockBatchedUnLockedEnv(device=device)
torch.manual_seed(seed)
Expand All @@ -448,13 +450,48 @@ def test_CEM_model_free_env(self, device, batch_size, seed=1):
planning_horizon=10,
optim_steps=2,
num_candidates=100,
num_top_k_candidates=2,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert td.get("action").shape[1:] == env.action_spec.shape
assert (
td.get("action").shape[-len(env.action_spec.shape) :]
== env.action_spec.shape
)
assert env.action_spec.is_in(td.get("action"))

for key in td.keys():
if key != "action":
assert torch.allclose(td[key], td_copy[key])

def test_MPPI(self, device, batch_size, seed=1):
torch.manual_seed(seed)
env = MockBatchedUnLockedEnv(device=device)
value_net = nn.LazyLinear(1, device=device)
value_net = ValueOperator(value_net, in_keys=["observation"])
advantage_module = TDLambdaEstimate(
0.99,
0.95,
value_net,
)
value_net(env.reset())
planner = MPPIPlanner(
env,
advantage_module,
temperature=1.0,
planning_horizon=10,
optim_steps=2,
num_candidates=100,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert (
td.get("action").shape[-len(env.action_spec.shape) :]
== env.action_spec.shape
)
assert env.action_spec.is_in(td.get("action"))

for key in td.keys():
Expand Down
124 changes: 76 additions & 48 deletions torchrl/modules/planners/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from tensordict.tensordict import TensorDictBase
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.envs import EnvBase
from torchrl.modules.planners.common import MPCPlannerBase
Expand Down Expand Up @@ -36,7 +36,7 @@ class CEMPlanner(MPCPlannerBase):
planner
num_candidates (int): The number of candidates to sample from the
Gaussian distributions.
num_top_k_candidates (int): The number of top candidates to use to
top_k (int): The number of top candidates to use to
update the mean and standard deviation of the Gaussian distribution.
reward_key (str, optional): The key in the TensorDict to use to
retrieve the reward. Defaults to "reward".
Expand All @@ -61,13 +61,17 @@ class CEMPlanner(MPCPlannerBase):
... self.reward_spec = UnboundedContinuousTensorSpec((1,))
...
... def _reset(self, tensordict: TensorDict) -> TensorDict:
... tensordict = TensorDict({},
... tensordict = TensorDict(
... {},
... batch_size=self.batch_size,
... device=self.device,
... )
... tensordict = tensordict.update(self.input_spec.rand(self.batch_size))
... tensordict = tensordict.update(self.observation_spec.rand(self.batch_size))
... tensordict = tensordict.update(
... self.input_spec.rand(self.batch_size))
... tensordict = tensordict.update(
... self.observation_spec.rand(self.batch_size))
... return tensordict
...
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> import torch.nn as nn
>>> world_model = WorldModelWrapper(
Expand All @@ -91,7 +95,12 @@ class CEMPlanner(MPCPlannerBase):
action: Tensor(torch.Size([5, 1]), dtype=torch.float32),
done: Tensor(torch.Size([5, 1]), dtype=torch.bool),
hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32),
next_hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32),
next: LazyStackedTensorDict(
fields={
hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False),
reward: Tensor(torch.Size([5, 1]), dtype=torch.float32)},
batch_size=torch.Size([5]),
device=cpu,
Expand All @@ -104,88 +113,107 @@ def __init__(
planning_horizon: int,
optim_steps: int,
num_candidates: int,
num_top_k_candidates: int,
top_k: int,
reward_key: str = "reward",
action_key: str = "action",
):
super().__init__(env=env, action_key=action_key)
self.planning_horizon = planning_horizon
self.optim_steps = optim_steps
self.num_candidates = num_candidates
self.num_top_k_candidates = num_top_k_candidates
self.top_k = top_k
self.reward_key = reward_key

def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
batch_size = tensordict.batch_size
expanded_original_tensordict = (
tensordict.unsqueeze(-1)
.expand(*batch_size, self.num_candidates)
.reshape(-1)
action_shape = (
*batch_size,
self.num_candidates,
self.planning_horizon,
*self.action_spec.shape,
)
flatten_batch_size = batch_size.numel()
actions_means = torch.zeros(
flatten_batch_size,
action_stats_shape = (
*batch_size,
1,
self.planning_horizon,
*self.action_spec.shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
)
actions_stds = torch.ones(
flatten_batch_size,
1,
action_topk_shape = (
*batch_size,
self.top_k,
self.planning_horizon,
*self.action_spec.shape,
)
TIME_DIM = len(self.action_spec.shape) - 3
K_DIM = len(self.action_spec.shape) - 4
expanded_original_tensordict = (
tensordict.unsqueeze(-1)
.expand(*batch_size, self.num_candidates)
.to_tensordict()
)
_action_means = torch.zeros(
*action_stats_shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
)
_action_stds = torch.ones_like(_action_means)
container = TensorDict(
{
"tensordict": expanded_original_tensordict,
"stats": TensorDict(
{
"_action_means": _action_means,
"_action_stds": _action_stds,
},
[*batch_size, 1, self.planning_horizon],
),
},
batch_size,
)

for _ in range(self.optim_steps):
actions_means = container.get(("stats", "_action_means"))
actions_stds = container.get(("stats", "_action_stds"))
actions = actions_means + actions_stds * torch.randn(
flatten_batch_size,
self.num_candidates,
self.planning_horizon,
*self.action_spec.shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
*action_shape,
device=actions_means.device,
dtype=actions_means.dtype,
)
actions = actions.flatten(0, 1)
actions = self.env.action_spec.project(actions)
optim_tensordict = expanded_original_tensordict.to_tensordict()
optim_tensordict = container.get("tensordict").clone()
policy = _PrecomputedActionsSequentialSetter(actions)
optim_tensordict = self.env.rollout(
max_steps=self.planning_horizon,
policy=policy,
auto_reset=False,
tensordict=optim_tensordict,
)
rewards = (
optim_tensordict.get(self.reward_key)
.sum(dim=1)
.reshape(flatten_batch_size, self.num_candidates)
)
_, top_k = rewards.topk(self.num_top_k_candidates, dim=1)

best_actions = actions.unflatten(
0, (flatten_batch_size, self.num_candidates)
sum_rewards = optim_tensordict.get(self.reward_key).sum(
dim=TIME_DIM, keepdim=True
)
_, top_k = sum_rewards.topk(self.top_k, dim=K_DIM)
top_k = top_k.expand(action_topk_shape)
best_actions = actions.gather(K_DIM, top_k)
container.set_(
("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True)
)
container.set_(
("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True)
)
best_actions = best_actions[
torch.arange(flatten_batch_size, device=tensordict.device).unsqueeze(1),
top_k,
]
actions_means = best_actions.mean(dim=1, keepdim=True)
actions_stds = best_actions.std(dim=1, keepdim=True)
return actions_means[:, :, 0].reshape(*batch_size, *self.action_spec.shape)
action_means = container.get(("stats", "_action_means"))
return action_means[..., 0, 0, :]


class _PrecomputedActionsSequentialSetter:
def __init__(self, actions):
self.actions = actions
self.cmpt = 0

def __call__(self, td):
if self.cmpt >= self.actions.shape[1]:
raise ValueError("Precomputed actions are too short")
td = td.set("action", self.actions[:, self.cmpt])
def __call__(self, tensordict):
# checks that the step count is lower or equal to the horizon
if self.cmpt >= self.actions.shape[-2]:
raise ValueError("Precomputed actions sequence is too short")
tensordict = tensordict.set("action", self.actions[..., self.cmpt, :])
self.cmpt += 1
return td
return tensordict
Loading