# Task-specific policy

In [10]:
from torchrl.envs import DMControlEnv, TransformedEnv, CatTensors, Compose, DoubleToFloat
from torchrl.modules import TensorDictModule, TensorDictSequence, MLP
from torch import nn
import torch

In [23]:
env1 = DMControlEnv("humanoid", "stand")
env1 = TransformedEnv(
    env1, 
    Compose(
        CatTensors(list(env1.observation_spec.keys()), "next_observation_stand", del_keys=False),
        CatTensors(list(env1.observation_spec.keys()), "next_observation"),
        DoubleToFloat(keys_in=["next_observation_stand", "next_observation"], keys_inv_in=["action"]),
    )
)
env2 = DMControlEnv("humanoid", "walk")
env2 = TransformedEnv(
    env2, 
    Compose(
        CatTensors(list(env2.observation_spec.keys()), "next_observation_walk", del_keys=False),
        CatTensors(list(env2.observation_spec.keys()), "next_observation"),
        DoubleToFloat(keys_in=["next_observation_walk", "next_observation"], keys_inv_in=["action"]),
    )
)


In [24]:
env1.reset()

TensorDict(
    fields={
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        observation: Tensor(torch.Size([67]), dtype=torch.float32),
        observation_stand: Tensor(torch.Size([67]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [25]:
env2.reset()

TensorDict(
    fields={
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        observation: Tensor(torch.Size([67]), dtype=torch.float32),
        observation_walk: Tensor(torch.Size([67]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [26]:
policy_common = TensorDictModule(nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"])
policy_stand = TensorDictModule(MLP(67 + 64, 64, depth=2), in_keys=["observation_stand", "hidden"], out_keys=["action_stand"])
policy_walk = TensorDictModule(MLP(67 + 64, 64, depth=2), in_keys=["observation_walk", "hidden"], out_keys=["action_walk"])
seq = TensorDictSequence(policy_common, policy_stand, policy_walk, partial_tolerant=True)

In [27]:
seq(env1.reset())

TensorDict(
    fields={
        action_stand: Tensor(torch.Size([64]), dtype=torch.float32),
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        hidden: Tensor(torch.Size([64]), dtype=torch.float32),
        observation: Tensor(torch.Size([67]), dtype=torch.float32),
        observation_stand: Tensor(torch.Size([67]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [28]:
seq(env2.reset())

TensorDict(
    fields={
        action_walk: Tensor(torch.Size([64]), dtype=torch.float32),
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        hidden: Tensor(torch.Size([64]), dtype=torch.float32),
        observation: Tensor(torch.Size([67]), dtype=torch.float32),
        observation_walk: Tensor(torch.Size([67]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)