In [1]:
from torchrl.envs import TransformedEnv, InitTracker
from torchrl.envs import GymEnv, SerialEnv
from torchrl.modules import MLP, LSTMModule, set_recurrent_mode
from torch import nn
from functools import partial
import torch
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
# torch.set_default_dtype(torch.double)
env = SerialEnv(2, [partial(TransformedEnv, GymEnv("Pendulum-v1"), InitTracker()) for _ in range(2)])
lstm_module = LSTMModule(
    input_size=env.observation_spec["observation"].shape[-1],
    hidden_size=64,
    in_keys=["observation", "rs_h", "rs_c"],
    out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
    python_based=True,
)
for p in lstm_module.parameters():
    p.data *= (1+torch.randn_like(p.data)/10)
mlp_value = MLP(num_cells=[64], out_features=1)
value_net = Seq(lstm_module, Mod(mlp_value, in_keys=["intermediate"], out_keys=["state_value"]))
mlp_policy = MLP(num_cells=[64], out_features=1)
policy_net = Seq(lstm_module, Mod(mlp_policy, in_keys=["intermediate"], out_keys=["action"]))
# value_net.select_out_keys("state_value")
env = env.append_transform(lstm_module.make_tensordict_primer())

  from .autonotebook import tqdm as notebook_tqdm
2025-05-07 14:07:43,673	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
from torchrl.objectives.value import GAE

In [3]:
vals = env.rollout(1000, policy_net, break_when_any_done=False)
vals["next", "is_init"] = vals["is_init"]

2025-05-07 14:07:44,241 [torchrl][INFO] transform container out of scope. Returning None for parent.
2025-05-07 14:07:44,242 [torchrl][INFO] transform container out of scope. Returning None for parent.


In [4]:
value_net(vals.copy());

In [5]:
gae = GAE(
    gamma=0.9,
    lmbda=0.99,
    value_network=value_net,
    shifted=True,
)
with set_recurrent_mode(True):
    r0 = gae(vals.copy())

data_in tensor([[   0,    0],
        [ 201,    0],
        [ 402,    0],
        [ 603,    0],
        [ 804,    0],
        [1005,    0],
        [1206,    0],
        [1407,    0],
        [1608,    0],
        [1809,    0]])
tensordict_shaped TensorDict(
    fields={
        is_init: Tensor(shape=torch.Size([10, 201, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([10, 201, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_c: Tensor(shape=torch.Size([10, 201, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_h: Tensor(shape=torch.Size([10, 201, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 201]),
    device=None,
    is_shared=False)


In [6]:
a0 = r0["advantage"]

In [7]:
gae = GAE(
    gamma=0.9,
    lmbda=0.99,
    value_network=value_net,
    shifted=False,
    deactivate_vmap=True,
)
with set_recurrent_mode(True):
    r1 = gae(vals.copy())
a1 = r1["advantage"]

data_next TensorDict(
    fields={
        is_init: Tensor(shape=torch.Size([2, 1000, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 1000, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_c: Tensor(shape=torch.Size([2, 1000, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_h: Tensor(shape=torch.Size([2, 1000, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 1000]),
    device=None,
    is_shared=False)
tensordict_shaped TensorDict(
    fields={
        is_init: Tensor(shape=torch.Size([10, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([10, 200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_c: Tensor(shape=torch.Size([10, 200, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        rs_h: Tensor(shape=torch.Size([10, 200, 1, 64]), device=cpu, dtype=torch.float32, i

In [8]:
a1/a0

tensor([[[0.9999],
         [0.9999],
         [1.0000],
         ...,
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0002],
         [1.0000],
         [1.0000],
         ...,
         [1.0000],
         [1.0000],
         [1.0000]]])

In [9]:
(a0/a1).view(-1)

tensor([1.0001, 1.0001, 1.0000,  ..., 1.0000, 1.0000, 1.0000])

In [10]:
(a0/a1)

tensor([[[1.0001],
         [1.0001],
         [1.0000],
         ...,
         [1.0000],
         [1.0000],
         [1.0000]],

        [[0.9998],
         [1.0000],
         [1.0000],
         ...,
         [1.0000],
         [1.0000],
         [1.0000]]])

In [11]:
d0 = r0["next", "done"]
d1 = r1["next", "done"]

(d0).view(-1)[190:220]

tensor([False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [12]:
abs(a0-a1).max()

tensor(0.0086)

In [13]:
import torch
torch.testing.assert_close(a0, a1)

AssertionError: Tensor-likes are not close!

Mismatched elements: 85 / 2000 (4.2%)
Greatest absolute difference: 0.008636474609375 at index (1, 400, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.0001738967257551849 at index (1, 0, 0) (up to 1.3e-06 allowed)