# TensorDictModule

We recommand reading the TensorDict tutorial before going through this one.

For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`. <br/>
The `TensorDictModule` class is an `nn.Module` that takes a `TensorDict` as input when called. <br/>
It is up to the user to define the keys to be read as input and output.

## `TensorDictModule` by examples

In [4]:
import torch
import torch.nn as nn

from torchrl.data import TensorDict
from torchrl.modules import TensorDictModule, TensorDictSequence

### Example 1: Simple usage

Let's suppose we have `TensorDict` with 2 entries `"a"` and `"b"` but only the value associated with `"a"` has to be read by the network.

In [15]:
tensordict = TensorDict(
    {"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)},
    batch_size=[5],
)
linear = TensorDictModule(
    nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"]
)
linear(tensordict)
assert (tensordict.get("b") == 0).all()
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)


### Example 2: Multiple inputs

Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor.

In [6]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out)
        self.linear_2 = nn.Linear(in_2, out)

    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2)) / 2

In [7]:
tensordict = TensorDict(
    {
        "a": torch.randn(5, 3),
        "b": torch.randn(5, 4),
    },
    batch_size=[5],
)

mergelinear = TensorDictModule(
    MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"]
)

mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example 3: Multiple outputs
TensorDictModule not only supports multiple inputs but also multiple outputs.

In [8]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out_1)
        self.linear_2 = nn.Linear(in_1, out_2)

    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [9]:
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"],
)
splitlinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the `TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.
`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place.

### Example 4: Combining multiples `TensorDictModule` with `TensorDictSequence`

To combine multiples `TensorDictModule`instances, we can une `TensorDictSequence`. This block will take the input of the n-1th `TensorDictModule` in a list and feed it to the nth `TensorDictModule`

In [13]:
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"],
)
mergelinear = TensorDictModule(
    MergeLinear(4, 10, 13), in_keys=["output_1", "output_2"], out_keys=["output"]
)

split_and_merge_linear = TensorDictSequence(splitlinear, mergelinear)

assert split_and_merge_linear(tensordict)['output'].shape == torch.Size([5, 13])

### Example 5: Compatibility with functorch

TensorDictModule is compatible with functorch. We can use make_functional_with_buffers on top of it.

In [9]:
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"],
)
func, (params, buffers) = splitlinear.make_functional_with_buffers()
func(tensordict, params=params, buffers=buffers)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also use vmap. Let's do some model ensembling with it.

In [10]:
num_models = 10

tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear_models = [
    TensorDictModule(
        nn.Linear(3, 10), in_keys=["a"], out_keys=["output"]
    )
    for _ in range(num_models)
]


def transpose_stack(tuple_of_tuple_of_tensors):
    tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
    results = tuple(
        torch.stack(shards).detach()
        for shards in tuple_of_tuple_of_tensors
    )
    return results


func = splitlinear_models[0].make_functional_with_buffers()[0]
params, buffers = zip(
    *[
        splitlinear.make_functional_with_buffers()[1]
        for splitlinear in splitlinear_models
    ]
)
params = transpose_stack(params)
buffers = transpose_stack(buffers)
func(tensordict, params=params, buffers=buffers, vmap=True).shape

RuntimeError: batched == nullptr INTERNAL ASSERT FAILED at "/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-i1bssf2g/functorch/csrc/Interpreter.cpp":95, please report a bug to PyTorch. 

In [11]:
from functorch import make_functional_with_buffers

tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    nn.Linear(3, 10), in_keys=["a"], out_keys=["output"]
)
func, param, buffers = make_functional_with_buffers(splitlinear)
func(param, buffers, tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example 6: Implementing a transformer using TensorDictModule
We can easily create a transformer that reads TensorDict objects using TensorDictModule.

The following figure shows the classical transformer architecture (Vaswani et al, 2017) 

<img src="./media/transformer.png" width = 1000px/>

We have let the positional encoders aside for simplicity.

Let's first implement the classical transformers blocks.

In [12]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)

    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V


class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads

    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(
            batch_size, to_num, self.num_heads, d_tensor
        ).transpose(1, 2)
        K = K.reshape(
            batch_size, from_num, self.num_heads, d_tensor
        ).transpose(1, 2)
        V = V.reshape(
            batch_size, from_num, self.num_heads, d_tensor
        ).transpose(1, 2)
        return Q, K, V


class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)

    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
        out = attn @ V
        out = self.out(
            out.transpose(1, 2).reshape(
                batch_size, to_num, n_heads * d_in
            )
        )
        return out, attn


class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))

    def forward(self, x_0, x_1):
        return self.layer_norm(x_0 + x_1)


class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, X):
        return self.FFN(X)

We can build the encoder and decoder blocks that will be part of the transformer thanks to the `TensorDictModule`. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block.

In [13]:
class AttentionBlockTensorDict(TensorDictSequence):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            TensorDictModule(
                TokensToQKV(to_dim, from_dim, latent_dim),
                in_keys=[to_name, from_name],
                out_keys=["Q", "K", "V"],
            ),
            TensorDictModule(
                SplitHeads(num_heads),
                in_keys=["Q", "K", "V"],
                out_keys=["Q", "K", "V"],
            ),
            TensorDictModule(
                Attention(latent_dim, to_dim),
                in_keys=["Q", "K", "V"],
                out_keys=["X_out", "Attn"],
            ),
            TensorDictModule(
                SkipLayerNorm(to_len, to_dim),
                in_keys=[to_name, "X_out"],
                out_keys=[to_name],
            ),
        )


class TransformerBlockEncoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            AttentionBlockTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads,
            ),
            TensorDictModule(
                FFN(to_dim, 4 * to_dim),
                in_keys=[to_name],
                out_keys=["X_out"],
            ),
            TensorDictModule(
                SkipLayerNorm(to_len, to_dim),
                in_keys=[to_name, "X_out"],
                out_keys=[to_name],
            ),
        )


class TransformerBlockDecoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            AttentionBlockTensorDict(
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
            TransformerBlockEncoderTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads,
            ),
        )

In [14]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_to": torch.randn(batch_size, to_len, to_dim),
        "X_from": torch.randn(batch_size, from_len, from_dim),
    },
    batch_size=[batch_size],
)

transformer_block = AttentionBlockTensorDict(
    "X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads
)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the attention can now be found at `tokens["X_to"]`

We create the transformer encoder and decoder.

For an encoder, we just need to take the same tokens for both queries, keys and values.

For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values.

In [15]:
class TransformerEncoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            *[
                TransformerBlockEncoderTensorDict(
                    to_name,
                    from_name,
                    to_dim,
                    to_len,
                    from_dim,
                    latent_dim,
                    num_heads,
                )
                for _ in range(num_blocks)
            ]
        )


class TransformerDecoderTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            *[
                TransformerBlockDecoderTensorDict(
                    to_name,
                    from_name,
                    to_dim,
                    to_len,
                    from_dim,
                    latent_dim,
                    num_heads,
                )
                for _ in range(num_blocks)
            ]
        )


class TransformerTensorDict(TensorDictSequence):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            TransformerEncoderTensorDict(
                num_blocks,
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
            TransformerDecoderTensorDict(
                num_blocks,
                from_name,
                to_name,
                from_dim,
                from_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
        )

In [16]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict(
    {
        "X_encode": torch.randn(batch_size, to_len, to_dim),
        "X_decode": torch.randn(batch_size, from_len, from_dim),
    },
    batch_size=[batch_size],
)

transformer = TransformerTensorDict(
    6,
    "X_encode",
    "X_decode",
    to_dim,
    to_len,
    from_dim,
    latent_dim,
    num_heads,
)

transformer(tokens)
tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        X_decode: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_encode: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations

## TensorDictModule for RL

In the context of RL torchrl offers a few wrappers on `TensorDictModule`

### `ProbabilisticTensorDictModule`

`ProbabilisticTDModule` is a special case of a `TensorDictModule` where the output is
sampled given some rule, specified by the input `default_interaction_mode`
argument and the `exploration_mode()` global function.

It consists in a wrapper around another `TensorDictModule` that returns a tensordict
updated with the distribution parameters. `ProbabilisticTensorDictModule` is
responsible for constructing the distribution (through the `get_dist()` method)
and/or sampling from this distribution (through a regular `__call__()` to the
module).

### `Actor`

Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `["action"]`.


### `ProbabilisticActor`

General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.
Similarly to `Actor`, it comes with default values for the `out_keys` (`["action"]`).


### `ActorCriticOperator`

Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`.
and wraps both and actor network and a value Network. 
`ActorCriticOperator` will first compute the action from the actor and then the value according to this action.

Have fun with TensorDictModule!