# TensorDict tutorial

TensorDict is a new tensor structure introduced in torchrl. With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes

### Motivation
As a concrete example, let us take DQN and PPO. The first uses a deterministic policy that applies an argmax operator to a collection of values associated with each action for a given observation.  The second has a parametric policy that outputs a distribution over the space of the available actions. Here are the pseudos codes:

<code>
# DQN
data = []
for i in range(max_steps):
    action, values = value_network(observation)  # action = values.argmax(-1)
    observation, reward, done, *other = env.step(action)
    data.append((action, values, observation, reward, done))
</code>
<code>
# PPO
data = []
for i in range(max_steps):
    action, action_log_prob = policy(observation)
    observation, reward, done, *other = env.step(action)
    data.append((action, action_log_prob, observation, reward, done))

</code>

Ideally we would like to abstract this away into the same code:

<code>
collections = []
for i in range(max_steps):
    collection_of_values = policy(collection_of_values)
    collection_of_values = env_step(collection_of_values)
    collections.append(collection_of_values)
</code>
The differences in the algorithms will now lie in the `policy`, the `env_step` and the initial `collection_of_values` but the main algorithm is now the same for both algorithm. This abstraction allows for more modular and reusable code.

## Tensor Dict Python Dictionary behaviour

TensorDict shares a lot of features with python dictionaries

In [1]:
from torchrl.data import TensorDict
import torch

In [2]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


If we want to access a certain key, it is explicit:

In [3]:
print(tensordict["a"])
print(tensordict["a"].shape)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
torch.Size([3, 4, 5])


Also works with get()

In [4]:
print(tensordict.get("a"))
print(tensordict.get("a").shape)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
torch.Size([3, 4, 5])


#### TensorDict.keys()
Keys can be retrieved to TensorDict

In [5]:
for key in tensordict.keys():
    print(key)

a
b


#### TensorDict.values()
The values of a TensorDict can be retrieved with the values() function. On the contrary of python dicts, the values() function return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors which can take a lot of space in memory.

In [6]:
tensordict.values()

<generator object _TensorDict.values at 0x11711d580>

In [7]:
for value in tensordict.values():
    print(value.shape)

torch.Size([3, 4, 5])
torch.Size([3, 4, 1])


#### TensorDict.set()
The set function can be used to set new values

In [8]:
tensordict.set("c", torch.zeros((3, 4, 2, 2)))
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)), "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
print(f"c is set as {tensordict['c']}")

c is set as tensor([[[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]]])


#### TensorDict.update()
The update function can be used to update the dict with other dict values (Or TensorDict)

In [9]:
tensordict.update({"a":torch.ones((3, 4, 5)), "d":torch.ones((3, 4, 2))})
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)), "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
print(f"a is now {tensordict['a']}")
print(f"d is set as {tensordict['d']}")

a is now tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])
d is set as tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])


#### TensorDict del key
TensorDict also support keys deletion with the del operator:

In [10]:
del tensordict["c"]
print(tensordict.keys())

dict_keys(['a', 'b', 'd'])


## TensorDict as a pytorch Tensor

But wait? Can't we do this with a classical dict? 
Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.
TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.

TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:
- device;
- memory location (shared, memory-mapped array, ...);
- batch size (i.e. n^th first dimensions).

In [11]:
from torchrl.data.tensordict.tensordict import TensorDict
import torch

In [12]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)
print(f"Our Tensor dict is of size {tensordict.shape}")

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)
Our Tensor dict is of size torch.Size([3, 4])


#### Batch size

Tensor dict has a batch size which is shared across all tensors

In [13]:
print(f"Our Tensor dict is of size {tensordict.shape}")

Our Tensor dict is of size torch.Size([3, 4])


You cannot have items that don't share the batch size inside the same TensorDict:

In [14]:
tensordict.update({"c":torch.zeros(4,3,1)})

RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])

#### Cloning
TensorDict supports cloning

In [15]:
tensordict_clone = tensordict.clone()
tensordict_clone["a"] = torch.ones(*tensordict.shape,5)
print(tensordict["a"])
print(tensordict_clone["a"])

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])


