# TensorDict tutorial

TensorDict is a new tensor structure introduced in torchrl. With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes

### Example
As a concrete example, let us take DQN and PPO. The first uses a deterministic policy that applies an argmax operator to a collection of values associated with each action for a given observation.  The second has a parametric policy that outputs a distribution over the space of the available actions. Here are the pseudos codes:

<code>
# DQN
data = []
for i in range(max_steps):
    action, values = value_network(observation)  # action = values.argmax(-1)
    observation, reward, done, *other = env.step(action)
    data.append((action, values, observation, reward, done))
</code>
<code>
# PPO
data = []
for i in range(max_steps):
    action, action_log_prob = policy(observation)
    observation, reward, done, *other = env.step(action)
    data.append((action, action_log_prob, observation, reward, done))

</code>

Ideally we would like to abstract this away into the same code:

<code>
collections = []
for i in range(max_steps):
    collection_of_values = policy(collection_of_values)
    collection_of_values = env_step(collection_of_values)
    collections.append(collection_of_values)
</code>
The differences in the algorithms will now lie in the `policy`, the `env_step` and the initial `collection_of_values` but the main algorithm is now the same for both algorithm. This abstraction allows for more modular and reusable code.

## Tensor Dict Python Dictionary behaviour

TensorDict shares a lot of features with python dictionaries

In [1]:
from torchrl.data import TensorDict
import torch

In [2]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


If we want to access a certain key, it is explicit:

In [3]:
print(tensordict["a"])
print(tensordict["a"].shape)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
torch.Size([3, 4, 5])


#### TensorDict.keys()
We can access the dict keys

In [4]:
for key in tensordict.keys():
    print(key)

a
b


#### TensorDict.values()
We can also retrieve the values of the dict. On the contrary of python dicts, we return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors in mind.

In [5]:
tensordict.values()

<generator object _TensorDict.values at 0x10e85f3c0>

In [6]:
for value in tensordict.values():
    print(value.shape)

torch.Size([3, 4, 5])
torch.Size([3, 4, 1])


#### TensorDict.update()
We can also use the update function like for dicts

In [7]:
tensordict.update({"a":torch.ones((3, 4, 5))})
tensordict.update({"c":torch.ones((3, 4, 2))})
print(f"a is now {tensordict['a']}")
print(f"c is set as {tensordict['c']}")

a is now tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])
c is set as tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])


#### TensorDict del key
Tensor Dict also support keys deletion

In [8]:
del tensordict["c"]
print(tensordict.keys())

dict_keys(['a', 'b'])


## TensorDict as a pytorch Tensor

But wait? Can't we do this with a classical dict? 
Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.
TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict

TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:
- device;
- memory location (shared, memory-mapped array, ...);
- batch size (i.e. n^th first dimensions).

In [9]:
from torchrl.data.tensordict.tensordict import TensorDict
import torch

In [10]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)
print(f"Our Tensor dict is of size {tensordict.shape}")

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)
Our Tensor dict is of size torch.Size([3, 4])


#### Batch size

Tensor dict has a batch size which is shared across all tensors

In [11]:
print(f"Our Tensor dict is of size {tensordict.shape}")

Our Tensor dict is of size torch.Size([3, 4])


You cannot have items that don't share the batch size inside the TensorDict

In [12]:
tensordict.update({"c":torch.zeros(4,3,1)})

RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])

### Tensor operations
We can perform tensor operations among the batch dimensions

#### Cloning
TensorDict supports cloning

In [13]:
tensordict_clone = tensordict.clone()
tensordict_clone["a"] = torch.ones(*tensordict.shape,5)
print(tensordict["a"])
print(tensordict_clone["a"])

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])


#### Slicing and indexing
Slicing and indexing is supported among the batch dimension

In [14]:
tensordict[0]

TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [15]:
tensordict[1:]

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 4]),
    device=cpu,
    is_shared=False)

In [16]:
tensordict[:,2:]

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 2]),
    device=cpu,
    is_shared=False)

TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\*shape), squeeze(dim), unsqueeze(dim), permute(\*dims) requiring the operations to comply with the batch_size

In [17]:
tensordict.reshape(-1)

TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

#### Casting to device
TensorDict supports casting to devices with the .to(device) function as with regular tensors

## How to use them in practice? The tensor the TensorDictModule

Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key.

In [18]:
from torchrl.modules import TensorDictModule
import torch.nn as nn

### Example: Simple Linear layer

Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a.

In [19]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3,10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace

In [20]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3,10),in_keys=["a"], out_keys=["a"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 2 input merging with 2 linear layer

Now lets imagine a more complex network that takes 2 entries and average them into a single output

In [21]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out)
        self.linear_2  = nn.Linear(in_2,out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [22]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3), "c":torch.randn(5,4)}, batch_size=[5])
mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=["a","c"], out_keys=["output"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 1 input to 2 outputs linear layer
We can also map to multiple outputs

In [23]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out_1)
        self.linear_2  = nn.Linear(in_1,out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [24]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=["a"], out_keys=["output_1", "output_2"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.
The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key.

### Example: A transformer with TensorDict?
Let's attempt to create a transformer with TensorDict and TensorDictModule

Disclaimer: This implementation isn't to be "better" than a tensor based implementation. It is just meant to showcase the TensorDictModule features.

Let's first implement the classical transformers blocks

In [25]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2,3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


In [26]:
class TransformerBlockTensorDict(nn.Module):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer_block = nn.Sequential(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=["X_to"], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
        )
    def forward(self, X_tensor_dict):
        self.transformer_block(X_tensor_dict)
        

In [27]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])

