# TensorDictModule

We recommand reading the TensorDict tutorial before going through this one.

For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`. <br/>
The `TensorDictModule` class is an `nn.Module` that takes a `TensorDict` as input when called. <br/>
It is up to the user to define the keys to be read as input and output.

In [1]:
from torchrl.modules import TensorDictModule, TensorDictSequence
from torchrl.data import TensorDict
import torch.nn as nn
import torch

### Example 1: Simple usage

Let's imagine we have 2 entries `TensorDict`, a and b and we only want to pass a to our network.

In [16]:
tensordict = TensorDict(
    {
        "a": torch.randn(5, 3),
        "b": torch.zeros(5, 4, 3)
    },
    batch_size=[5]
)
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)
assert (tensordict["b"] == torch.zeros(5, 4, 3)).all()
tensordict

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace:

In [18]:
tensordict = TensorDict(
    {
        "a": torch.randn(5, 3),
        "b": torch.zeros(5, 4, 3)
    },
    batch_size=[5]
)

linear = TensorDictModule(
    nn.Linear(3, 10),
    in_keys=["a"], 
    out_keys=["a"]
)

linear(tensordict)
assert tensordict["a"].shape == torch.Size([5,10])

### Example 2: Multiple inputs

Now lets imagine a more complex network that takes 2 entries and average them into a single output:

In [4]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out)
        self.linear_2 = nn.Linear(in_2, out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [5]:
tensordict = TensorDict(
    {
        "a": torch.randn(5, 3),
        "b": torch.randn(5, 4),
    }, 
    batch_size=[5]
)

mergelinear = TensorDictModule(
    MergeLinear(3, 4, 10),
    in_keys=["a","b"],
    out_keys=["output"]
)
    
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example 3: Multiple outputs
We can also map to multiple outputs

In [6]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out_1)
        self.linear_2 = nn.Linear(in_1, out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [7]:
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"]
)
splitlinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the `TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.
`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place.

### Example 4: A transformer with TensorDict?
Let's attempt to create a transformer with `TensorDict` and `TensorDictModule`.

Here's a diagram that sums up the architecture:

<img src="./media/transformer.png" width = 1000px/>

Disclaimer: This implementation don't claim to be "better" than a classical tensor-based implementation. It is just meant to showcase the `TensorDictModule` features.
For simplicity we will not have positional encoders.

Let's first implement the classical transformers blocks.

In [8]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


Now, we can build the encoder and decoder blocks that will be part of the transformer thanks to the TensorDictModule. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block.

In [9]:
class AttentionBlockTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name]),
        )
class TransformerBlockEncoderTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            AttentionBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=[to_name], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, "X_out"], out_keys=[to_name]),
        )
class TransformerBlockDecoderTensorDict(TensorDictSequence):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__(
            AttentionBlockTensorDict(to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads),
            TransformerBlockEncoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads)
        )

In [10]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_to": torch.randn(batch_size, to_len, to_dim),
        "X_from": torch.randn(batch_size, from_len, from_dim)
    },
    batch_size=[batch_size]
)

transformer_block = AttentionBlockTensorDict(
    "X_to",
    "X_from",
    to_dim,
    to_len,
    from_dim,
    latent_dim,
    num_heads
)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [11]:
tokens["X_to"]

tensor([[[-4.1014e-01, -3.8609e-02, -4.6190e-01, -1.1984e+00, -1.0029e+00],
         [ 4.8135e-01,  2.1887e+00,  1.0424e+00,  2.3389e-01,  2.2858e-01],
         [ 9.2826e-02,  1.7278e+00, -1.3197e+00, -8.0110e-01, -7.6280e-01]],

        [[-8.1625e-01,  9.2371e-01, -3.4789e-01, -1.7687e+00, -5.5227e-01],
         [-4.1375e-01, -6.0962e-01, -6.1183e-01,  2.3965e-01,  1.3284e+00],
         [-7.7074e-01,  6.0351e-01, -2.2673e-01,  2.4082e+00,  6.1431e-01]],

        [[ 1.1999e+00,  9.0440e-01,  1.9596e-01, -1.7704e+00,  4.7291e-01],
         [ 1.2238e+00,  1.9033e-01,  2.7292e-01, -2.5550e-01, -1.3451e+00],
         [ 1.0397e+00,  1.1162e+00, -1.2123e+00, -1.0223e+00, -1.0106e+00]],

        [[ 4.9356e-01, -1.4560e-02,  2.1211e-02,  8.7976e-01, -8.7540e-02],
         [-1.4565e-01,  1.0329e+00,  5.8444e-01,  6.7036e-01, -1.5396e+00],
         [ 6.7115e-01, -9.1910e-01,  1.5387e+00, -8.9005e-01, -2.2955e+00]],

        [[ 1.1135e+00,  7.9529e-01, -1.6116e+00, -7.5607e-01, -8.4692e-01],
    

We can now create a transformer easily

For an encoder, we just need to take the same tokens for both queries, keys and values.

For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values.

In [12]:
class TransformerEncoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(
            *[TransformerBlockEncoderTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads
            ) for _ in range(num_blocks)
        ])
class TransformerDecoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(
            *[TransformerBlockDecoderTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads
            ) for _ in range(num_blocks)
             ])
        
class TransformerTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads
    ):
        super().__init__(
            TransformerEncoderTensorDict(
                num_blocks,
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads
            ),
            TransformerDecoderTensorDict(
                num_blocks,
                from_name,
                to_name,
                from_dim,
                from_len,
                to_dim,
                latent_dim,
                num_heads
            )
            
        )    


In [13]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_encode":torch.randn(batch_size, to_len, to_dim),
        "X_decode":torch.randn(batch_size, from_len, from_dim)
    },
    batch_size=[batch_size]
)

In [14]:
transformer =  TransformerTensorDict(
    6,
    "X_to",
    "X_from",
    to_dim,
    to_len,
    from_dim,
    latent_dim,
    num_heads
)

transformer(tokens)
tokens

TypeError: linear(): argument 'input' (position 1) must be Tensor, not NoneType

Now we can look at the model:

In [None]:
transformer