### Tensor operations
We can perform tensor operations among the batch dimensions:

#### Slicing and indexing
Slicing and indexing is supported among the batch dimension

In [16]:
tensordict[0]

TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [17]:
tensordict[1:]

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 4]),
    device=cpu,
    is_shared=False)

In [18]:
tensordict[:,2:]

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 2]),
    device=cpu,
    is_shared=False)

TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\*shape), squeeze(dim), unsqueeze(dim), permute(\*dims) requiring the operations to comply with the batch_size

In [19]:
# Reshape
tensordict.reshape(-1)

TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

In [20]:
#Cat
torch.cat([tensordict, tensordict.clone()], dim=0)

TensorDict(
    fields={
        a: Tensor(torch.Size([6, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([6, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([6, 4]),
    device=cpu,
    is_shared=False)

#### Casting to device
TensorDict supports casting to devices with the .to(device) function as with regular tensors

## How to use them in practice? The tensor the TensorDictModule

Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key.

In [21]:
from torchrl.modules import TensorDictModule
import torch.nn as nn

### Example: Simple Linear layer

Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a.

In [22]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3,10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace

In [23]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3,10),in_keys=["a"], out_keys=["a"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 2 input merging with 2 linear layer

Now lets imagine a more complex network that takes 2 entries and average them into a single output

In [24]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out)
        self.linear_2  = nn.Linear(in_2,out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [25]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3), "c":torch.randn(5,4)}, batch_size=[5])
mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=["a","c"], out_keys=["output"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 1 input to 2 outputs linear layer
We can also map to multiple outputs

In [26]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out_1)
        self.linear_2  = nn.Linear(in_1,out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [27]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=["a"], out_keys=["output_1", "output_2"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.
The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key.

### Example: A transformer with TensorDict?
Let's attempt to create a transformer with TensorDict and TensorDictModule

Disclaimer: This implementation don't claim to be "better" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.
For simplicity we will not have positional encoders.

Let's first implement the classical transformers blocks

In [28]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2,3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block.

In [29]:
class TransformerBlockTensorDict(nn.Module):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer_block = nn.Sequential(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=["X_to"], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
        )
    def forward(self, X_tensor_dict):
        self.transformer_block(X_tensor_dict)
        

In [30]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])

transformer_block = TransformerBlockTensorDict("X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [31]:
tokens["X_to"]

tensor([[[-0.8713, -1.2626,  1.3218, -0.2947,  1.6938],
         [-0.7374, -0.6038, -1.4958, -0.3975,  0.0741],
         [-0.5026,  0.0095,  1.8102,  0.2308,  1.0256]],

        [[ 0.4484,  0.0689,  1.0152, -1.1690,  1.2264],
         [-0.5891,  0.0737,  0.7038,  1.0404,  0.0131],
         [ 1.0050, -2.4708, -0.7890, -1.0160,  0.4391]],

        [[-1.8364, -0.5181, -0.5258,  0.5166,  1.8120],
         [ 1.3389,  0.1451, -0.1267, -0.7637,  1.6104],
         [-0.1859, -0.4134, -1.4359, -0.1131,  0.4961]],

        [[-1.0511,  0.1636, -0.9440, -0.2152, -0.4874],
         [ 1.4676,  2.0405,  0.2846,  0.5990,  1.0199],
         [-1.5073,  0.0980, -1.5943, -0.1160,  0.2421]],

        [[-0.6059,  0.3442,  0.6854, -0.0933,  1.8850],
         [-0.2040, -2.0479,  0.8991,  1.1162,  0.0855],
         [-1.6792, -0.4797,  0.4558,  0.4763, -0.8374]],

        [[-1.2463, -1.3887,  1.2930,  0.5651,  0.9994],
         [-0.1023, -0.4523, -2.0760,  1.6500,  0.5962],
         [ 0.7221,  0.1171,  0.1437, -

We can now create a transformer easily

In [32]:
class TransformerTensorDict(nn.Module):
    def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
    def forward(self, X_tensor_dict):
        for transformer_block in self.transformer:
            transformer_block(X_tensor_dict)

In [33]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])




For an encoder, we just need to take the same tokens for both queries, keys and values.

In [34]:
transformer_encoder = TransformerTensorDict(6, "X_to", "X_to", to_dim, to_len, to_dim, latent_dim, num_heads)

transformer_encoder(tokens)
tokens["X_to"]

tensor([[[-7.0456e-01,  6.2688e-01, -1.1026e+00,  2.2786e-02, -4.8655e-02],
         [-9.7513e-01,  1.4375e+00,  9.4504e-01,  2.0078e+00, -8.5485e-01],
         [-1.3192e+00,  8.7940e-01, -1.2088e+00,  4.6428e-01, -1.7003e-01]],

        [[-2.4683e+00,  1.1900e+00, -4.6999e-01,  8.1202e-01, -1.0019e+00],
         [ 9.6744e-01,  1.0676e+00,  1.0292e+00,  4.7986e-01, -1.2177e+00],
         [ 3.2124e-02,  4.6237e-01, -7.1896e-01,  2.0720e-03, -1.6574e-01]],

        [[-2.2703e-01,  1.5631e+00,  1.1274e+00,  8.1163e-02, -1.5204e+00],
         [-1.8939e+00,  7.4907e-01, -1.6144e+00,  7.3381e-01,  7.8596e-01],
         [-7.7463e-02,  2.4682e-01,  4.6115e-01,  3.5791e-01, -7.7312e-01]],

        [[ 4.2789e-01,  7.7004e-02, -5.2232e-01, -1.3905e+00, -6.8685e-01],
         [-9.8940e-01,  6.9261e-02,  1.8176e+00,  1.2323e+00, -2.8591e-01],
         [-1.9008e+00,  1.2424e+00,  9.0587e-01,  3.6788e-01, -3.6444e-01]],

        [[ 1.4538e+00,  4.6922e-01, -8.1502e-01, -3.0426e-01,  3.3914e-01],
    

For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values.

In [35]:
transformer_decoder = TransformerTensorDict(6, "X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_decoder(tokens)
tokens["X_to"]

tensor([[[-0.3380,  0.7571, -1.0049,  0.7433, -0.1268],
         [-1.2908,  1.1133, -0.1441,  1.9060, -1.1740],
         [-1.2169,  1.1862, -0.9880,  0.9077, -0.3301]],

        [[-0.9989,  1.0225,  0.0516,  1.4268, -1.3924],
         [-0.4546,  1.6237, -0.9002,  1.0320, -1.1606],
         [-0.5801,  0.9752, -1.1864,  0.4610,  0.0803]],

        [[-0.9055,  1.3019, -0.1604,  1.2191, -1.8533],
         [-1.2209,  1.0449, -0.6953,  1.0392,  0.4056],
         [-0.3828,  0.6094,  0.0309,  0.9272, -1.3600]],

        [[-0.3067,  0.0189, -0.9604, -0.3966, -0.3716],
         [-1.0679,  0.4198,  1.2470,  1.9662, -0.8475],
         [-1.6165,  1.1001, -0.3376,  1.4791, -0.3264]],

        [[-0.0980,  0.6583, -1.2214,  0.2261, -0.4485],
         [-0.6073,  1.5547, -0.3092,  1.1512, -1.0601],
         [-1.4279,  1.2677,  0.1266,  1.4727, -1.2850]],

        [[-0.4164,  0.7499,  0.2536,  0.1846,  0.1112],
         [-0.4129,  1.0208, -0.3762,  0.7487, -1.3215],
         [-0.1598,  0.5029, -0.3288,  

Now we can look at both models:

In [36]:
transformer_encoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=5, out_features=10, bias=True)
              (v): Linear(in_features=5, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_to'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictMod

In [37]:
transformer_decoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=6, out_features=10, bias=True)
              (v): Linear(in_features=6, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_from'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictM