# The TensorDictModule

Make sure to first read the tensordict tutorial

How do we use the TensorDict it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key.

In [2]:
from torchrl.modules import TensorDictModule, TensorDictSequence
from torchrl.data import TensorDict
import torch.nn as nn
import torch

### Example: Simple Linear layer

Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a.

In [3]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace

In [4]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 2 input merging with 2 linear layer

Now lets imagine a more complex network that takes 2 entries and average them into a single output

In [5]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out)
        self.linear_2  = nn.Linear(in_2,out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [6]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3), "c":torch.randn(5, 4)}, batch_size=[5])
mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=["a","c"], out_keys=["output"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 1 input to 2 outputs linear layer
We can also map to multiple outputs

In [7]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out_1)
        self.linear_2  = nn.Linear(in_1,out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [8]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3)}, batch_size=[5])
mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=["a"], out_keys=["output_1", "output_2"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.
The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key.

### Example: A transformer with TensorDict?
Let's attempt to create a transformer with TensorDict and TensorDictModule.

Here's a diagram that sums up the architecture:

<img src="./media/transformer.png" width = 1000px/>
Disclaimer: This implementation don't claim to be "better" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.
For simplicity we will not have positional encoders.

Let's first implement the classical transformers blocks.

In [36]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block.

In [37]:
class AttentionBlockTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name]),
        )
class TransformerBlockEncoderTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            AttentionBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=[to_name], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name]),
        )
class TransformerBlockDecoderTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            AttentionBlockTensorDict(to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads),
            TransformerBlockEncoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads)
        )

In [38]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_to": torch.randn(batch_size, to_len, to_dim),
        "X_from": torch.randn(batch_size, from_len, from_dim)
    },
    batch_size=[batch_size]
)

transformer_block = AttentionBlockTensorDict(
    "X_to",
    "X_from",
    to_dim,
    to_len,
    from_dim,
    latent_dim,
    num_heads
)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [39]:
tokens["X_to"]

tensor([[[-0.2024, -1.0720,  0.0880,  0.1581,  1.3610],
         [-0.5052, -1.8012,  0.7504, -0.6252, -1.9845],
         [ 0.7721,  0.9715,  0.5495,  1.2725,  0.2672]],

        [[ 0.5293, -0.4636, -1.3908, -0.7254,  0.7734],
         [-1.7271,  1.1405,  0.5576, -0.3882,  0.0643],
         [ 0.6418,  1.0622,  1.1772, -1.8478,  0.5967]],

        [[ 0.4203,  0.1880, -1.6405, -1.6656,  0.4852],
         [ 0.5613,  0.6962,  0.3396,  0.1042,  1.3068],
         [-0.0636, -0.3054,  1.7350, -1.7914, -0.3700]],

        [[-0.6542,  0.3826, -0.9735,  1.6878, -0.2295],
         [-0.6227,  0.1929,  1.3043,  1.3246, -1.2593],
         [ 0.7568, -0.5468, -1.7795, -0.4934,  0.9099]],

        [[ 0.6101,  1.1662, -0.1247, -0.0322,  0.3963],
         [-1.3019,  1.8116, -0.8462,  0.8816, -1.8484],
         [ 0.5544, -0.6557, -1.4419,  0.5660,  0.2648]],

        [[-0.4518,  1.6725, -1.2902, -0.8343,  0.7091],
         [-1.1318, -0.3141, -0.1082,  0.5590,  0.6859],
         [ 0.7135,  1.9085,  0.0153, -

We can now create a transformer easily

For an encoder, we just need to take the same tokens for both queries, keys and values.

For a decoder, we now can extract info from X_from into X_to. X_from will map to queries whereas X_from will map to keys and values.

In [40]:
class TransformerEncoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(*[TransformerBlockEncoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
class TransformerDecoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(*[TransformerBlockDecoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
        
class TransformerTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(
            TransformerEncoderTensorDict(
                num_blocks,
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads
            ),
            TransformerDecoderTensorDict(
                num_blocks,
                from_name,
                to_name,
                from_dim,
                from_len,
                to_dim,
                latent_dim,
                num_heads
            )
            
        )    


In [33]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_encode":torch.randn(batch_size, to_len, to_dim),
        "X_decode":torch.randn(batch_size, from_len, from_dim)
    },
    batch_size=[batch_size]
)

In [42]:
transformer =  TransformerTensorDict(
    6,
    "X_to",
    "X_from",
    to_dim,
    to_len,
    from_dim,
    latent_dim,
    num_heads
)

transformer(tokens)
tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

Now we can look at the model:

In [43]:
transformer

TransformerTensorDict(
    module=ModuleList(
      (0): TransformerEncoderTensorDict(
          module=ModuleList(
            (0): TransformerBlockEncoderTensorDict(
                module=ModuleList(
                  (0): AttentionBlockTensorDict(
                      module=ModuleList(
                        (0): TensorDictModule(
                            module=TokensToQKV(
                              (q): Linear(in_features=5, out_features=10, bias=True)
                              (k): Linear(in_features=5, out_features=10, bias=True)
                              (v): Linear(in_features=5, out_features=10, bias=True)
                            ), 
                            device=cpu, 
                            in_keys=['X_to', 'X_to'], 
                            out_keys=['Q', 'K', 'V'])
                        (1): TensorDictModule(
                            module=SplitHeads(), 
                            device=cpu, 
                            in_keys=[

In [17]:
transformer_decoder

TransformerTensorDict(
    module=ModuleList(
      (0): TransformerBlockTensorDict(
          module=ModuleList(
            (0): TensorDictModule(
                module=TokensToQKV(
                  (q): Linear(in_features=5, out_features=10, bias=True)
                  (k): Linear(in_features=6, out_features=10, bias=True)
                  (v): Linear(in_features=6, out_features=10, bias=True)
                ), 
                device=cpu, 
                in_keys=['X_to', 'X_from'], 
                out_keys=['Q', 'K', 'V'])
            (1): TensorDictModule(
                module=SplitHeads(), 
                device=cpu, 
                in_keys=['Q', 'K', 'V'], 
                out_keys=['Q', 'K', 'V'])
            (2): TensorDictModule(
                module=Attention(
                  (softmax): Softmax(dim=-1)
                  (out): Linear(in_features=10, out_features=5, bias=True)
                ), 
                device=cpu, 
                in_keys=['Q', 'K', '