diff --git a/tutorials/README.md b/tutorials/README.md index 7ac2d8a1765..516f8d79540 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -4,8 +4,10 @@ Get a sense of TorchRL functionalities through our tutorials. For an overview of TorchRL, try the [TorchRL demo](demo.ipynb). -Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict +Make sure you test the [TensorDict tutorial](tensordict.ipynb) to see what TensorDict is about and what it can do. -Checkout the [environment demo](envs.ipynb) for a deep dive in the envs +To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](tensordictmodule.ipynb). + +Checkout the [environment tutorial](envs.ipynb) for a deep dive in the envs functionalities. diff --git a/tutorials/media/transformer.png b/tutorials/media/transformer.png new file mode 100644 index 00000000000..b18ca7b93e3 Binary files /dev/null and b/tutorials/media/transformer.png differ diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py new file mode 100644 index 00000000000..352f08b9f44 --- /dev/null +++ b/tutorials/src/transformer.py @@ -0,0 +1,164 @@ +import torch.nn as nn + + +class TokensToQKV(nn.Module): + def __init__(self, to_dim, from_dim, latent_dim): + super().__init__() + self.q = nn.Linear(to_dim, latent_dim) + self.k = nn.Linear(from_dim, latent_dim) + self.v = nn.Linear(from_dim, latent_dim) + + def forward(self, X_to, X_from): + Q = self.q(X_to) + K = self.k(X_from) + V = self.v(X_from) + return Q, K, V + + +class SplitHeads(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, Q, K, V): + batch_size, to_num, latent_dim = Q.shape + _, from_num, _ = K.shape + d_tensor = latent_dim // self.num_heads + Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2) + K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) + V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) + return Q, K, V + + +class Attention(nn.Module): + def __init__(self, latent_dim, to_dim): + super().__init__() + self.softmax = nn.Softmax(dim=-1) + self.out = nn.Linear(latent_dim, to_dim) + + def forward(self, Q, K, V): + batch_size, n_heads, to_num, d_in = Q.shape + attn = self.softmax(Q @ K.transpose(2, 3) / d_in) + out = attn @ V + out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in)) + return out, attn + + +class SkipLayerNorm(nn.Module): + def __init__(self, to_len, to_dim): + super().__init__() + self.layer_norm = nn.LayerNorm((to_len, to_dim)) + + def forward(self, x_0, x_1): + return self.layer_norm(x_0 + x_1) + + +class FFN(nn.Module): + def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): + super().__init__() + self.FFN = nn.Sequential( + nn.Linear(to_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, to_dim), + nn.Dropout(dropout_rate), + ) + + def forward(self, X): + return self.FFN(X) + + +class AttentionBlock(nn.Module): + def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim) + self.split_heads = SplitHeads(num_heads) + self.attention = Attention(latent_dim, to_dim) + self.skip = SkipLayerNorm(to_len, to_dim) + + def forward(self, X_to, X_from): + Q, K, V = self.tokens_to_qkv(X_to, X_from) + Q, K, V = self.split_heads(Q, K, V) + out, attention = self.attention(Q, K, V) + out = self.skip(X_to, out) + return out + + +class EncoderTransformerBlock(nn.Module): + def __init__(self, to_dim, to_len, latent_dim, num_heads): + super().__init__() + self.attention_block = AttentionBlock( + to_dim, to_len, to_dim, latent_dim, num_heads + ) + self.FFN = FFN(to_dim, 4 * to_dim) + self.skip = SkipLayerNorm(to_len, to_dim) + + def forward(self, X_to): + X_to = self.attention_block(X_to, X_to) + X_out = self.FFN(X_to) + return self.skip(X_out, X_to) + + +class DecoderTransformerBlock(nn.Module): + def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.attention_block = AttentionBlock( + to_dim, to_len, from_dim, latent_dim, num_heads + ) + self.encoder_block = EncoderTransformerBlock( + to_dim, to_len, latent_dim, num_heads + ) + + def forward(self, X_to, X_from): + X_to = self.attention_block(X_to, X_from) + X_to = self.encoder_block(X_to) + return X_to + + +class TransformerEncoder(nn.Module): + def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads): + super().__init__() + self.encoder = nn.ModuleList( + [ + EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads) + for i in range(num_blocks) + ] + ) + + def forward(self, X_to): + for i in range(len(self.encoder)): + X_to = self.encoder[i](X_to) + return X_to + + +class TransformerDecoder(nn.Module): + def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.decoder = nn.ModuleList( + [ + DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads) + for i in range(num_blocks) + ] + ) + + def forward(self, X_to, X_from): + for i in range(len(self.decoder)): + X_to = self.decoder[i](X_to, X_from) + return X_to + + +class Transformer(nn.Module): + def __init__( + self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads + ): + super().__init__() + self.encoder = TransformerEncoder( + num_blocks, to_dim, to_len, latent_dim, num_heads + ) + self.decoder = TransformerDecoder( + num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads + ) + + def forward(self, X_to, X_from): + X_to = self.encoder(X_to) + X_out = self.decoder(X_from, X_to) + return X_out diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb new file mode 100644 index 00000000000..71b2baf68d0 --- /dev/null +++ b/tutorials/tensordictmodule.ipynb @@ -0,0 +1,1214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3be0fafd", + "metadata": {}, + "source": [ + "# TensorDictModule" + ] + }, + { + "cell_type": "markdown", + "id": "94bd315a", + "metadata": {}, + "source": [ + "We recommand reading the TensorDict tutorial before going through this one." + ] + }, + { + "cell_type": "markdown", + "id": "bbc7e457-48b5-42d2-a8cf-092f0419d2d4", + "metadata": {}, + "source": [ + "For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`.
\n", + "The `TensorDictModule` class is an `nn.Module` that takes a `TensorDict` as input when called.
\n", + "It is up to the user to define the keys to be read as input and output." + ] + }, + { + "cell_type": "markdown", + "id": "129a6de9-cf97-4565-a229-c05ad18df882", + "metadata": {}, + "source": [ + "## `TensorDictModule` by examples" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5b0241ab", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from torchrl.data import TensorDict\n", + "from torchrl.modules import TensorDictModule, TensorDictSequence" + ] + }, + { + "cell_type": "markdown", + "id": "9d1c188a", + "metadata": {}, + "source": [ + "### Example 1: Simple usage" + ] + }, + { + "cell_type": "markdown", + "id": "1d21a711", + "metadata": {}, + "source": [ + "We have a `TensorDict` with 2 entries `\"a\"` and `\"b\"` but only the value associated with `\"a\"` has to be read by the network." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6f33781f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] + } + ], + "source": [ + "tensordict = TensorDict(\n", + " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)},\n", + " batch_size=[5],\n", + ")\n", + "linear = TensorDictModule(\n", + " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a_out\"]\n", + ")\n", + "linear(tensordict)\n", + "assert (tensordict.get(\"b\") == 0).all()\n", + "print(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "00035cbd", + "metadata": {}, + "source": [ + "### Example 2: Multiple inputs" + ] + }, + { + "cell_type": "markdown", + "id": "06a20c22", + "metadata": {}, + "source": [ + "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a `TensorDictModule` instance read multiple input values, one must register them in the `in_keys` keyword argument of the constructor." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "69098393", + "metadata": {}, + "outputs": [], + "source": [ + "class MergeLinear(nn.Module):\n", + " def __init__(self, in_1, in_2, out):\n", + " super().__init__()\n", + " self.linear_1 = nn.Linear(in_1, out)\n", + " self.linear_2 = nn.Linear(in_2, out)\n", + "\n", + " def forward(self, x_1, x_2):\n", + " return (self.linear_1(x_1) + self.linear_2(x_2)) / 2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2dd686bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict(\n", + " {\n", + " \"a\": torch.randn(5, 3),\n", + " \"b\": torch.randn(5, 4),\n", + " },\n", + " batch_size=[5],\n", + ")\n", + "\n", + "mergelinear = TensorDictModule(\n", + " MergeLinear(3, 4, 10), in_keys=[\"a\", \"b\"], out_keys=[\"output\"]\n", + ")\n", + "\n", + "mergelinear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "11256ae7", + "metadata": {}, + "source": [ + "### Example 3: Multiple outputs\n", + "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values, one must register them in the `out_keys` keyword argument of the constructor." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0b7f709b", + "metadata": {}, + "outputs": [], + "source": [ + "class MultiHeadLinear(nn.Module):\n", + " def __init__(self, in_1, out_1, out_2):\n", + " super().__init__()\n", + " self.linear_1 = nn.Linear(in_1, out_1)\n", + " self.linear_2 = nn.Linear(in_1, out_2)\n", + "\n", + " def forward(self, x):\n", + " return self.linear_1(x), self.linear_2(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1b2b465f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", + ")\n", + "splitlinear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "859630c3", + "metadata": {}, + "source": [ + "When having multiple input keys and output keys, make sure they match the order in the module.\n", + "\n", + "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. \n", + "\n", + "Unless a `vmap` operator is used, the `TensorDict` is modified in-place." + ] + }, + { + "cell_type": "markdown", + "id": "11d2d2a7-6a55-4f31-972b-041be387f9df", + "metadata": {}, + "source": [ + "### Example 4: Combining multiples `TensorDictModule` with `TensorDictSequence`" + ] + }, + { + "cell_type": "markdown", + "id": "89b157d5-322c-45d6-bec9-20440b78a2bf", + "metadata": {}, + "source": [ + "To combine multiples `TensorDictModule` instances, we can use `TensorDictSequence`. We create a list where each `TensorDictModule` must be executed sequentially. `TensorDictSequence` will read and write keys to the tensordict following the sequence of modules provided.\n", + "\n", + "We can also gather the inputs needed by `TensorDictSequence` with the `in_keys` property, and the outputs keys are found at the `out_keys` attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", + "metadata": {}, + "outputs": [], + "source": [ + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", + ")\n", + "mergelinear = TensorDictModule(\n", + " MergeLinear(4, 10, 13),\n", + " in_keys=[\"output_1\", \"output_2\"],\n", + " out_keys=[\"output\"],\n", + ")\n", + "\n", + "split_and_merge_linear = TensorDictSequence(splitlinear, mergelinear)\n", + "\n", + "assert split_and_merge_linear(tensordict)[\n", + " \"output\"\n", + "].shape == torch.Size([5, 13])" + ] + }, + { + "cell_type": "markdown", + "id": "760118ea", + "metadata": {}, + "source": [ + "### Example 5: Compatibility with functorch" + ] + }, + { + "cell_type": "markdown", + "id": "e2718a12", + "metadata": {}, + "source": [ + "`TensorDictModule` comes with its own `make_functional_with_buffers` method to make it functional (you should not be using `functorch.make_functional_with_buffers(tensordictmodule)`, that will not work in general)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b553bed1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", + ")\n", + "func, (params, buffers) = splitlinear.make_functional_with_buffers()\n", + "func(tensordict, params=params, buffers=buffers)" + ] + }, + { + "cell_type": "markdown", + "id": "50ac0393", + "metadata": {}, + "source": [ + "We can also use the `vmap` operator, here's an example of model ensembling with it:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "86ccb7be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the output tensordict shape is: torch.Size([10, 5])\n" + ] + } + ], + "source": [ + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "num_models = 10\n", + "model = TensorDictModule(\n", + " nn.Linear(3, 4), in_keys=[\"a\"], out_keys=[\"output\"]\n", + " )\n", + "fmodel, (params, buffers) = model.make_functional_with_buffers()\n", + "params = [torch.randn(num_models, *p.shape, device=p.device) for p in params]\n", + "buffers = [torch.randn(num_models, *b.shape, device=b.device) for b in buffers]\n", + "result_td = fmodel(tensordict, params=params, buffers=buffers, vmap=True)\n", + "print(\"the output tensordict shape is: \", result_td.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "31be6c45-10fb-4fd1-a52f-92214b76c00a", + "metadata": {}, + "source": [ + "## Do's and don't with `TensorDictModule`\n", + "\n", + "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. \n", + "\n", + "Don't use `nn.Sequence`, similar to `nn.Module`, it would break features such as `functorch` compatibility. Do use `TensorDictSequence` instead.\n", + "\n", + "Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place:\n", + "\n", + "```python\n", + "tensordict = module(tensordict) # ok!\n", + "tensordict_out = module(tensordict) # don't!\n", + "```\n", + "\n", + "Don't use `make_functional_with_buffers` from `functorch` directly but use `TensorDictModule.make_functional_with_buffers` instead.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "22e65356-d8b3-4197-84b8-598330c1ddc8", + "metadata": {}, + "source": [ + "## TensorDictModule for RL" + ] + }, + { + "cell_type": "markdown", + "id": "8d49a911-933c-476f-8c9a-00e006ed043c", + "metadata": {}, + "source": [ + "TorchRL provides a few RL-specific `TensorDictModule` instances that serves domain-specific needs." + ] + }, + { + "cell_type": "markdown", + "id": "e33904a6-d405-45db-a713-47493ca8ee33", + "metadata": { + "tags": [] + }, + "source": [ + "### `ProbabilisticTensorDictModule`" + ] + }, + { + "cell_type": "markdown", + "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f", + "metadata": {}, + "source": [ + "`ProbabilisticTensorDictModule` is a special case of a `TensorDictModule` where the output is\n", + "sampled given some rule, specified by the input `default_interaction_mode`\n", + "argument and the `exploration_mode()` global function. If they conflict, the context manager precedes.\n", + "\n", + "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n", + "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n", + "responsible for constructing the distribution (through the `get_dist()` method)\n", + "and/or sampling from this distribution (through a regular `__call__()` to the\n", + "module).\n", + "\n", + "One can find the parameters in the output tensordict as well as the log probability if needed" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dd7846a-f12c-492e-a2ef-b0c67969234d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict before going through module: TensorDict(\n", + " fields={\n", + " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", + " input: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict after going through module now as keys action, loc and scale: TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", + " input: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] + } + ], + "source": [ + "from torchrl.modules import ProbabilisticTensorDictModule\n", + "from torchrl.modules import TanhNormal, NormalParamWrapper\n", + "import functorch\n", + "td = TensorDict({\"input\": torch.randn(3, 4), \"hidden\": torch.randn(3, 8)}, [3,])\n", + "net = NormalParamWrapper(torch.nn.GRUCell(4, 8))\n", + "module = TensorDictModule(net, in_keys=[\"input\", \"hidden\"], out_keys=[\"loc\", \"scale\"])\n", + "td_module = ProbabilisticTensorDictModule(\n", + " module=module,\n", + " dist_param_keys=[\"loc\", \"scale\"],\n", + " out_key_sample=[\"action\"],\n", + " distribution_class=TanhNormal,\n", + " return_log_prob=True,\n", + " )\n", + "print(f\"TensorDict before going through module: {td}\")\n", + "td_module(td)\n", + "print(f\"TensorDict after going through module now as keys action, loc and scale: {td}\")" + ] + }, + { + "cell_type": "markdown", + "id": "406b1caa-bcec-4317-b685-10df23352154", + "metadata": {}, + "source": [ + "### `Actor`" + ] + }, + { + "cell_type": "markdown", + "id": "e139de7d-0250-49c0-b495-8b5a404821f5", + "metadata": {}, + "source": [ + "Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `[\"action\"]`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "cceeade9-47f1-4e92-897a-dd226c9371a6", + "metadata": {}, + "source": [ + "### `ProbabilisticActor`" + ] + }, + { + "cell_type": "markdown", + "id": "4fd0f53e-90aa-49a9-9d8f-5a260255e556", + "metadata": {}, + "source": [ + "General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.\n", + "Similarly to `Actor`, it comes with default values for the `out_keys` (`[\"action\"]`).\n" + ] + }, + { + "cell_type": "markdown", + "id": "dbd48bb2-b93b-4766-b7a7-19d500f17e2d", + "metadata": {}, + "source": [ + "### `ActorCriticOperator`" + ] + }, + { + "cell_type": "markdown", + "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044", + "metadata": {}, + "source": [ + "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`and wraps both an actor network and a value Network.\n", + "\n", + "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5b6c6035-f9cc-41e7-bf3a-f88936f93b70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "Policy: TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "Critic: TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] + } + ], + "source": [ + "from torchrl.modules import (\n", + " MLP,\n", + " ActorCriticOperator,\n", + " NormalParamWrapper,\n", + " TanhNormal,\n", + " ValueOperator,\n", + ")\n", + "from torchrl.modules.tensordict_module import ProbabilisticActor\n", + "\n", + "module_hidden = torch.nn.Linear(4, 4)\n", + "td_module_hidden = TensorDictModule(\n", + " module=module_hidden,\n", + " in_keys=[\"observation\"],\n", + " out_keys=[\"hidden\"],\n", + ")\n", + "module_action = NormalParamWrapper(torch.nn.Linear(4, 8))\n", + "module_action = TensorDictModule(\n", + " module_action, in_keys=[\"hidden\"], out_keys=[\"loc\", \"scale\"]\n", + ")\n", + "td_module_action = ProbabilisticActor(\n", + " module=module_action,\n", + " dist_param_keys=[\"loc\", \"scale\"],\n", + " out_key_sample=[\"action\"],\n", + " distribution_class=TanhNormal,\n", + " return_log_prob=True,\n", + ")\n", + "module_value = MLP(in_features=8, out_features=1, num_cells=[])\n", + "td_module_value = ValueOperator(\n", + " module=module_value,\n", + " in_keys=[\"hidden\", \"action\"],\n", + " out_keys=[\"state_action_value\"],\n", + ")\n", + "td_module = ActorCriticOperator(\n", + " td_module_hidden, td_module_action, td_module_value\n", + ")\n", + "td = TensorDict(\n", + " {\"observation\": torch.randn(3, 4)},\n", + " [\n", + " 3,\n", + " ],\n", + ")\n", + "print(td)\n", + "td_clone = td_module(td.clone())\n", + "print(td_clone)\n", + "td_clone = td_module.get_policy_operator()(td.clone())\n", + "print(f\"Policy: {td_clone}\") # no value\n", + "td_clone = td_module.get_critic_operator()(td.clone())\n", + "print(f\"Critic: {td_clone}\") # no action" + ] + }, + { + "cell_type": "markdown", + "id": "11d0f8ea-0292-4ca0-9460-2a2149f7aeef", + "metadata": {}, + "source": [ + "Other blocks exist such as:\n", + "\n", + "The `ValueOperator` which is a general class for value functions in RL.\n", + "\n", + "the `ActorCriticWrapper` which wraps together an actor and a value model that do not share a common observation embedding network.\n", + "\n", + "The `ActorValueOperator` which wraps together an actor and a value model that share a common observation embedding network." + ] + }, + { + "cell_type": "markdown", + "id": "6304a098", + "metadata": { + "tags": [] + }, + "source": [ + "## Showcase: Implementing a transformer using TensorDictModule\n", + "To demonstrate the flexibility of `TensorDictModule`, we are going to create a transformer that reads `TensorDict` objects using `TensorDictModule`.\n", + "\n", + "The following figure shows the classical transformer architecture (Vaswani et al, 2017) \n", + "\n", + "\n", + "\n", + "We have let the positional encoders aside for simplicity.\n", + "\n", + "Let's first import the classical transformers blocks (see `src/transformer.py`for more details.)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e1f7ba7b", + "metadata": {}, + "outputs": [], + "source": [ + "from tutorials.src.transformer import (\n", + " FFN,\n", + " Attention,\n", + " SkipLayerNorm,\n", + " SplitHeads,\n", + " TokensToQKV,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c3258540-acb2-4090-a374-822dfcb857bd", + "metadata": {}, + "source": [ + "We first create the `AttentionBlockTensorDict`, the attention block using `TensorDictModule` and `TensorDictSequence`.\n", + "\n", + "The wiring operation that connects the modules to each other requires us to indicate which key each of them must read and write. Unlike `nn.Sequence`, a `TensorDictSequence` can read/write more than one input/output. Moreover, its components inputs need not be identical to the previous layers outputs, allowing us to code complicated neural architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "eb9775bd", + "metadata": {}, + "outputs": [], + "source": [ + "class AttentionBlockTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " TensorDictModule(\n", + " TokensToQKV(to_dim, from_dim, latent_dim),\n", + " in_keys=[to_name, from_name],\n", + " out_keys=[\"Q\", \"K\", \"V\"],\n", + " ),\n", + " TensorDictModule(\n", + " SplitHeads(num_heads),\n", + " in_keys=[\"Q\", \"K\", \"V\"],\n", + " out_keys=[\"Q\", \"K\", \"V\"],\n", + " ),\n", + " TensorDictModule(\n", + " Attention(latent_dim, to_dim),\n", + " in_keys=[\"Q\", \"K\", \"V\"],\n", + " out_keys=[\"X_out\", \"Attn\"],\n", + " ),\n", + " TensorDictModule(\n", + " SkipLayerNorm(to_len, to_dim),\n", + " in_keys=[to_name, \"X_out\"],\n", + " out_keys=[to_name],\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "b5f6f291", + "metadata": {}, + "source": [ + "We build the encoder and decoder blocks that will be part of the transformer thanks to `TensorDictModule`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f902006d-3f89-4ea6-84e0-a193a53e42db", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlockEncoderTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " AttentionBlockTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ),\n", + " TensorDictModule(\n", + " FFN(to_dim, 4 * to_dim),\n", + " in_keys=[to_name],\n", + " out_keys=[\"X_out\"],\n", + " ),\n", + " TensorDictModule(\n", + " SkipLayerNorm(to_len, to_dim),\n", + " in_keys=[to_name, \"X_out\"],\n", + " out_keys=[to_name],\n", + " ),\n", + " )\n", + "\n", + "\n", + "class TransformerBlockDecoderTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " AttentionBlockTensorDict(\n", + " to_name,\n", + " to_name,\n", + " to_dim,\n", + " to_len,\n", + " to_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ),\n", + " TransformerBlockEncoderTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "42dbfae5", + "metadata": {}, + "source": [ + "We create the transformer encoder and decoder.\n", + "\n", + "For an encoder, we just need to take the same tokens for both queries, keys and values.\n", + "\n", + "For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1c6c85b5", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerEncoderTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " num_blocks,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " *[\n", + " TransformerBlockEncoderTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " )\n", + " for _ in range(num_blocks)\n", + " ]\n", + " )\n", + "\n", + "\n", + "class TransformerDecoderTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " num_blocks,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " *[\n", + " TransformerBlockDecoderTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " )\n", + " for _ in range(num_blocks)\n", + " ]\n", + " )\n", + "\n", + "\n", + "class TransformerTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " num_blocks,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads,\n", + " ):\n", + " super().__init__(\n", + " TransformerEncoderTensorDict(\n", + " num_blocks,\n", + " to_name,\n", + " to_name,\n", + " to_dim,\n", + " to_len,\n", + " to_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ),\n", + " TransformerDecoderTensorDict(\n", + " num_blocks,\n", + " from_name,\n", + " to_name,\n", + " from_dim,\n", + " from_len,\n", + " to_dim,\n", + " latent_dim,\n", + " num_heads,\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "15b1b4e2-918d-40bc-a245-15be0e9cc276", + "metadata": {}, + "source": [ + "We now test our new `TransformerTensorDict`" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7a680452-1462-4ee6-ba04-dce0bb855870", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),\n", + " K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", + " Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", + " V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", + " X_decode: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", + " X_encode: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n", + " X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32)},\n", + " batch_size=torch.Size([8]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "to_dim = 5\n", + "from_dim = 6\n", + "latent_dim = 10\n", + "to_len = 3\n", + "from_len = 10\n", + "batch_size = 8\n", + "num_heads = 2\n", + "num_blocks = 6\n", + "\n", + "tokens = TensorDict(\n", + " {\n", + " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n", + " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n", + " },\n", + " batch_size=[batch_size],\n", + ")\n", + "\n", + "transformer = TransformerTensorDict(\n", + " num_blocks,\n", + " \"X_encode\",\n", + " \"X_decode\",\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads,\n", + ")\n", + "\n", + "transformer(tokens)\n", + "tokens" + ] + }, + { + "cell_type": "markdown", + "id": "3f6448dd-5d0d-43fd-9e57-a0ac3b30ecba", + "metadata": {}, + "source": [ + "We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations" + ] + }, + { + "cell_type": "markdown", + "id": "bb30fb1b-ef8f-4638-af44-69374dd9cfe9", + "metadata": {}, + "source": [ + "### Benchmarking" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f75eb50b-b5c4-47ef-9e33-4fa6dfb489ba", + "metadata": {}, + "outputs": [], + "source": [ + "from tutorials.src.transformer import Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c4ff0abf-1f01-45bd-9dfc-cd26374137c7", + "metadata": {}, + "outputs": [], + "source": [ + "to_dim = 5\n", + "from_dim = 6\n", + "latent_dim = 10\n", + "to_len = 3\n", + "from_len = 10\n", + "batch_size = 8\n", + "num_heads = 2\n", + "num_blocks = 6" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3e08ff04-1086-4315-bf5e-caa960183c94", + "metadata": {}, + "outputs": [], + "source": [ + "td_tokens = TensorDict(\n", + " {\n", + " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n", + " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n", + " },\n", + " batch_size=[batch_size],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "665c4168-9ac8-45e5-98bc-6e5cc511a209", + "metadata": {}, + "outputs": [], + "source": [ + "X_encode = torch.randn(batch_size, to_len, to_dim)\n", + "X_decode = torch.randn(batch_size, from_len, from_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f3c2fd50-bc9b-4911-bd7c-8f8f03bd4ea4", + "metadata": {}, + "outputs": [], + "source": [ + "tdtransformer = TransformerTensorDict(\n", + " num_blocks,\n", + " \"X_encode\",\n", + " \"X_decode\",\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dfbadd6b-7847-4399-9b22-7e5c58524334", + "metadata": {}, + "outputs": [], + "source": [ + "transformer = Transformer(\n", + " num_blocks,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6a63de8f-ee8e-4ddf-bf89-f72c2896e1c3", + "metadata": { + "tags": [] + }, + "source": [ + "#### Inference time" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "02a4116b-2b75-47fc-8bc1-3903aa7cd504", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.25 ms, sys: 6.73 ms, total: 13 ms\n", + "Wall time: 7.57 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "tokens = tdtransformer(td_tokens)\n", + "#" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "40158aab-b53a-4a99-82cb-f5595eef7159", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5.32 ms, sys: 9.24 ms, total: 14.6 ms\n", + "Wall time: 7.15 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "X_out = transformer(X_encode, X_decode)" + ] + }, + { + "cell_type": "markdown", + "id": "664adff3-1466-47c3-9a80-a0f26171addd", + "metadata": {}, + "source": [ + "We can see on this minimal example that the overhead introduced by `TensorDictModule` is marginal." + ] + }, + { + "cell_type": "markdown", + "id": "bd08362a-8bb8-49fb-8038-1a60c5c01ea2", + "metadata": {}, + "source": [ + "Have fun with TensorDictModule!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}