diff --git a/tutorials/README.md b/tutorials/README.md
index 7ac2d8a1765..516f8d79540 100644
--- a/tutorials/README.md
+++ b/tutorials/README.md
@@ -4,8 +4,10 @@ Get a sense of TorchRL functionalities through our tutorials.
For an overview of TorchRL, try the [TorchRL demo](demo.ipynb).
-Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict
+Make sure you test the [TensorDict tutorial](tensordict.ipynb) to see what TensorDict
is about and what it can do.
-Checkout the [environment demo](envs.ipynb) for a deep dive in the envs
+To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](tensordictmodule.ipynb).
+
+Checkout the [environment tutorial](envs.ipynb) for a deep dive in the envs
functionalities.
diff --git a/tutorials/media/transformer.png b/tutorials/media/transformer.png
new file mode 100644
index 00000000000..b18ca7b93e3
Binary files /dev/null and b/tutorials/media/transformer.png differ
diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py
new file mode 100644
index 00000000000..352f08b9f44
--- /dev/null
+++ b/tutorials/src/transformer.py
@@ -0,0 +1,164 @@
+import torch.nn as nn
+
+
+class TokensToQKV(nn.Module):
+ def __init__(self, to_dim, from_dim, latent_dim):
+ super().__init__()
+ self.q = nn.Linear(to_dim, latent_dim)
+ self.k = nn.Linear(from_dim, latent_dim)
+ self.v = nn.Linear(from_dim, latent_dim)
+
+ def forward(self, X_to, X_from):
+ Q = self.q(X_to)
+ K = self.k(X_from)
+ V = self.v(X_from)
+ return Q, K, V
+
+
+class SplitHeads(nn.Module):
+ def __init__(self, num_heads):
+ super().__init__()
+ self.num_heads = num_heads
+
+ def forward(self, Q, K, V):
+ batch_size, to_num, latent_dim = Q.shape
+ _, from_num, _ = K.shape
+ d_tensor = latent_dim // self.num_heads
+ Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
+ K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
+ V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
+ return Q, K, V
+
+
+class Attention(nn.Module):
+ def __init__(self, latent_dim, to_dim):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+ self.out = nn.Linear(latent_dim, to_dim)
+
+ def forward(self, Q, K, V):
+ batch_size, n_heads, to_num, d_in = Q.shape
+ attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
+ out = attn @ V
+ out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
+ return out, attn
+
+
+class SkipLayerNorm(nn.Module):
+ def __init__(self, to_len, to_dim):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm((to_len, to_dim))
+
+ def forward(self, x_0, x_1):
+ return self.layer_norm(x_0 + x_1)
+
+
+class FFN(nn.Module):
+ def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
+ super().__init__()
+ self.FFN = nn.Sequential(
+ nn.Linear(to_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, to_dim),
+ nn.Dropout(dropout_rate),
+ )
+
+ def forward(self, X):
+ return self.FFN(X)
+
+
+class AttentionBlock(nn.Module):
+ def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
+ super().__init__()
+ self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
+ self.split_heads = SplitHeads(num_heads)
+ self.attention = Attention(latent_dim, to_dim)
+ self.skip = SkipLayerNorm(to_len, to_dim)
+
+ def forward(self, X_to, X_from):
+ Q, K, V = self.tokens_to_qkv(X_to, X_from)
+ Q, K, V = self.split_heads(Q, K, V)
+ out, attention = self.attention(Q, K, V)
+ out = self.skip(X_to, out)
+ return out
+
+
+class EncoderTransformerBlock(nn.Module):
+ def __init__(self, to_dim, to_len, latent_dim, num_heads):
+ super().__init__()
+ self.attention_block = AttentionBlock(
+ to_dim, to_len, to_dim, latent_dim, num_heads
+ )
+ self.FFN = FFN(to_dim, 4 * to_dim)
+ self.skip = SkipLayerNorm(to_len, to_dim)
+
+ def forward(self, X_to):
+ X_to = self.attention_block(X_to, X_to)
+ X_out = self.FFN(X_to)
+ return self.skip(X_out, X_to)
+
+
+class DecoderTransformerBlock(nn.Module):
+ def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
+ super().__init__()
+ self.attention_block = AttentionBlock(
+ to_dim, to_len, from_dim, latent_dim, num_heads
+ )
+ self.encoder_block = EncoderTransformerBlock(
+ to_dim, to_len, latent_dim, num_heads
+ )
+
+ def forward(self, X_to, X_from):
+ X_to = self.attention_block(X_to, X_from)
+ X_to = self.encoder_block(X_to)
+ return X_to
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
+ super().__init__()
+ self.encoder = nn.ModuleList(
+ [
+ EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
+ for i in range(num_blocks)
+ ]
+ )
+
+ def forward(self, X_to):
+ for i in range(len(self.encoder)):
+ X_to = self.encoder[i](X_to)
+ return X_to
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
+ super().__init__()
+ self.decoder = nn.ModuleList(
+ [
+ DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
+ for i in range(num_blocks)
+ ]
+ )
+
+ def forward(self, X_to, X_from):
+ for i in range(len(self.decoder)):
+ X_to = self.decoder[i](X_to, X_from)
+ return X_to
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
+ ):
+ super().__init__()
+ self.encoder = TransformerEncoder(
+ num_blocks, to_dim, to_len, latent_dim, num_heads
+ )
+ self.decoder = TransformerDecoder(
+ num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
+ )
+
+ def forward(self, X_to, X_from):
+ X_to = self.encoder(X_to)
+ X_out = self.decoder(X_from, X_to)
+ return X_out
diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb
new file mode 100644
index 00000000000..71b2baf68d0
--- /dev/null
+++ b/tutorials/tensordictmodule.ipynb
@@ -0,0 +1,1214 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "3be0fafd",
+ "metadata": {},
+ "source": [
+ "# TensorDictModule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94bd315a",
+ "metadata": {},
+ "source": [
+ "We recommand reading the TensorDict tutorial before going through this one."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bbc7e457-48b5-42d2-a8cf-092f0419d2d4",
+ "metadata": {},
+ "source": [
+ "For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`.
\n",
+ "The `TensorDictModule` class is an `nn.Module` that takes a `TensorDict` as input when called.
\n",
+ "It is up to the user to define the keys to be read as input and output."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "129a6de9-cf97-4565-a229-c05ad18df882",
+ "metadata": {},
+ "source": [
+ "## `TensorDictModule` by examples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "5b0241ab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "from torchrl.data import TensorDict\n",
+ "from torchrl.modules import TensorDictModule, TensorDictSequence"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9d1c188a",
+ "metadata": {},
+ "source": [
+ "### Example 1: Simple usage"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1d21a711",
+ "metadata": {},
+ "source": [
+ "We have a `TensorDict` with 2 entries `\"a\"` and `\"b\"` but only the value associated with `\"a\"` has to be read by the network."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "6f33781f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TensorDict(\n",
+ " fields={\n",
+ " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n",
+ " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n",
+ " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([5]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n"
+ ]
+ }
+ ],
+ "source": [
+ "tensordict = TensorDict(\n",
+ " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)},\n",
+ " batch_size=[5],\n",
+ ")\n",
+ "linear = TensorDictModule(\n",
+ " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a_out\"]\n",
+ ")\n",
+ "linear(tensordict)\n",
+ "assert (tensordict.get(\"b\") == 0).all()\n",
+ "print(tensordict)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "00035cbd",
+ "metadata": {},
+ "source": [
+ "### Example 2: Multiple inputs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "06a20c22",
+ "metadata": {},
+ "source": [
+ "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a `TensorDictModule` instance read multiple input values, one must register them in the `in_keys` keyword argument of the constructor."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "69098393",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MergeLinear(nn.Module):\n",
+ " def __init__(self, in_1, in_2, out):\n",
+ " super().__init__()\n",
+ " self.linear_1 = nn.Linear(in_1, out)\n",
+ " self.linear_2 = nn.Linear(in_2, out)\n",
+ "\n",
+ " def forward(self, x_1, x_2):\n",
+ " return (self.linear_1(x_1) + self.linear_2(x_2)) / 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "2dd686bb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorDict(\n",
+ " fields={\n",
+ " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n",
+ " b: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n",
+ " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([5]),\n",
+ " device=cpu,\n",
+ " is_shared=False)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tensordict = TensorDict(\n",
+ " {\n",
+ " \"a\": torch.randn(5, 3),\n",
+ " \"b\": torch.randn(5, 4),\n",
+ " },\n",
+ " batch_size=[5],\n",
+ ")\n",
+ "\n",
+ "mergelinear = TensorDictModule(\n",
+ " MergeLinear(3, 4, 10), in_keys=[\"a\", \"b\"], out_keys=[\"output\"]\n",
+ ")\n",
+ "\n",
+ "mergelinear(tensordict)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11256ae7",
+ "metadata": {},
+ "source": [
+ "### Example 3: Multiple outputs\n",
+ "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values, one must register them in the `out_keys` keyword argument of the constructor."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "0b7f709b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MultiHeadLinear(nn.Module):\n",
+ " def __init__(self, in_1, out_1, out_2):\n",
+ " super().__init__()\n",
+ " self.linear_1 = nn.Linear(in_1, out_1)\n",
+ " self.linear_2 = nn.Linear(in_1, out_2)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.linear_1(x), self.linear_2(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "1b2b465f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorDict(\n",
+ " fields={\n",
+ " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n",
+ " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n",
+ " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([5]),\n",
+ " device=cpu,\n",
+ " is_shared=False)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n",
+ "\n",
+ "splitlinear = TensorDictModule(\n",
+ " MultiHeadLinear(3, 4, 10),\n",
+ " in_keys=[\"a\"],\n",
+ " out_keys=[\"output_1\", \"output_2\"],\n",
+ ")\n",
+ "splitlinear(tensordict)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "859630c3",
+ "metadata": {},
+ "source": [
+ "When having multiple input keys and output keys, make sure they match the order in the module.\n",
+ "\n",
+ "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. \n",
+ "\n",
+ "Unless a `vmap` operator is used, the `TensorDict` is modified in-place."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11d2d2a7-6a55-4f31-972b-041be387f9df",
+ "metadata": {},
+ "source": [
+ "### Example 4: Combining multiples `TensorDictModule` with `TensorDictSequence`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "89b157d5-322c-45d6-bec9-20440b78a2bf",
+ "metadata": {},
+ "source": [
+ "To combine multiples `TensorDictModule` instances, we can use `TensorDictSequence`. We create a list where each `TensorDictModule` must be executed sequentially. `TensorDictSequence` will read and write keys to the tensordict following the sequence of modules provided.\n",
+ "\n",
+ "We can also gather the inputs needed by `TensorDictSequence` with the `in_keys` property, and the outputs keys are found at the `out_keys` attribute."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "7e36d071-df67-4232-a8a9-78e79b32fef2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n",
+ "\n",
+ "splitlinear = TensorDictModule(\n",
+ " MultiHeadLinear(3, 4, 10),\n",
+ " in_keys=[\"a\"],\n",
+ " out_keys=[\"output_1\", \"output_2\"],\n",
+ ")\n",
+ "mergelinear = TensorDictModule(\n",
+ " MergeLinear(4, 10, 13),\n",
+ " in_keys=[\"output_1\", \"output_2\"],\n",
+ " out_keys=[\"output\"],\n",
+ ")\n",
+ "\n",
+ "split_and_merge_linear = TensorDictSequence(splitlinear, mergelinear)\n",
+ "\n",
+ "assert split_and_merge_linear(tensordict)[\n",
+ " \"output\"\n",
+ "].shape == torch.Size([5, 13])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "760118ea",
+ "metadata": {},
+ "source": [
+ "### Example 5: Compatibility with functorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e2718a12",
+ "metadata": {},
+ "source": [
+ "`TensorDictModule` comes with its own `make_functional_with_buffers` method to make it functional (you should not be using `functorch.make_functional_with_buffers(tensordictmodule)`, that will not work in general)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "b553bed1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorDict(\n",
+ " fields={\n",
+ " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n",
+ " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n",
+ " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([5]),\n",
+ " device=cpu,\n",
+ " is_shared=False)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n",
+ "\n",
+ "splitlinear = TensorDictModule(\n",
+ " MultiHeadLinear(3, 4, 10),\n",
+ " in_keys=[\"a\"],\n",
+ " out_keys=[\"output_1\", \"output_2\"],\n",
+ ")\n",
+ "func, (params, buffers) = splitlinear.make_functional_with_buffers()\n",
+ "func(tensordict, params=params, buffers=buffers)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "50ac0393",
+ "metadata": {},
+ "source": [
+ "We can also use the `vmap` operator, here's an example of model ensembling with it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "86ccb7be",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the output tensordict shape is: torch.Size([10, 5])\n"
+ ]
+ }
+ ],
+ "source": [
+ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n",
+ "num_models = 10\n",
+ "model = TensorDictModule(\n",
+ " nn.Linear(3, 4), in_keys=[\"a\"], out_keys=[\"output\"]\n",
+ " )\n",
+ "fmodel, (params, buffers) = model.make_functional_with_buffers()\n",
+ "params = [torch.randn(num_models, *p.shape, device=p.device) for p in params]\n",
+ "buffers = [torch.randn(num_models, *b.shape, device=b.device) for b in buffers]\n",
+ "result_td = fmodel(tensordict, params=params, buffers=buffers, vmap=True)\n",
+ "print(\"the output tensordict shape is: \", result_td.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "31be6c45-10fb-4fd1-a52f-92214b76c00a",
+ "metadata": {},
+ "source": [
+ "## Do's and don't with `TensorDictModule`\n",
+ "\n",
+ "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. \n",
+ "\n",
+ "Don't use `nn.Sequence`, similar to `nn.Module`, it would break features such as `functorch` compatibility. Do use `TensorDictSequence` instead.\n",
+ "\n",
+ "Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place:\n",
+ "\n",
+ "```python\n",
+ "tensordict = module(tensordict) # ok!\n",
+ "tensordict_out = module(tensordict) # don't!\n",
+ "```\n",
+ "\n",
+ "Don't use `make_functional_with_buffers` from `functorch` directly but use `TensorDictModule.make_functional_with_buffers` instead.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "22e65356-d8b3-4197-84b8-598330c1ddc8",
+ "metadata": {},
+ "source": [
+ "## TensorDictModule for RL"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8d49a911-933c-476f-8c9a-00e006ed043c",
+ "metadata": {},
+ "source": [
+ "TorchRL provides a few RL-specific `TensorDictModule` instances that serves domain-specific needs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e33904a6-d405-45db-a713-47493ca8ee33",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "### `ProbabilisticTensorDictModule`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f",
+ "metadata": {},
+ "source": [
+ "`ProbabilisticTensorDictModule` is a special case of a `TensorDictModule` where the output is\n",
+ "sampled given some rule, specified by the input `default_interaction_mode`\n",
+ "argument and the `exploration_mode()` global function. If they conflict, the context manager precedes.\n",
+ "\n",
+ "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n",
+ "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n",
+ "responsible for constructing the distribution (through the `get_dist()` method)\n",
+ "and/or sampling from this distribution (through a regular `__call__()` to the\n",
+ "module).\n",
+ "\n",
+ "One can find the parameters in the output tensordict as well as the log probability if needed"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "9dd7846a-f12c-492e-a2ef-b0c67969234d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TensorDict before going through module: TensorDict(\n",
+ " fields={\n",
+ " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n",
+ " input: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n",
+ "TensorDict after going through module now as keys action, loc and scale: TensorDict(\n",
+ " fields={\n",
+ " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n",
+ " input: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n",
+ " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from torchrl.modules import ProbabilisticTensorDictModule\n",
+ "from torchrl.modules import TanhNormal, NormalParamWrapper\n",
+ "import functorch\n",
+ "td = TensorDict({\"input\": torch.randn(3, 4), \"hidden\": torch.randn(3, 8)}, [3,])\n",
+ "net = NormalParamWrapper(torch.nn.GRUCell(4, 8))\n",
+ "module = TensorDictModule(net, in_keys=[\"input\", \"hidden\"], out_keys=[\"loc\", \"scale\"])\n",
+ "td_module = ProbabilisticTensorDictModule(\n",
+ " module=module,\n",
+ " dist_param_keys=[\"loc\", \"scale\"],\n",
+ " out_key_sample=[\"action\"],\n",
+ " distribution_class=TanhNormal,\n",
+ " return_log_prob=True,\n",
+ " )\n",
+ "print(f\"TensorDict before going through module: {td}\")\n",
+ "td_module(td)\n",
+ "print(f\"TensorDict after going through module now as keys action, loc and scale: {td}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "406b1caa-bcec-4317-b685-10df23352154",
+ "metadata": {},
+ "source": [
+ "### `Actor`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e139de7d-0250-49c0-b495-8b5a404821f5",
+ "metadata": {},
+ "source": [
+ "Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `[\"action\"]`.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cceeade9-47f1-4e92-897a-dd226c9371a6",
+ "metadata": {},
+ "source": [
+ "### `ProbabilisticActor`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4fd0f53e-90aa-49a9-9d8f-5a260255e556",
+ "metadata": {},
+ "source": [
+ "General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.\n",
+ "Similarly to `Actor`, it comes with default values for the `out_keys` (`[\"action\"]`).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dbd48bb2-b93b-4766-b7a7-19d500f17e2d",
+ "metadata": {},
+ "source": [
+ "### `ActorCriticOperator`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044",
+ "metadata": {},
+ "source": [
+ "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`and wraps both an actor network and a value Network.\n",
+ "\n",
+ "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "5b6c6035-f9cc-41e7-bf3a-f88936f93b70",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TensorDict(\n",
+ " fields={\n",
+ " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n",
+ "TensorDict(\n",
+ " fields={\n",
+ " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n",
+ " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n",
+ "Policy: TensorDict(\n",
+ " fields={\n",
+ " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n",
+ " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n",
+ "Critic: TensorDict(\n",
+ " fields={\n",
+ " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n",
+ " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n",
+ " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([3]),\n",
+ " device=cpu,\n",
+ " is_shared=False)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from torchrl.modules import (\n",
+ " MLP,\n",
+ " ActorCriticOperator,\n",
+ " NormalParamWrapper,\n",
+ " TanhNormal,\n",
+ " ValueOperator,\n",
+ ")\n",
+ "from torchrl.modules.tensordict_module import ProbabilisticActor\n",
+ "\n",
+ "module_hidden = torch.nn.Linear(4, 4)\n",
+ "td_module_hidden = TensorDictModule(\n",
+ " module=module_hidden,\n",
+ " in_keys=[\"observation\"],\n",
+ " out_keys=[\"hidden\"],\n",
+ ")\n",
+ "module_action = NormalParamWrapper(torch.nn.Linear(4, 8))\n",
+ "module_action = TensorDictModule(\n",
+ " module_action, in_keys=[\"hidden\"], out_keys=[\"loc\", \"scale\"]\n",
+ ")\n",
+ "td_module_action = ProbabilisticActor(\n",
+ " module=module_action,\n",
+ " dist_param_keys=[\"loc\", \"scale\"],\n",
+ " out_key_sample=[\"action\"],\n",
+ " distribution_class=TanhNormal,\n",
+ " return_log_prob=True,\n",
+ ")\n",
+ "module_value = MLP(in_features=8, out_features=1, num_cells=[])\n",
+ "td_module_value = ValueOperator(\n",
+ " module=module_value,\n",
+ " in_keys=[\"hidden\", \"action\"],\n",
+ " out_keys=[\"state_action_value\"],\n",
+ ")\n",
+ "td_module = ActorCriticOperator(\n",
+ " td_module_hidden, td_module_action, td_module_value\n",
+ ")\n",
+ "td = TensorDict(\n",
+ " {\"observation\": torch.randn(3, 4)},\n",
+ " [\n",
+ " 3,\n",
+ " ],\n",
+ ")\n",
+ "print(td)\n",
+ "td_clone = td_module(td.clone())\n",
+ "print(td_clone)\n",
+ "td_clone = td_module.get_policy_operator()(td.clone())\n",
+ "print(f\"Policy: {td_clone}\") # no value\n",
+ "td_clone = td_module.get_critic_operator()(td.clone())\n",
+ "print(f\"Critic: {td_clone}\") # no action"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11d0f8ea-0292-4ca0-9460-2a2149f7aeef",
+ "metadata": {},
+ "source": [
+ "Other blocks exist such as:\n",
+ "\n",
+ "The `ValueOperator` which is a general class for value functions in RL.\n",
+ "\n",
+ "the `ActorCriticWrapper` which wraps together an actor and a value model that do not share a common observation embedding network.\n",
+ "\n",
+ "The `ActorValueOperator` which wraps together an actor and a value model that share a common observation embedding network."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6304a098",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## Showcase: Implementing a transformer using TensorDictModule\n",
+ "To demonstrate the flexibility of `TensorDictModule`, we are going to create a transformer that reads `TensorDict` objects using `TensorDictModule`.\n",
+ "\n",
+ "The following figure shows the classical transformer architecture (Vaswani et al, 2017) \n",
+ "\n",
+ "
\n",
+ "\n",
+ "We have let the positional encoders aside for simplicity.\n",
+ "\n",
+ "Let's first import the classical transformers blocks (see `src/transformer.py`for more details.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "e1f7ba7b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tutorials.src.transformer import (\n",
+ " FFN,\n",
+ " Attention,\n",
+ " SkipLayerNorm,\n",
+ " SplitHeads,\n",
+ " TokensToQKV,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c3258540-acb2-4090-a374-822dfcb857bd",
+ "metadata": {},
+ "source": [
+ "We first create the `AttentionBlockTensorDict`, the attention block using `TensorDictModule` and `TensorDictSequence`.\n",
+ "\n",
+ "The wiring operation that connects the modules to each other requires us to indicate which key each of them must read and write. Unlike `nn.Sequence`, a `TensorDictSequence` can read/write more than one input/output. Moreover, its components inputs need not be identical to the previous layers outputs, allowing us to code complicated neural architecture."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "eb9775bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class AttentionBlockTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " TensorDictModule(\n",
+ " TokensToQKV(to_dim, from_dim, latent_dim),\n",
+ " in_keys=[to_name, from_name],\n",
+ " out_keys=[\"Q\", \"K\", \"V\"],\n",
+ " ),\n",
+ " TensorDictModule(\n",
+ " SplitHeads(num_heads),\n",
+ " in_keys=[\"Q\", \"K\", \"V\"],\n",
+ " out_keys=[\"Q\", \"K\", \"V\"],\n",
+ " ),\n",
+ " TensorDictModule(\n",
+ " Attention(latent_dim, to_dim),\n",
+ " in_keys=[\"Q\", \"K\", \"V\"],\n",
+ " out_keys=[\"X_out\", \"Attn\"],\n",
+ " ),\n",
+ " TensorDictModule(\n",
+ " SkipLayerNorm(to_len, to_dim),\n",
+ " in_keys=[to_name, \"X_out\"],\n",
+ " out_keys=[to_name],\n",
+ " ),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5f6f291",
+ "metadata": {},
+ "source": [
+ "We build the encoder and decoder blocks that will be part of the transformer thanks to `TensorDictModule`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "f902006d-3f89-4ea6-84e0-a193a53e42db",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TransformerBlockEncoderTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " AttentionBlockTensorDict(\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ),\n",
+ " TensorDictModule(\n",
+ " FFN(to_dim, 4 * to_dim),\n",
+ " in_keys=[to_name],\n",
+ " out_keys=[\"X_out\"],\n",
+ " ),\n",
+ " TensorDictModule(\n",
+ " SkipLayerNorm(to_len, to_dim),\n",
+ " in_keys=[to_name, \"X_out\"],\n",
+ " out_keys=[to_name],\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ "\n",
+ "class TransformerBlockDecoderTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " AttentionBlockTensorDict(\n",
+ " to_name,\n",
+ " to_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " to_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ),\n",
+ " TransformerBlockEncoderTensorDict(\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "42dbfae5",
+ "metadata": {},
+ "source": [
+ "We create the transformer encoder and decoder.\n",
+ "\n",
+ "For an encoder, we just need to take the same tokens for both queries, keys and values.\n",
+ "\n",
+ "For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "1c6c85b5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TransformerEncoderTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " num_blocks,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " *[\n",
+ " TransformerBlockEncoderTensorDict(\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " )\n",
+ " for _ in range(num_blocks)\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "\n",
+ "class TransformerDecoderTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " num_blocks,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " *[\n",
+ " TransformerBlockDecoderTensorDict(\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " )\n",
+ " for _ in range(num_blocks)\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "\n",
+ "class TransformerTensorDict(TensorDictSequence):\n",
+ " def __init__(\n",
+ " self,\n",
+ " num_blocks,\n",
+ " to_name,\n",
+ " from_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " from_len,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ):\n",
+ " super().__init__(\n",
+ " TransformerEncoderTensorDict(\n",
+ " num_blocks,\n",
+ " to_name,\n",
+ " to_name,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " to_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ),\n",
+ " TransformerDecoderTensorDict(\n",
+ " num_blocks,\n",
+ " from_name,\n",
+ " to_name,\n",
+ " from_dim,\n",
+ " from_len,\n",
+ " to_dim,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ " ),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15b1b4e2-918d-40bc-a245-15be0e9cc276",
+ "metadata": {},
+ "source": [
+ "We now test our new `TransformerTensorDict`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "7a680452-1462-4ee6-ba04-dce0bb855870",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorDict(\n",
+ " fields={\n",
+ " Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),\n",
+ " K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n",
+ " Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n",
+ " V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n",
+ " X_decode: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n",
+ " X_encode: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n",
+ " X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32)},\n",
+ " batch_size=torch.Size([8]),\n",
+ " device=cpu,\n",
+ " is_shared=False)"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "to_dim = 5\n",
+ "from_dim = 6\n",
+ "latent_dim = 10\n",
+ "to_len = 3\n",
+ "from_len = 10\n",
+ "batch_size = 8\n",
+ "num_heads = 2\n",
+ "num_blocks = 6\n",
+ "\n",
+ "tokens = TensorDict(\n",
+ " {\n",
+ " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n",
+ " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n",
+ " },\n",
+ " batch_size=[batch_size],\n",
+ ")\n",
+ "\n",
+ "transformer = TransformerTensorDict(\n",
+ " num_blocks,\n",
+ " \"X_encode\",\n",
+ " \"X_decode\",\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " from_len,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ ")\n",
+ "\n",
+ "transformer(tokens)\n",
+ "tokens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f6448dd-5d0d-43fd-9e57-a0ac3b30ecba",
+ "metadata": {},
+ "source": [
+ "We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bb30fb1b-ef8f-4638-af44-69374dd9cfe9",
+ "metadata": {},
+ "source": [
+ "### Benchmarking"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "f75eb50b-b5c4-47ef-9e33-4fa6dfb489ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tutorials.src.transformer import Transformer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "c4ff0abf-1f01-45bd-9dfc-cd26374137c7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_dim = 5\n",
+ "from_dim = 6\n",
+ "latent_dim = 10\n",
+ "to_len = 3\n",
+ "from_len = 10\n",
+ "batch_size = 8\n",
+ "num_heads = 2\n",
+ "num_blocks = 6"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "3e08ff04-1086-4315-bf5e-caa960183c94",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "td_tokens = TensorDict(\n",
+ " {\n",
+ " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n",
+ " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n",
+ " },\n",
+ " batch_size=[batch_size],\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "665c4168-9ac8-45e5-98bc-6e5cc511a209",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_encode = torch.randn(batch_size, to_len, to_dim)\n",
+ "X_decode = torch.randn(batch_size, from_len, from_dim)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "f3c2fd50-bc9b-4911-bd7c-8f8f03bd4ea4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tdtransformer = TransformerTensorDict(\n",
+ " num_blocks,\n",
+ " \"X_encode\",\n",
+ " \"X_decode\",\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " from_len,\n",
+ " latent_dim,\n",
+ " num_heads,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "dfbadd6b-7847-4399-9b22-7e5c58524334",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transformer = Transformer(\n",
+ " num_blocks,\n",
+ " to_dim,\n",
+ " to_len,\n",
+ " from_dim,\n",
+ " from_len,\n",
+ " latent_dim,\n",
+ " num_heads\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6a63de8f-ee8e-4ddf-bf89-f72c2896e1c3",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "#### Inference time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "02a4116b-2b75-47fc-8bc1-3903aa7cd504",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 6.25 ms, sys: 6.73 ms, total: 13 ms\n",
+ "Wall time: 7.57 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "tokens = tdtransformer(td_tokens)\n",
+ "#"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "40158aab-b53a-4a99-82cb-f5595eef7159",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5.32 ms, sys: 9.24 ms, total: 14.6 ms\n",
+ "Wall time: 7.15 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "X_out = transformer(X_encode, X_decode)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "664adff3-1466-47c3-9a80-a0f26171addd",
+ "metadata": {},
+ "source": [
+ "We can see on this minimal example that the overhead introduced by `TensorDictModule` is marginal."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bd08362a-8bb8-49fb-8038-1a60c5c01ea2",
+ "metadata": {},
+ "source": [
+ "Have fun with TensorDictModule!"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}