transformer_block = TransformerBlockTensorDict("X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [28]:
tokens["X_to"]

tensor([[[-2.2140e-01,  9.3885e-02,  1.5475e+00, -8.5764e-01, -1.0164e+00],
         [ 6.1844e-01,  4.2928e-01,  1.2559e+00, -1.7318e-02, -1.0785e+00],
         [ 3.6746e-01, -2.0010e+00,  5.7185e-01,  1.3700e+00, -1.0621e+00]],

        [[ 1.8018e+00,  1.7676e-01,  1.3061e+00,  5.5536e-04,  9.5903e-01],
         [-1.7914e-02, -1.1058e+00, -1.0330e+00, -5.6826e-01,  6.1741e-01],
         [ 1.0086e+00, -2.1221e+00,  5.2469e-02, -5.4785e-01, -5.2776e-01]],

        [[-5.7215e-03, -2.1323e+00, -6.8256e-01, -5.2980e-02,  2.1093e+00],
         [-3.1369e-01, -1.0205e+00,  1.2342e+00, -4.3158e-01,  1.8713e-01],
         [ 1.2250e+00, -3.5429e-01,  3.3673e-01, -7.2250e-01,  6.2381e-01]],

        [[ 1.1720e+00, -1.2023e+00, -4.1971e-01, -3.9235e-01, -5.0843e-01],
         [ 1.4071e+00, -7.0667e-01, -9.0403e-01,  9.7855e-01,  1.9402e+00],
         [ 5.9586e-01, -2.0236e-01, -5.3443e-01, -1.6470e+00,  4.2355e-01]],

        [[ 2.2837e-01, -2.2963e-02,  2.4110e-01, -9.2972e-01,  5.7418e-01],
    

We can now create a transformer easily

In [29]:
class TransformerTensorDict(nn.Module):
    def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
    def forward(self, X_tensor_dict):
        for transformer_block in self.transformer:
            transformer_block(X_tensor_dict)

In [30]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])




For an encoder, we can do it easily as follow

In [31]:
transformer_encoder = TransformerTensorDict(6, "X_to", "X_to", to_dim, to_len, to_dim, latent_dim, num_heads)

transformer_encoder(tokens)
tokens["X_to"]

tensor([[[ 0.5433, -2.1209,  0.6917,  0.1949, -0.2383],
         [-0.2084,  1.1901,  0.1548, -1.2804,  0.3292],
         [ 1.8350, -1.2308,  0.2563, -0.9090,  0.7924]],

        [[ 0.7642, -0.8773,  0.6102, -1.9610, -0.2547],
         [ 0.8582, -0.8673,  1.6634, -0.5311, -1.3557],
         [ 0.5985, -0.6759,  0.8806, -0.0338,  1.1816]],

        [[ 0.1173, -1.0630, -1.1206, -0.9292,  0.6910],
         [-0.5904,  0.2967,  1.0146, -0.1985, -0.1382],
         [ 2.3575, -0.4132,  1.5436, -1.2285, -0.3390]],

        [[-0.0839,  0.3025,  1.2402, -1.1337, -0.3312],
         [ 0.1848, -0.5644, -0.9238, -1.2078,  0.7454],
         [ 0.9539, -0.4605,  2.5865, -0.8541, -0.4540]],

        [[ 2.2012,  0.3198, -0.1624, -1.0218, -0.2963],
         [-1.1105,  0.0330, -0.6841,  0.2308,  0.1053],
         [ 2.2178, -1.1405,  0.0187, -0.9278,  0.2166]],

        [[ 0.1762,  0.1206, -0.9111, -0.8480,  0.8383],
         [-0.0356,  1.9972,  1.2416, -1.9975,  0.2798],
         [-1.5899,  0.6967,  0.2188, -

For a decoder we have 

In [32]:
transformer_decoder = TransformerTensorDict(6, "X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_decoder(tokens)
tokens["X_to"]

tensor([[[ 0.9111, -0.9565,  0.8885, -0.1153, -1.1195],
         [-1.0822,  1.9428, -0.1168, -1.2518, -0.4014],
         [ 1.6768, -1.1258,  0.4677, -0.1539,  0.4362]],

        [[ 0.2126, -0.9831,  0.8968, -1.4953, -1.0055],
         [ 0.8913,  0.0649,  0.6585, -0.7806, -2.0843],
         [ 1.5738,  0.4048,  0.6604,  0.7997,  0.1859]],

        [[-0.1787, -0.4171, -1.1724, -0.7935,  0.4380],
         [-0.3860,  1.7195,  0.8667, -1.2160, -0.2223],
         [ 1.6916,  0.1081,  1.4817, -1.3484, -0.5712]],

        [[ 0.3393,  0.4130,  0.7333, -1.0726, -0.8733],
         [ 0.9412, -0.4980, -0.0214, -0.0619,  0.4108],
         [ 1.5666, -0.0143,  1.5400, -1.2883, -2.1144]],

        [[ 2.3167,  0.4312, -0.4086,  0.1072, -1.3709],
         [-0.6496,  0.0503, -0.5773,  0.4249, -0.6153],
         [ 2.2431, -0.4351, -0.3872, -0.3861, -0.7434]],

        [[ 0.7115,  0.0280, -0.2837, -0.1418,  0.2284],
         [-0.3786,  1.7935,  1.5064, -1.9357, -0.8157],
         [-0.8887,  1.5851, -0.0724, -

In [33]:
transformer_encoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=5, out_features=10, bias=True)
              (v): Linear(in_features=5, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_to'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictMod

In [34]:
transformer_decoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=6, out_features=10, bias=True)
              (v): Linear(in_features=6, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_from'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictM