From 228b158f9bfba2b29afa5bc30abb5ad5cc3807de Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 7 Jul 2022 18:03:30 +0100 Subject: [PATCH 01/21] Added TensorDict tutorial --- tutorials/tensor_dict.ipynb | 1949 +++++++++++++++++++++++++++++++++++ 1 file changed, 1949 insertions(+) create mode 100644 tutorials/tensor_dict.ipynb diff --git a/tutorials/tensor_dict.ipynb b/tutorials/tensor_dict.ipynb new file mode 100644 index 00000000000..c37215bd9fc --- /dev/null +++ b/tutorials/tensor_dict.ipynb @@ -0,0 +1,1949 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c61f510e", + "metadata": {}, + "source": [ + "# TensorDict tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "125c0ad4", + "metadata": {}, + "source": [ + "TensorDict is a new tensor structure introduced in torchrl. With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes\n", + "\n", + "### Example\n", + "As a concrete example, let us take DQN and PPO. The first uses a deterministic policy that applies an argmax operator to a collection of values associated with each action for a given observation. The second has a parametric policy that outputs a distribution over the space of the available actions. Here are the pseudos codes:\n", + "\n", + "\n", + "# DQN\n", + "data = []\n", + "for i in range(max_steps):\n", + " action, values = value_network(observation) # action = values.argmax(-1)\n", + " observation, reward, done, *other = env.step(action)\n", + " data.append((action, values, observation, reward, done))\n", + "\n", + "\n", + "# PPO\n", + "data = []\n", + "for i in range(max_steps):\n", + " action, action_log_prob = policy(observation)\n", + " observation, reward, done, *other = env.step(action)\n", + " data.append((action, action_log_prob, observation, reward, done))\n", + "\n", + "\n", + "\n", + "Ideally we would like to abstract this away into the same code:\n", + "\n", + "\n", + "collections = []\n", + "for i in range(max_steps):\n", + " collection_of_values = policy(collection_of_values)\n", + " collection_of_values = env_step(collection_of_values)\n", + " collections.append(collection_of_values)\n", + "\n", + "The differences in the algorithms will now lie in the `policy`, the `env_step` and the initial `collection_of_values` but the main algorithm is now the same for both algorithm. This abstraction allows for more modular and reusable code." + ] + }, + { + "cell_type": "markdown", + "id": "ac324978", + "metadata": {}, + "source": [ + "## Tensor Dict Python Dictionary behaviour" + ] + }, + { + "cell_type": "markdown", + "id": "61f4b705", + "metadata": {}, + "source": [ + "TensorDict shares a lot of features with python dictionaries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "791a3488", + "metadata": {}, + "outputs": [], + "source": [ + "from torchrl.data import TensorDict\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f46c8c8f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3, 4]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] + } + ], + "source": [ + "tensordict = TensorDict({\"a\": torch.zeros(3, 4, 5), \"b\": torch.zeros(3, 4)}, batch_size=[3, 4])\n", + "print(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "73cc767a", + "metadata": {}, + "source": [ + "If we want to access a certain key, it is explicit:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "16360ee0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]])\n", + "torch.Size([3, 4, 5])\n" + ] + } + ], + "source": [ + "print(tensordict[\"a\"])\n", + "print(tensordict[\"a\"].shape)" + ] + }, + { + "cell_type": "markdown", + "id": "423e1ee6", + "metadata": {}, + "source": [ + "#### TensorDict.keys()\n", + "We can access the dict keys" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6212e301", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a\n", + "b\n" + ] + } + ], + "source": [ + "for key in tensordict.keys():\n", + " print(key)" + ] + }, + { + "cell_type": "markdown", + "id": "43c4fe4c", + "metadata": {}, + "source": [ + "#### TensorDict.values()\n", + "We can also retrieve the values of the dict. On the contrary of python dicts, we return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors in mind." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f4538e4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict.values()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3030acda", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 4, 5])\n", + "torch.Size([3, 4, 1])\n" + ] + } + ], + "source": [ + "for value in tensordict.values():\n", + " print(value.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "bcc512bc", + "metadata": {}, + "source": [ + "#### TensorDict.update()\n", + "We can also use the update function like for dicts" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "884c3fed", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a is now tensor([[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]])\n", + "c is set as tensor([[[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]]])\n" + ] + } + ], + "source": [ + "tensordict.update({\"a\":torch.ones((3, 4, 5))})\n", + "tensordict.update({\"c\":torch.ones((3, 4, 2))})\n", + "print(f\"a is now {tensordict['a']}\")\n", + "print(f\"c is set as {tensordict['c']}\")" + ] + }, + { + "cell_type": "markdown", + "id": "dbadaafb", + "metadata": {}, + "source": [ + "#### TensorDict del key\n", + "Tensor Dict also support keys deletion" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d7b5920c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['a', 'b'])\n" + ] + } + ], + "source": [ + "del tensordict[\"c\"]\n", + "print(tensordict.keys())" + ] + }, + { + "cell_type": "markdown", + "id": "072b6665", + "metadata": {}, + "source": [ + "## TensorDict as a pytorch Tensor" + ] + }, + { + "cell_type": "markdown", + "id": "5f292ab8", + "metadata": {}, + "source": [ + "But wait? Can't we do this with a classical dict? \n", + "Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.\n", + "TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict\n", + "\n", + "TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:\n", + "- device;\n", + "- memory location (shared, memory-mapped array, ...);\n", + "- batch size (i.e. n^th first dimensions)." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3a3da095", + "metadata": {}, + "outputs": [], + "source": [ + "from torchrl.data.tensordict.tensordict import TensorDict\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "14fd8b61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3, 4]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "Our Tensor dict is of size torch.Size([3, 4])\n" + ] + } + ], + "source": [ + "tensordict = TensorDict({\"a\": torch.zeros(3, 4, 5), \"b\": torch.zeros(3, 4)}, batch_size=[3, 4])\n", + "print(tensordict)\n", + "print(f\"Our Tensor dict is of size {tensordict.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "74c47805", + "metadata": {}, + "source": [ + "#### Batch size" + ] + }, + { + "cell_type": "markdown", + "id": "58c926a9", + "metadata": {}, + "source": [ + "Tensor dict has a batch size which is shared across all tensors" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "bc68f307", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Our Tensor dict is of size torch.Size([3, 4])\n" + ] + } + ], + "source": [ + "print(f\"Our Tensor dict is of size {tensordict.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "69fc180b", + "metadata": {}, + "source": [ + "You cannot have items that don't share the batch size inside the TensorDict" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "41737242", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtensordict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:356\u001b[0m, in \u001b[0;36m_TensorDict.update\u001b[0;34m(self, input_dict_or_td, clone, inplace, **kwargs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clone:\n\u001b[1;32m 355\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mclone()\n\u001b[0;32m--> 356\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:1646\u001b[0m, in \u001b[0;36mTensorDict.set\u001b[0;34m(self, key, value, inplace, _run_checks, _meta_val)\u001b[0m\n\u001b[1;32m 1644\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict \u001b[38;5;129;01mand\u001b[39;00m inplace:\n\u001b[1;32m 1645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_(key, value)\n\u001b[0;32m-> 1646\u001b[0m proc_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1647\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_tensor_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_shared\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# check_tensor_shape=_run_checks\u001b[39;00m\n\u001b[1;32m 1652\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict[key] \u001b[38;5;241m=\u001b[39m proc_value\n\u001b[1;32m 1653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict_meta[key] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1654\u001b[0m MetaTensor(\n\u001b[1;32m 1655\u001b[0m proc_value,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m _meta_val\n\u001b[1;32m 1661\u001b[0m )\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:480\u001b[0m, in \u001b[0;36m_TensorDict._process_tensor\u001b[0;34m(self, input, check_device, check_tensor_shape, check_shared)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_shared is not authorized anymore\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_tensor_shape \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch dimension mismatch, got self.batch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and tensor.shape[:self.batch_dims]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 483\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 484\u001b[0m )\n\u001b[1;32m 486\u001b[0m \u001b[38;5;66;03m# minimum ndimension is 1\u001b[39;00m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "\u001b[0;31mRuntimeError\u001b[0m: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])" + ] + } + ], + "source": [ + "tensordict.update({\"c\":torch.zeros(4,3,1)})" + ] + }, + { + "cell_type": "markdown", + "id": "92a02a32", + "metadata": {}, + "source": [ + "### Tensor operations\n", + "We can perform tensor operations among the batch dimensions" + ] + }, + { + "cell_type": "markdown", + "id": "8a624475", + "metadata": {}, + "source": [ + "#### Cloning\n", + "TensorDict supports cloning" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f579c627", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]])\n", + "tensor([[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]])\n" + ] + } + ], + "source": [ + "tensordict_clone = tensordict.clone()\n", + "tensordict_clone[\"a\"] = torch.ones(*tensordict.shape,5)\n", + "print(tensordict[\"a\"])\n", + "print(tensordict_clone[\"a\"])" + ] + }, + { + "cell_type": "markdown", + "id": "71075821", + "metadata": {}, + "source": [ + "#### Slicing and indexing\n", + "Slicing and indexing is supported among the batch dimension" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "710d8b3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([4]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5ef348fc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([2, 4]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict[1:]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d8230353", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3, 2]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict[:,2:]" + ] + }, + { + "cell_type": "markdown", + "id": "9bb47f65", + "metadata": {}, + "source": [ + "TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\\*shape), squeeze(dim), unsqueeze(dim), permute(\\*dims) requiring the operations to comply with the batch_size" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "fb3b3002", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([12, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([12]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict.reshape(-1)" + ] + }, + { + "cell_type": "markdown", + "id": "d6d74f22", + "metadata": {}, + "source": [ + "#### Casting to device\n", + "TensorDict supports casting to devices with the .to(device) function as with regular tensors" + ] + }, + { + "cell_type": "markdown", + "id": "0053ef9f", + "metadata": {}, + "source": [ + "## How to use them in practice? The tensor the TensorDictModule" + ] + }, + { + "cell_type": "markdown", + "id": "f94af8a2", + "metadata": {}, + "source": [ + "Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4439dc82", + "metadata": {}, + "outputs": [], + "source": [ + "from torchrl.modules import TensorDictModule\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "markdown", + "id": "88f76450", + "metadata": {}, + "source": [ + "### Example: Simple Linear layer" + ] + }, + { + "cell_type": "markdown", + "id": "2614d32f", + "metadata": {}, + "source": [ + "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1acaf3d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3)}, batch_size=[5])\n", + "linear = TensorDictModule(nn.Linear(3,10),in_keys=[\"a\"], out_keys=[\"a_out\"])\n", + "linear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "30a879d0", + "metadata": {}, + "source": [ + "We can also do it inplace" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9f22bdd9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3)}, batch_size=[5])\n", + "linear = TensorDictModule(nn.Linear(3,10),in_keys=[\"a\"], out_keys=[\"a\"])\n", + "linear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "19f2866e", + "metadata": {}, + "source": [ + "### Example: 2 input merging with 2 linear layer" + ] + }, + { + "cell_type": "markdown", + "id": "8166868f", + "metadata": {}, + "source": [ + "Now lets imagine a more complex network that takes 2 entries and average them into a single output" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1a5ba235", + "metadata": {}, + "outputs": [], + "source": [ + "class MergeLinear(nn.Module):\n", + " def __init__(self, in_1, in_2, out):\n", + " super().__init__()\n", + " self.linear_1 = nn.Linear(in_1,out)\n", + " self.linear_2 = nn.Linear(in_2,out)\n", + " def forward(self, x_1, x_2):\n", + " return (self.linear_1(x_1) + self.linear_2(x_2))/2" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "021c97e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", + " c: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3), \"c\":torch.randn(5,4)}, batch_size=[5])\n", + "mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=[\"a\",\"c\"], out_keys=[\"output\"])\n", + "mergelinear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "fd579195", + "metadata": {}, + "source": [ + "### Example: 1 input to 2 outputs linear layer\n", + "We can also map to multiple outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "632426fe", + "metadata": {}, + "outputs": [], + "source": [ + "class MultiHeadLinear(nn.Module):\n", + " def __init__(self, in_1, out_1, out_2):\n", + " super().__init__()\n", + " self.linear_1 = nn.Linear(in_1,out_1)\n", + " self.linear_2 = nn.Linear(in_1,out_2)\n", + " def forward(self, x):\n", + " return self.linear_1(x), self.linear_2(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "435ae253", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", + " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3)}, batch_size=[5])\n", + "mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"])\n", + "mergelinear(tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "d7594721", + "metadata": {}, + "source": [ + "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", + "The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key." + ] + }, + { + "cell_type": "markdown", + "id": "786a2ed0", + "metadata": {}, + "source": [ + "### Example: A transformer with TensorDict?\n", + "Let's attempt to create a transformer with TensorDict and TensorDictModule\n", + "\n", + "Disclaimer: This implementation isn't to be \"better\" than a tensor based implementation. It is just meant to showcase the TensorDictModule features.\n", + "\n", + "Let's first implement the classical transformers blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "ff5e05d1", + "metadata": {}, + "outputs": [], + "source": [ + "class TokensToQKV(nn.Module):\n", + " def __init__(self, to_dim, from_dim, latent_dim):\n", + " super().__init__()\n", + " self.q = nn.Linear(to_dim, latent_dim)\n", + " self.k = nn.Linear(from_dim, latent_dim)\n", + " self.v = nn.Linear(from_dim, latent_dim)\n", + " def forward(self, X_to, X_from):\n", + " Q = self.q(X_to)\n", + " K = self.k(X_from)\n", + " V = self.v(X_from)\n", + " return Q, K, V\n", + "\n", + "class SplitHeads(nn.Module):\n", + " def __init__(self, num_heads):\n", + " super().__init__()\n", + " self.num_heads = num_heads\n", + " def forward(self, Q, K, V):\n", + " batch_size, to_num, latent_dim = Q.shape\n", + " _, from_num, _ = K.shape\n", + " d_tensor = latent_dim // self.num_heads\n", + " Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)\n", + " K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", + " V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", + " return Q, K, V\n", + "class Attention(nn.Module):\n", + " def __init__(self, latent_dim, to_dim):\n", + " super().__init__()\n", + " self.softmax = nn.Softmax(dim=-1)\n", + " self.out = nn.Linear(latent_dim, to_dim)\n", + " def forward(self, Q, K, V):\n", + " batch_size, n_heads, to_num, d_in = Q.shape\n", + " attn = self.softmax(Q @ K.transpose(2,3) / d_in)\n", + " out = attn @ V\n", + " out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))\n", + " return out, attn\n", + "class SkipLayerNorm(nn.Module):\n", + " def __init__(self, to_len, to_dim):\n", + " super().__init__()\n", + " self.layer_norm = nn.LayerNorm((to_len, to_dim))\n", + " def forward(self, x_0, x_1):\n", + " return self.layer_norm(x_0+x_1)\n", + "class FFN(nn.Module):\n", + " def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):\n", + " super().__init__()\n", + " self.FFN = nn.Sequential(\n", + " nn.Linear(to_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim, to_dim),\n", + " nn.Dropout(dropout_rate)\n", + " )\n", + " def forward(self, X):\n", + " return self.FFN(X)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d4c49b49", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlockTensorDict(nn.Module):\n", + " def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", + " super().__init__()\n", + " self.transformer_block = nn.Sequential(\n", + " TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=[\"Q\", \"K\", \"V\"]),\n", + " TensorDictModule(SplitHeads(num_heads), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"Q\", \"K\", \"V\"]),\n", + " TensorDictModule(Attention(latent_dim, to_dim), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"X_out\",\"Attn\"]),\n", + " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[\"X_to\", \"X_out\"], out_keys=[\"X_to\"]),\n", + " TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=[\"X_to\"], out_keys=[\"X_out\"]),\n", + " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[\"X_to\", \"X_out\"], out_keys=[\"X_to\"]),\n", + " )\n", + " def forward(self, X_tensor_dict):\n", + " self.transformer_block(X_tensor_dict)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "9fa8d7ff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),\n", + " K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", + " Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", + " V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", + " X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", + " X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n", + " X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},\n", + " batch_size=torch.Size([8]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "to_dim = 5\n", + "from_dim = 6\n", + "latent_dim = 10\n", + "to_len = 3\n", + "from_len = 10\n", + "batch_size = 8\n", + "num_heads = 2\n", + "\n", + "tokens = TensorDict({\"X_to\":torch.randn(batch_size, to_len, to_dim), \"X_from\":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])\n", + "\n", + "transformer_block = TransformerBlockTensorDict(\"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads)\n", + "\n", + "transformer_block(tokens)\n", + "\n", + "tokens" + ] + }, + { + "cell_type": "markdown", + "id": "23a53298", + "metadata": {}, + "source": [ + "The output of the transformer layer can now be found at tokens[\"X_to\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "12c66c3f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-2.2140e-01, 9.3885e-02, 1.5475e+00, -8.5764e-01, -1.0164e+00],\n", + " [ 6.1844e-01, 4.2928e-01, 1.2559e+00, -1.7318e-02, -1.0785e+00],\n", + " [ 3.6746e-01, -2.0010e+00, 5.7185e-01, 1.3700e+00, -1.0621e+00]],\n", + "\n", + " [[ 1.8018e+00, 1.7676e-01, 1.3061e+00, 5.5536e-04, 9.5903e-01],\n", + " [-1.7914e-02, -1.1058e+00, -1.0330e+00, -5.6826e-01, 6.1741e-01],\n", + " [ 1.0086e+00, -2.1221e+00, 5.2469e-02, -5.4785e-01, -5.2776e-01]],\n", + "\n", + " [[-5.7215e-03, -2.1323e+00, -6.8256e-01, -5.2980e-02, 2.1093e+00],\n", + " [-3.1369e-01, -1.0205e+00, 1.2342e+00, -4.3158e-01, 1.8713e-01],\n", + " [ 1.2250e+00, -3.5429e-01, 3.3673e-01, -7.2250e-01, 6.2381e-01]],\n", + "\n", + " [[ 1.1720e+00, -1.2023e+00, -4.1971e-01, -3.9235e-01, -5.0843e-01],\n", + " [ 1.4071e+00, -7.0667e-01, -9.0403e-01, 9.7855e-01, 1.9402e+00],\n", + " [ 5.9586e-01, -2.0236e-01, -5.3443e-01, -1.6470e+00, 4.2355e-01]],\n", + "\n", + " [[ 2.2837e-01, -2.2963e-02, 2.4110e-01, -9.2972e-01, 5.7418e-01],\n", + " [ 1.1679e+00, -7.8923e-01, -1.7175e-01, 3.6727e-01, 1.1527e+00],\n", + " [ 1.8616e+00, -3.0584e-01, -2.1084e+00, -1.4724e+00, 2.0712e-01]],\n", + "\n", + " [[-2.6172e-01, -2.2121e+00, -7.4566e-01, 1.2002e+00, -7.5803e-01],\n", + " [ 5.2748e-01, 3.9661e-01, 1.6845e+00, 9.5950e-01, -3.9241e-01],\n", + " [-7.8969e-01, -5.8632e-01, 6.8664e-01, -7.5952e-01, 1.0505e+00]],\n", + "\n", + " [[ 1.8870e+00, -3.2287e-01, -9.8743e-03, -8.0076e-01, -9.8274e-01],\n", + " [-9.8188e-02, 1.4190e+00, -2.8326e-01, 8.4426e-01, -5.9501e-01],\n", + " [-2.5563e-01, -2.1496e+00, -3.9752e-01, 5.6409e-01, 1.1811e+00]],\n", + "\n", + " [[ 4.9171e-01, -8.9689e-01, 8.8269e-01, 1.7786e+00, -6.3821e-02],\n", + " [-3.8188e-01, -1.4144e+00, -7.0545e-01, -1.0574e+00, 1.4502e+00],\n", + " [ 2.6302e-01, -1.0660e+00, 4.6133e-01, -1.0306e+00, 1.2888e+00]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens[\"X_to\"]" + ] + }, + { + "cell_type": "markdown", + "id": "55fd3c7b", + "metadata": {}, + "source": [ + "We can now create a transformer easily" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "19e0659e", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerTensorDict(nn.Module):\n", + " def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", + " super().__init__()\n", + " self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])\n", + " def forward(self, X_tensor_dict):\n", + " for transformer_block in self.transformer:\n", + " transformer_block(X_tensor_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "fd865d55", + "metadata": {}, + "outputs": [], + "source": [ + "to_dim = 5\n", + "from_dim = 6\n", + "latent_dim = 10\n", + "to_len = 3\n", + "from_len = 10\n", + "batch_size = 8\n", + "num_heads = 2\n", + "\n", + "tokens = TensorDict({\"X_to\":torch.randn(batch_size, to_len, to_dim), \"X_from\":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "86885fbc", + "metadata": {}, + "source": [ + "For an encoder, we can do it easily as follow" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "9c9d6d10", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.5433, -2.1209, 0.6917, 0.1949, -0.2383],\n", + " [-0.2084, 1.1901, 0.1548, -1.2804, 0.3292],\n", + " [ 1.8350, -1.2308, 0.2563, -0.9090, 0.7924]],\n", + "\n", + " [[ 0.7642, -0.8773, 0.6102, -1.9610, -0.2547],\n", + " [ 0.8582, -0.8673, 1.6634, -0.5311, -1.3557],\n", + " [ 0.5985, -0.6759, 0.8806, -0.0338, 1.1816]],\n", + "\n", + " [[ 0.1173, -1.0630, -1.1206, -0.9292, 0.6910],\n", + " [-0.5904, 0.2967, 1.0146, -0.1985, -0.1382],\n", + " [ 2.3575, -0.4132, 1.5436, -1.2285, -0.3390]],\n", + "\n", + " [[-0.0839, 0.3025, 1.2402, -1.1337, -0.3312],\n", + " [ 0.1848, -0.5644, -0.9238, -1.2078, 0.7454],\n", + " [ 0.9539, -0.4605, 2.5865, -0.8541, -0.4540]],\n", + "\n", + " [[ 2.2012, 0.3198, -0.1624, -1.0218, -0.2963],\n", + " [-1.1105, 0.0330, -0.6841, 0.2308, 0.1053],\n", + " [ 2.2178, -1.1405, 0.0187, -0.9278, 0.2166]],\n", + "\n", + " [[ 0.1762, 0.1206, -0.9111, -0.8480, 0.8383],\n", + " [-0.0356, 1.9972, 1.2416, -1.9975, 0.2798],\n", + " [-1.5899, 0.6967, 0.2188, -0.2017, 0.0147]],\n", + "\n", + " [[ 0.5766, -1.0382, -0.3026, -0.9678, 0.1475],\n", + " [-0.2206, 0.6656, -0.0131, -0.7690, -0.2181],\n", + " [ 1.7309, -0.6701, 2.5907, -0.5225, -0.9894]],\n", + "\n", + " [[ 0.8108, -0.8168, 0.2928, -0.8560, 0.6063],\n", + " [ 1.3993, -1.0990, 2.3342, -0.6302, -0.6805],\n", + " [-0.3689, -1.0603, 1.0844, -0.3315, -0.6845]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer_encoder = TransformerTensorDict(6, \"X_to\", \"X_to\", to_dim, to_len, to_dim, latent_dim, num_heads)\n", + "\n", + "transformer_encoder(tokens)\n", + "tokens[\"X_to\"]" + ] + }, + { + "cell_type": "markdown", + "id": "d994eaa5", + "metadata": {}, + "source": [ + "For a decoder we have " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "a1c3783e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.9111, -0.9565, 0.8885, -0.1153, -1.1195],\n", + " [-1.0822, 1.9428, -0.1168, -1.2518, -0.4014],\n", + " [ 1.6768, -1.1258, 0.4677, -0.1539, 0.4362]],\n", + "\n", + " [[ 0.2126, -0.9831, 0.8968, -1.4953, -1.0055],\n", + " [ 0.8913, 0.0649, 0.6585, -0.7806, -2.0843],\n", + " [ 1.5738, 0.4048, 0.6604, 0.7997, 0.1859]],\n", + "\n", + " [[-0.1787, -0.4171, -1.1724, -0.7935, 0.4380],\n", + " [-0.3860, 1.7195, 0.8667, -1.2160, -0.2223],\n", + " [ 1.6916, 0.1081, 1.4817, -1.3484, -0.5712]],\n", + "\n", + " [[ 0.3393, 0.4130, 0.7333, -1.0726, -0.8733],\n", + " [ 0.9412, -0.4980, -0.0214, -0.0619, 0.4108],\n", + " [ 1.5666, -0.0143, 1.5400, -1.2883, -2.1144]],\n", + "\n", + " [[ 2.3167, 0.4312, -0.4086, 0.1072, -1.3709],\n", + " [-0.6496, 0.0503, -0.5773, 0.4249, -0.6153],\n", + " [ 2.2431, -0.4351, -0.3872, -0.3861, -0.7434]],\n", + "\n", + " [[ 0.7115, 0.0280, -0.2837, -0.1418, 0.2284],\n", + " [-0.3786, 1.7935, 1.5064, -1.9357, -0.8157],\n", + " [-0.8887, 1.5851, -0.0724, -0.4470, -0.8894]],\n", + "\n", + " [[ 0.7716, -0.7878, 0.1092, -0.8322, 0.1236],\n", + " [-0.2652, 0.5158, 0.1450, -1.0936, -0.9748],\n", + " [ 1.3011, -0.0770, 2.7170, -0.5874, -1.0653]],\n", + "\n", + " [[ 0.7893, -0.3034, 0.4487, -0.8292, 0.2726],\n", + " [ 1.0496, -0.7812, 1.6862, -0.5165, -1.8188],\n", + " [ 0.6703, -0.0561, 1.4446, -0.5093, -1.5469]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer_decoder = TransformerTensorDict(6, \"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads)\n", + "\n", + "transformer_decoder(tokens)\n", + "tokens[\"X_to\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "346ebe20", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TransformerTensorDict(\n", + " (transformer): ModuleList(\n", + " (0): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (1): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (2): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (3): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (4): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (5): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer_encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "e309c6ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TransformerTensorDict(\n", + " (transformer): ModuleList(\n", + " (0): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (1): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (2): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (3): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (4): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " (5): TransformerBlockTensorDict(\n", + " (transformer_block): Sequential(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer_decoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59b7ecfa", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 1f0319969cfc1b23c7092dfc02c2e95ade8fd475 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 7 Jul 2022 18:59:05 +0100 Subject: [PATCH 02/21] Fixed english mistakes and small refactoring --- tutorials/tensor_dict.ipynb | 574 +++++++++++++++++++++++------------- 1 file changed, 363 insertions(+), 211 deletions(-) diff --git a/tutorials/tensor_dict.ipynb b/tutorials/tensor_dict.ipynb index c37215bd9fc..52adae109fc 100644 --- a/tutorials/tensor_dict.ipynb +++ b/tutorials/tensor_dict.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c61f510e", + "id": "ed498bdd", "metadata": {}, "source": [ "# TensorDict tutorial" @@ -10,12 +10,12 @@ }, { "cell_type": "markdown", - "id": "125c0ad4", + "id": "02a938a2", "metadata": {}, "source": [ "TensorDict is a new tensor structure introduced in torchrl. With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes\n", "\n", - "### Example\n", + "### Motivation\n", "As a concrete example, let us take DQN and PPO. The first uses a deterministic policy that applies an argmax operator to a collection of values associated with each action for a given observation. The second has a parametric policy that outputs a distribution over the space of the available actions. Here are the pseudos codes:\n", "\n", "\n", @@ -50,7 +50,7 @@ }, { "cell_type": "markdown", - "id": "ac324978", + "id": "f0e44de1", "metadata": {}, "source": [ "## Tensor Dict Python Dictionary behaviour" @@ -58,7 +58,7 @@ }, { "cell_type": "markdown", - "id": "61f4b705", + "id": "a1b27f7f", "metadata": {}, "source": [ "TensorDict shares a lot of features with python dictionaries" @@ -67,7 +67,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "791a3488", + "id": "ff626668", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "f46c8c8f", + "id": "1985ef1b", "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "73cc767a", + "id": "7f66356d", "metadata": {}, "source": [ "If we want to access a certain key, it is explicit:" @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "16360ee0", + "id": "d0d4689c", "metadata": {}, "outputs": [ { @@ -143,17 +143,58 @@ }, { "cell_type": "markdown", - "id": "423e1ee6", + "id": "887c4d65", "metadata": {}, "source": [ - "#### TensorDict.keys()\n", - "We can access the dict keys" + "Also works with get()" ] }, { "cell_type": "code", "execution_count": 4, - "id": "6212e301", + "id": "7212a431", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]])\n", + "torch.Size([3, 4, 5])\n" + ] + } + ], + "source": [ + "print(tensordict.get(\"a\"))\n", + "print(tensordict.get(\"a\").shape)" + ] + }, + { + "cell_type": "markdown", + "id": "f3c82ca1", + "metadata": {}, + "source": [ + "#### TensorDict.keys()\n", + "Keys can be retrieved to TensorDict" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bef88aa2", "metadata": {}, "outputs": [ { @@ -172,26 +213,26 @@ }, { "cell_type": "markdown", - "id": "43c4fe4c", + "id": "43066ab2", "metadata": {}, "source": [ "#### TensorDict.values()\n", - "We can also retrieve the values of the dict. On the contrary of python dicts, we return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors in mind." + "The values of a TensorDict can be retrieved with the values() function. On the contrary of python dicts, the values() function return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors which can take a lot of space in memory." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "f4538e4d", + "execution_count": 6, + "id": "2fa0d99b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -202,8 +243,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "3030acda", + "execution_count": 7, + "id": "270dc4f0", "metadata": {}, "outputs": [ { @@ -222,17 +263,82 @@ }, { "cell_type": "markdown", - "id": "bcc512bc", + "id": "17079498", + "metadata": {}, + "source": [ + "#### TensorDict.set()\n", + "The set function can be used to set new values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c088852f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "c is set as tensor([[[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]]])\n" + ] + } + ], + "source": [ + "tensordict.set(\"c\", torch.zeros((3, 4, 2, 2)))\n", + "# Also works with tensordict.update(TensorDict({\"a\":torch.ones((3, 4, 5)), \"c\":torch.ones((3, 4, 2))}, batch_size=[3,4]))\n", + "print(f\"c is set as {tensordict['c']}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7b8783e1", "metadata": {}, "source": [ "#### TensorDict.update()\n", - "We can also use the update function like for dicts" + "The update function can be used to update the dict with other dict values (Or TensorDict)" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "884c3fed", + "execution_count": 9, + "id": "5c7061f8", "metadata": {}, "outputs": [ { @@ -253,7 +359,7 @@ " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]])\n", - "c is set as tensor([[[1., 1.],\n", + "d is set as tensor([[[1., 1.],\n", " [1., 1.],\n", " [1., 1.],\n", " [1., 1.]],\n", @@ -271,32 +377,32 @@ } ], "source": [ - "tensordict.update({\"a\":torch.ones((3, 4, 5))})\n", - "tensordict.update({\"c\":torch.ones((3, 4, 2))})\n", + "tensordict.update({\"a\":torch.ones((3, 4, 5)), \"d\":torch.ones((3, 4, 2))})\n", + "# Also works with tensordict.update(TensorDict({\"a\":torch.ones((3, 4, 5)), \"c\":torch.ones((3, 4, 2))}, batch_size=[3,4]))\n", "print(f\"a is now {tensordict['a']}\")\n", - "print(f\"c is set as {tensordict['c']}\")" + "print(f\"d is set as {tensordict['d']}\")" ] }, { "cell_type": "markdown", - "id": "dbadaafb", + "id": "1e7434c1", "metadata": {}, "source": [ "#### TensorDict del key\n", - "Tensor Dict also support keys deletion" + "TensorDict also support keys deletion with the del operator:" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "d7b5920c", + "execution_count": 10, + "id": "a83a86b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "dict_keys(['a', 'b'])\n" + "dict_keys(['a', 'b', 'd'])\n" ] } ], @@ -307,7 +413,7 @@ }, { "cell_type": "markdown", - "id": "072b6665", + "id": "dada11e0", "metadata": {}, "source": [ "## TensorDict as a pytorch Tensor" @@ -315,12 +421,12 @@ }, { "cell_type": "markdown", - "id": "5f292ab8", + "id": "3af83e0f", "metadata": {}, "source": [ "But wait? Can't we do this with a classical dict? \n", "Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.\n", - "TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict\n", + "TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.\n", "\n", "TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:\n", "- device;\n", @@ -330,8 +436,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "3a3da095", + "execution_count": 11, + "id": "345e0964", "metadata": {}, "outputs": [], "source": [ @@ -341,8 +447,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "14fd8b61", + "execution_count": 12, + "id": "5f1b27ff", "metadata": {}, "outputs": [ { @@ -368,7 +474,7 @@ }, { "cell_type": "markdown", - "id": "74c47805", + "id": "fdc15741", "metadata": {}, "source": [ "#### Batch size" @@ -376,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "58c926a9", + "id": "cd455eaa", "metadata": {}, "source": [ "Tensor dict has a batch size which is shared across all tensors" @@ -384,8 +490,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "bc68f307", + "execution_count": 13, + "id": "e060c706", "metadata": {}, "outputs": [ { @@ -402,16 +508,16 @@ }, { "cell_type": "markdown", - "id": "69fc180b", + "id": "b5d4a862", "metadata": {}, "source": [ - "You cannot have items that don't share the batch size inside the TensorDict" + "You cannot have items that don't share the batch size inside the same TensorDict:" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "41737242", + "execution_count": 14, + "id": "1d00f2b7", "metadata": {}, "outputs": [ { @@ -421,7 +527,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtensordict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", + "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtensordict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:356\u001b[0m, in \u001b[0;36m_TensorDict.update\u001b[0;34m(self, input_dict_or_td, clone, inplace, **kwargs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clone:\n\u001b[1;32m 355\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mclone()\n\u001b[0;32m--> 356\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:1646\u001b[0m, in \u001b[0;36mTensorDict.set\u001b[0;34m(self, key, value, inplace, _run_checks, _meta_val)\u001b[0m\n\u001b[1;32m 1644\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict \u001b[38;5;129;01mand\u001b[39;00m inplace:\n\u001b[1;32m 1645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_(key, value)\n\u001b[0;32m-> 1646\u001b[0m proc_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1647\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_tensor_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_shared\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# check_tensor_shape=_run_checks\u001b[39;00m\n\u001b[1;32m 1652\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict[key] \u001b[38;5;241m=\u001b[39m proc_value\n\u001b[1;32m 1653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict_meta[key] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1654\u001b[0m MetaTensor(\n\u001b[1;32m 1655\u001b[0m proc_value,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m _meta_val\n\u001b[1;32m 1661\u001b[0m )\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:480\u001b[0m, in \u001b[0;36m_TensorDict._process_tensor\u001b[0;34m(self, input, check_device, check_tensor_shape, check_shared)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_shared is not authorized anymore\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_tensor_shape \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch dimension mismatch, got self.batch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and tensor.shape[:self.batch_dims]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 483\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 484\u001b[0m )\n\u001b[1;32m 486\u001b[0m \u001b[38;5;66;03m# minimum ndimension is 1\u001b[39;00m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", @@ -435,16 +541,7 @@ }, { "cell_type": "markdown", - "id": "92a02a32", - "metadata": {}, - "source": [ - "### Tensor operations\n", - "We can perform tensor operations among the batch dimensions" - ] - }, - { - "cell_type": "markdown", - "id": "8a624475", + "id": "0fa5d799", "metadata": {}, "source": [ "#### Cloning\n", @@ -453,8 +550,8 @@ }, { "cell_type": "code", - "execution_count": 13, - "id": "f579c627", + "execution_count": 15, + "id": "1f1ab1ec", "metadata": {}, "outputs": [ { @@ -501,7 +598,16 @@ }, { "cell_type": "markdown", - "id": "71075821", + "id": "857823f7", + "metadata": {}, + "source": [ + "### Tensor operations\n", + "We can perform tensor operations among the batch dimensions:" + ] + }, + { + "cell_type": "markdown", + "id": "b0ff6353", "metadata": {}, "source": [ "#### Slicing and indexing\n", @@ -510,8 +616,8 @@ }, { "cell_type": "code", - "execution_count": 14, - "id": "710d8b3b", + "execution_count": 16, + "id": "e170a64e", "metadata": {}, "outputs": [ { @@ -526,7 +632,7 @@ " is_shared=False)" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -537,8 +643,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "5ef348fc", + "execution_count": 17, + "id": "2a12a7d3", "metadata": {}, "outputs": [ { @@ -553,7 +659,7 @@ " is_shared=False)" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -564,8 +670,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "id": "d8230353", + "execution_count": 18, + "id": "c89ff33a", "metadata": {}, "outputs": [ { @@ -580,7 +686,7 @@ " is_shared=False)" ] }, - "execution_count": 16, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -591,7 +697,7 @@ }, { "cell_type": "markdown", - "id": "9bb47f65", + "id": "e6007915", "metadata": {}, "source": [ "TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\\*shape), squeeze(dim), unsqueeze(dim), permute(\\*dims) requiring the operations to comply with the batch_size" @@ -599,8 +705,8 @@ }, { "cell_type": "code", - "execution_count": 17, - "id": "fb3b3002", + "execution_count": 19, + "id": "49ae45c5", "metadata": {}, "outputs": [ { @@ -615,18 +721,47 @@ " is_shared=False)" ] }, - "execution_count": 17, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Reshape\n", "tensordict.reshape(-1)" ] }, + { + "cell_type": "code", + "execution_count": 20, + "id": "491a987d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([6, 4, 5]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([6, 4, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([6, 4]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#Cat\n", + "torch.cat([tensordict, tensordict.clone()], dim=0)" + ] + }, { "cell_type": "markdown", - "id": "d6d74f22", + "id": "1a293442", "metadata": {}, "source": [ "#### Casting to device\n", @@ -635,7 +770,7 @@ }, { "cell_type": "markdown", - "id": "0053ef9f", + "id": "0548af86", "metadata": {}, "source": [ "## How to use them in practice? The tensor the TensorDictModule" @@ -643,7 +778,7 @@ }, { "cell_type": "markdown", - "id": "f94af8a2", + "id": "d2ef1786", "metadata": {}, "source": [ "Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." @@ -651,8 +786,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "4439dc82", + "execution_count": 21, + "id": "61c59b0b", "metadata": {}, "outputs": [], "source": [ @@ -662,7 +797,7 @@ }, { "cell_type": "markdown", - "id": "88f76450", + "id": "fbab1825", "metadata": {}, "source": [ "### Example: Simple Linear layer" @@ -670,7 +805,7 @@ }, { "cell_type": "markdown", - "id": "2614d32f", + "id": "10cd4be6", "metadata": {}, "source": [ "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." @@ -678,8 +813,8 @@ }, { "cell_type": "code", - "execution_count": 19, - "id": "1acaf3d6", + "execution_count": 22, + "id": "120540b2", "metadata": {}, "outputs": [ { @@ -695,7 +830,7 @@ " is_shared=False)" ] }, - "execution_count": 19, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -708,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "30a879d0", + "id": "6590ddfb", "metadata": {}, "source": [ "We can also do it inplace" @@ -716,8 +851,8 @@ }, { "cell_type": "code", - "execution_count": 20, - "id": "9f22bdd9", + "execution_count": 23, + "id": "6d28e94f", "metadata": {}, "outputs": [ { @@ -732,7 +867,7 @@ " is_shared=False)" ] }, - "execution_count": 20, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -745,7 +880,7 @@ }, { "cell_type": "markdown", - "id": "19f2866e", + "id": "69a6fec6", "metadata": {}, "source": [ "### Example: 2 input merging with 2 linear layer" @@ -753,7 +888,7 @@ }, { "cell_type": "markdown", - "id": "8166868f", + "id": "dc7bf86b", "metadata": {}, "source": [ "Now lets imagine a more complex network that takes 2 entries and average them into a single output" @@ -761,8 +896,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "1a5ba235", + "execution_count": 24, + "id": "e5871748", "metadata": {}, "outputs": [], "source": [ @@ -777,8 +912,8 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "021c97e1", + "execution_count": 25, + "id": "4a615ad1", "metadata": {}, "outputs": [ { @@ -795,7 +930,7 @@ " is_shared=False)" ] }, - "execution_count": 22, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -808,7 +943,7 @@ }, { "cell_type": "markdown", - "id": "fd579195", + "id": "f4eea40a", "metadata": {}, "source": [ "### Example: 1 input to 2 outputs linear layer\n", @@ -817,8 +952,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "632426fe", + "execution_count": 26, + "id": "d969a7d2", "metadata": {}, "outputs": [], "source": [ @@ -833,8 +968,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "435ae253", + "execution_count": 27, + "id": "04bd7977", "metadata": {}, "outputs": [ { @@ -851,7 +986,7 @@ " is_shared=False)" ] }, - "execution_count": 24, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -864,7 +999,7 @@ }, { "cell_type": "markdown", - "id": "d7594721", + "id": "ff2d6658", "metadata": {}, "source": [ "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", @@ -873,21 +1008,22 @@ }, { "cell_type": "markdown", - "id": "786a2ed0", + "id": "62d946d8", "metadata": {}, "source": [ "### Example: A transformer with TensorDict?\n", "Let's attempt to create a transformer with TensorDict and TensorDictModule\n", "\n", - "Disclaimer: This implementation isn't to be \"better\" than a tensor based implementation. It is just meant to showcase the TensorDictModule features.\n", + "Disclaimer: This implementation don't claim to be \"better\" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.\n", + "For simplicity we will not have positional encoders.\n", "\n", "Let's first implement the classical transformers blocks" ] }, { "cell_type": "code", - "execution_count": 25, - "id": "ff5e05d1", + "execution_count": 28, + "id": "f21f79a6", "metadata": {}, "outputs": [], "source": [ @@ -945,10 +1081,18 @@ " return self.FFN(X)\n" ] }, + { + "cell_type": "markdown", + "id": "acfd6a21", + "metadata": {}, + "source": [ + "Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block." + ] + }, { "cell_type": "code", - "execution_count": 26, - "id": "d4c49b49", + "execution_count": 29, + "id": "4b52dc71", "metadata": {}, "outputs": [], "source": [ @@ -970,8 +1114,8 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "9fa8d7ff", + "execution_count": 30, + "id": "3aa84439", "metadata": {}, "outputs": [ { @@ -991,7 +1135,7 @@ " is_shared=False)" ] }, - "execution_count": 27, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1016,7 +1160,7 @@ }, { "cell_type": "markdown", - "id": "23a53298", + "id": "653f125e", "metadata": {}, "source": [ "The output of the transformer layer can now be found at tokens[\"X_to\"]" @@ -1024,48 +1168,48 @@ }, { "cell_type": "code", - "execution_count": 28, - "id": "12c66c3f", + "execution_count": 31, + "id": "31bc9d71", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-2.2140e-01, 9.3885e-02, 1.5475e+00, -8.5764e-01, -1.0164e+00],\n", - " [ 6.1844e-01, 4.2928e-01, 1.2559e+00, -1.7318e-02, -1.0785e+00],\n", - " [ 3.6746e-01, -2.0010e+00, 5.7185e-01, 1.3700e+00, -1.0621e+00]],\n", + "tensor([[[-0.8713, -1.2626, 1.3218, -0.2947, 1.6938],\n", + " [-0.7374, -0.6038, -1.4958, -0.3975, 0.0741],\n", + " [-0.5026, 0.0095, 1.8102, 0.2308, 1.0256]],\n", "\n", - " [[ 1.8018e+00, 1.7676e-01, 1.3061e+00, 5.5536e-04, 9.5903e-01],\n", - " [-1.7914e-02, -1.1058e+00, -1.0330e+00, -5.6826e-01, 6.1741e-01],\n", - " [ 1.0086e+00, -2.1221e+00, 5.2469e-02, -5.4785e-01, -5.2776e-01]],\n", + " [[ 0.4484, 0.0689, 1.0152, -1.1690, 1.2264],\n", + " [-0.5891, 0.0737, 0.7038, 1.0404, 0.0131],\n", + " [ 1.0050, -2.4708, -0.7890, -1.0160, 0.4391]],\n", "\n", - " [[-5.7215e-03, -2.1323e+00, -6.8256e-01, -5.2980e-02, 2.1093e+00],\n", - " [-3.1369e-01, -1.0205e+00, 1.2342e+00, -4.3158e-01, 1.8713e-01],\n", - " [ 1.2250e+00, -3.5429e-01, 3.3673e-01, -7.2250e-01, 6.2381e-01]],\n", + " [[-1.8364, -0.5181, -0.5258, 0.5166, 1.8120],\n", + " [ 1.3389, 0.1451, -0.1267, -0.7637, 1.6104],\n", + " [-0.1859, -0.4134, -1.4359, -0.1131, 0.4961]],\n", "\n", - " [[ 1.1720e+00, -1.2023e+00, -4.1971e-01, -3.9235e-01, -5.0843e-01],\n", - " [ 1.4071e+00, -7.0667e-01, -9.0403e-01, 9.7855e-01, 1.9402e+00],\n", - " [ 5.9586e-01, -2.0236e-01, -5.3443e-01, -1.6470e+00, 4.2355e-01]],\n", + " [[-1.0511, 0.1636, -0.9440, -0.2152, -0.4874],\n", + " [ 1.4676, 2.0405, 0.2846, 0.5990, 1.0199],\n", + " [-1.5073, 0.0980, -1.5943, -0.1160, 0.2421]],\n", "\n", - " [[ 2.2837e-01, -2.2963e-02, 2.4110e-01, -9.2972e-01, 5.7418e-01],\n", - " [ 1.1679e+00, -7.8923e-01, -1.7175e-01, 3.6727e-01, 1.1527e+00],\n", - " [ 1.8616e+00, -3.0584e-01, -2.1084e+00, -1.4724e+00, 2.0712e-01]],\n", + " [[-0.6059, 0.3442, 0.6854, -0.0933, 1.8850],\n", + " [-0.2040, -2.0479, 0.8991, 1.1162, 0.0855],\n", + " [-1.6792, -0.4797, 0.4558, 0.4763, -0.8374]],\n", "\n", - " [[-2.6172e-01, -2.2121e+00, -7.4566e-01, 1.2002e+00, -7.5803e-01],\n", - " [ 5.2748e-01, 3.9661e-01, 1.6845e+00, 9.5950e-01, -3.9241e-01],\n", - " [-7.8969e-01, -5.8632e-01, 6.8664e-01, -7.5952e-01, 1.0505e+00]],\n", + " [[-1.2463, -1.3887, 1.2930, 0.5651, 0.9994],\n", + " [-0.1023, -0.4523, -2.0760, 1.6500, 0.5962],\n", + " [ 0.7221, 0.1171, 0.1437, -0.2839, -0.5374]],\n", "\n", - " [[ 1.8870e+00, -3.2287e-01, -9.8743e-03, -8.0076e-01, -9.8274e-01],\n", - " [-9.8188e-02, 1.4190e+00, -2.8326e-01, 8.4426e-01, -5.9501e-01],\n", - " [-2.5563e-01, -2.1496e+00, -3.9752e-01, 5.6409e-01, 1.1811e+00]],\n", + " [[-0.3637, -1.2232, -2.0972, -0.7830, -0.1663],\n", + " [-0.5369, 0.0789, -0.9869, 0.6352, 1.5502],\n", + " [ 0.3313, 0.2594, 1.5771, 0.6336, 1.0914]],\n", "\n", - " [[ 4.9171e-01, -8.9689e-01, 8.8269e-01, 1.7786e+00, -6.3821e-02],\n", - " [-3.8188e-01, -1.4144e+00, -7.0545e-01, -1.0574e+00, 1.4502e+00],\n", - " [ 2.6302e-01, -1.0660e+00, 4.6133e-01, -1.0306e+00, 1.2888e+00]]],\n", + " [[-1.0102, 0.4594, 0.5603, 0.1587, 0.2164],\n", + " [-1.3799, -0.1682, -1.9153, 0.9154, 1.8860],\n", + " [-0.0694, -0.8951, -0.6851, 0.6067, 1.3203]]],\n", " grad_fn=)" ] }, - "execution_count": 28, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -1076,7 +1220,7 @@ }, { "cell_type": "markdown", - "id": "55fd3c7b", + "id": "3c9e6820", "metadata": {}, "source": [ "We can now create a transformer easily" @@ -1084,8 +1228,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "id": "19e0659e", + "execution_count": 32, + "id": "928bcebc", "metadata": {}, "outputs": [], "source": [ @@ -1100,8 +1244,8 @@ }, { "cell_type": "code", - "execution_count": 30, - "id": "fd865d55", + "execution_count": 33, + "id": "d672d546", "metadata": {}, "outputs": [], "source": [ @@ -1120,56 +1264,56 @@ }, { "cell_type": "markdown", - "id": "86885fbc", + "id": "f522bd3a", "metadata": {}, "source": [ - "For an encoder, we can do it easily as follow" + "For an encoder, we just need to take the same tokens for both queries, keys and values." ] }, { "cell_type": "code", - "execution_count": 31, - "id": "9c9d6d10", + "execution_count": 34, + "id": "f51e79cf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[ 0.5433, -2.1209, 0.6917, 0.1949, -0.2383],\n", - " [-0.2084, 1.1901, 0.1548, -1.2804, 0.3292],\n", - " [ 1.8350, -1.2308, 0.2563, -0.9090, 0.7924]],\n", + "tensor([[[-7.0456e-01, 6.2688e-01, -1.1026e+00, 2.2786e-02, -4.8655e-02],\n", + " [-9.7513e-01, 1.4375e+00, 9.4504e-01, 2.0078e+00, -8.5485e-01],\n", + " [-1.3192e+00, 8.7940e-01, -1.2088e+00, 4.6428e-01, -1.7003e-01]],\n", "\n", - " [[ 0.7642, -0.8773, 0.6102, -1.9610, -0.2547],\n", - " [ 0.8582, -0.8673, 1.6634, -0.5311, -1.3557],\n", - " [ 0.5985, -0.6759, 0.8806, -0.0338, 1.1816]],\n", + " [[-2.4683e+00, 1.1900e+00, -4.6999e-01, 8.1202e-01, -1.0019e+00],\n", + " [ 9.6744e-01, 1.0676e+00, 1.0292e+00, 4.7986e-01, -1.2177e+00],\n", + " [ 3.2124e-02, 4.6237e-01, -7.1896e-01, 2.0720e-03, -1.6574e-01]],\n", "\n", - " [[ 0.1173, -1.0630, -1.1206, -0.9292, 0.6910],\n", - " [-0.5904, 0.2967, 1.0146, -0.1985, -0.1382],\n", - " [ 2.3575, -0.4132, 1.5436, -1.2285, -0.3390]],\n", + " [[-2.2703e-01, 1.5631e+00, 1.1274e+00, 8.1163e-02, -1.5204e+00],\n", + " [-1.8939e+00, 7.4907e-01, -1.6144e+00, 7.3381e-01, 7.8596e-01],\n", + " [-7.7463e-02, 2.4682e-01, 4.6115e-01, 3.5791e-01, -7.7312e-01]],\n", "\n", - " [[-0.0839, 0.3025, 1.2402, -1.1337, -0.3312],\n", - " [ 0.1848, -0.5644, -0.9238, -1.2078, 0.7454],\n", - " [ 0.9539, -0.4605, 2.5865, -0.8541, -0.4540]],\n", + " [[ 4.2789e-01, 7.7004e-02, -5.2232e-01, -1.3905e+00, -6.8685e-01],\n", + " [-9.8940e-01, 6.9261e-02, 1.8176e+00, 1.2323e+00, -2.8591e-01],\n", + " [-1.9008e+00, 1.2424e+00, 9.0587e-01, 3.6788e-01, -3.6444e-01]],\n", "\n", - " [[ 2.2012, 0.3198, -0.1624, -1.0218, -0.2963],\n", - " [-1.1105, 0.0330, -0.6841, 0.2308, 0.1053],\n", - " [ 2.2178, -1.1405, 0.0187, -0.9278, 0.2166]],\n", + " [[ 1.4538e+00, 4.6922e-01, -8.1502e-01, -3.0426e-01, 3.3914e-01],\n", + " [ 4.5448e-02, 5.8241e-01, -5.1411e-02, 2.1194e-01, -3.5672e-01],\n", + " [-2.6675e+00, 1.0907e+00, 9.3493e-01, 4.4590e-01, -1.3785e+00]],\n", "\n", - " [[ 0.1762, 0.1206, -0.9111, -0.8480, 0.8383],\n", - " [-0.0356, 1.9972, 1.2416, -1.9975, 0.2798],\n", - " [-1.5899, 0.6967, 0.2188, -0.2017, 0.0147]],\n", + " [[-8.6485e-01, 1.1146e+00, -9.4181e-02, -1.1407e-01, 6.2847e-01],\n", + " [-9.5162e-01, 1.4542e+00, 1.2157e-02, -2.0240e-01, -1.5317e+00],\n", + " [-9.3001e-01, 9.8748e-01, 9.6898e-01, 1.2264e+00, -1.7035e+00]],\n", "\n", - " [[ 0.5766, -1.0382, -0.3026, -0.9678, 0.1475],\n", - " [-0.2206, 0.6656, -0.0131, -0.7690, -0.2181],\n", - " [ 1.7309, -0.6701, 2.5907, -0.5225, -0.9894]],\n", + " [[ 1.8318e-01, 8.8817e-01, -4.1605e-01, -7.2077e-02, 1.0248e+00],\n", + " [ 1.1950e+00, -3.1990e-01, -3.0087e+00, -8.1052e-01, 4.2547e-01],\n", + " [ 4.8264e-02, 8.2594e-01, -4.8231e-01, -2.2883e-01, 7.4760e-01]],\n", "\n", - " [[ 0.8108, -0.8168, 0.2928, -0.8560, 0.6063],\n", - " [ 1.3993, -1.0990, 2.3342, -0.6302, -0.6805],\n", - " [-0.3689, -1.0603, 1.0844, -0.3315, -0.6845]]],\n", + " [[-2.0185e+00, 2.7137e-01, 1.9260e-01, -3.7461e-01, -1.4131e+00],\n", + " [ 3.0497e-01, 4.7218e-01, 1.4051e+00, -7.2168e-01, -7.7026e-01],\n", + " [ 1.3967e+00, -6.0792e-01, 1.1108e-01, 1.7145e+00, 3.7664e-02]]],\n", " grad_fn=)" ] }, - "execution_count": 31, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1183,56 +1327,56 @@ }, { "cell_type": "markdown", - "id": "d994eaa5", + "id": "2fd013d3", "metadata": {}, "source": [ - "For a decoder we have " + "For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values." ] }, { "cell_type": "code", - "execution_count": 32, - "id": "a1c3783e", + "execution_count": 35, + "id": "21d299fa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[ 0.9111, -0.9565, 0.8885, -0.1153, -1.1195],\n", - " [-1.0822, 1.9428, -0.1168, -1.2518, -0.4014],\n", - " [ 1.6768, -1.1258, 0.4677, -0.1539, 0.4362]],\n", + "tensor([[[-0.3380, 0.7571, -1.0049, 0.7433, -0.1268],\n", + " [-1.2908, 1.1133, -0.1441, 1.9060, -1.1740],\n", + " [-1.2169, 1.1862, -0.9880, 0.9077, -0.3301]],\n", "\n", - " [[ 0.2126, -0.9831, 0.8968, -1.4953, -1.0055],\n", - " [ 0.8913, 0.0649, 0.6585, -0.7806, -2.0843],\n", - " [ 1.5738, 0.4048, 0.6604, 0.7997, 0.1859]],\n", + " [[-0.9989, 1.0225, 0.0516, 1.4268, -1.3924],\n", + " [-0.4546, 1.6237, -0.9002, 1.0320, -1.1606],\n", + " [-0.5801, 0.9752, -1.1864, 0.4610, 0.0803]],\n", "\n", - " [[-0.1787, -0.4171, -1.1724, -0.7935, 0.4380],\n", - " [-0.3860, 1.7195, 0.8667, -1.2160, -0.2223],\n", - " [ 1.6916, 0.1081, 1.4817, -1.3484, -0.5712]],\n", + " [[-0.9055, 1.3019, -0.1604, 1.2191, -1.8533],\n", + " [-1.2209, 1.0449, -0.6953, 1.0392, 0.4056],\n", + " [-0.3828, 0.6094, 0.0309, 0.9272, -1.3600]],\n", "\n", - " [[ 0.3393, 0.4130, 0.7333, -1.0726, -0.8733],\n", - " [ 0.9412, -0.4980, -0.0214, -0.0619, 0.4108],\n", - " [ 1.5666, -0.0143, 1.5400, -1.2883, -2.1144]],\n", + " [[-0.3067, 0.0189, -0.9604, -0.3966, -0.3716],\n", + " [-1.0679, 0.4198, 1.2470, 1.9662, -0.8475],\n", + " [-1.6165, 1.1001, -0.3376, 1.4791, -0.3264]],\n", "\n", - " [[ 2.3167, 0.4312, -0.4086, 0.1072, -1.3709],\n", - " [-0.6496, 0.0503, -0.5773, 0.4249, -0.6153],\n", - " [ 2.2431, -0.4351, -0.3872, -0.3861, -0.7434]],\n", + " [[-0.0980, 0.6583, -1.2214, 0.2261, -0.4485],\n", + " [-0.6073, 1.5547, -0.3092, 1.1512, -1.0601],\n", + " [-1.4279, 1.2677, 0.1266, 1.4727, -1.2850]],\n", "\n", - " [[ 0.7115, 0.0280, -0.2837, -0.1418, 0.2284],\n", - " [-0.3786, 1.7935, 1.5064, -1.9357, -0.8157],\n", - " [-0.8887, 1.5851, -0.0724, -0.4470, -0.8894]],\n", + " [[-0.4164, 0.7499, 0.2536, 0.1846, 0.1112],\n", + " [-0.4129, 1.0208, -0.3762, 0.7487, -1.3215],\n", + " [-0.1598, 0.5029, -0.3288, 1.9525, -2.5087]],\n", "\n", - " [[ 0.7716, -0.7878, 0.1092, -0.8322, 0.1236],\n", - " [-0.2652, 0.5158, 0.1450, -1.0936, -0.9748],\n", - " [ 1.3011, -0.0770, 2.7170, -0.5874, -1.0653]],\n", + " [[-0.3682, 0.8304, -0.8471, 0.5485, 0.4438],\n", + " [-0.8762, 1.1699, -2.6552, 0.4682, 0.1896],\n", + " [-0.3391, 1.0842, -1.0764, 0.5564, 0.8712]],\n", "\n", - " [[ 0.7893, -0.3034, 0.4487, -0.8292, 0.2726],\n", - " [ 1.0496, -0.7812, 1.6862, -0.5165, -1.8188],\n", - " [ 0.6703, -0.0561, 1.4446, -0.5093, -1.5469]]],\n", + " [[-1.0284, 0.3804, 0.4793, 0.4096, -2.1200],\n", + " [-0.2695, 1.0064, -0.2066, 0.1593, -0.7439],\n", + " [ 0.9516, 0.8927, -0.9734, 1.9427, -0.8802]]],\n", " grad_fn=)" ] }, - "execution_count": 32, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1244,10 +1388,18 @@ "tokens[\"X_to\"]" ] }, + { + "cell_type": "markdown", + "id": "b6483e4f", + "metadata": {}, + "source": [ + "Now we can look at both models:" + ] + }, { "cell_type": "code", - "execution_count": 33, - "id": "346ebe20", + "execution_count": 36, + "id": "a5548e93", "metadata": {}, "outputs": [ { @@ -1571,7 +1723,7 @@ ")" ] }, - "execution_count": 33, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1582,8 +1734,8 @@ }, { "cell_type": "code", - "execution_count": 34, - "id": "e309c6ca", + "execution_count": 37, + "id": "af5a02ea", "metadata": {}, "outputs": [ { @@ -1907,7 +2059,7 @@ ")" ] }, - "execution_count": 34, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1919,7 +2071,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59b7ecfa", + "id": "4fec2070", "metadata": {}, "outputs": [], "source": [] From f4039fa9bb0053da6ff8d4ced3667cc2501a1c56 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Jul 2022 17:39:34 +0100 Subject: [PATCH 03/21] init --- torchrl/modules/tensordict_module/sequence.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 3e0ff4a7b50..b9964a4cc3b 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -118,27 +118,18 @@ def __init__( in_keys = [] out_keys = [] for module in modules: - in_keys += module.in_keys + # we sometimes use in_keys to select keys of a tensordict that are + # necessary to run a TensorDictModule. If a key is an intermediary in + # the chain, there is no reason why it should belong to the input + # TensorDict. + in_keys += [key for key in module.in_keys if key not in out_keys] out_keys += module.out_keys - # in_keys = [] - # for in_key in in_keys_tmp: - # if (in_key not in in_keys) and (in_key not in out_keys): - # in_keys.append(in_key) - # if not len(in_keys): - # raise RuntimeError( - # "in_keys empty. Please ensure that there is at least one input " - # "key that is not part of the output key set." - # ) + out_keys = [ out_key for i, out_key in enumerate(out_keys) if out_key not in out_keys[i + 1 :] ] - # we sometimes use in_keys to select keys of a tensordict that are - # necessary to run a TensorDictModule. If a key is an intermediary in - # the chain, there is not reason why it should belong to the input - # TensorDict. - in_keys = [in_key for in_key in in_keys if in_key not in out_keys] super().__init__( spec=None, From 1337768c43287e73e6c6c3d328db493203f1d9f9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Jul 2022 17:40:24 +0100 Subject: [PATCH 04/21] init From 04fe61cb5a0225ee92bbbff012f0573713512f91 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Fri, 8 Jul 2022 18:20:47 +0100 Subject: [PATCH 05/21] Added suggered changes and cleaned up --- tutorials/tensordict.ipynb | 473 ++++----- tutorials/tensordictmodule.ipynb | 1590 ++++++++++++++++-------------- 2 files changed, 1040 insertions(+), 1023 deletions(-) diff --git a/tutorials/tensordict.ipynb b/tutorials/tensordict.ipynb index c21c3d56009..40270f3d0d5 100644 --- a/tutorials/tensordict.ipynb +++ b/tutorials/tensordict.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "182c2c65", + "id": "c0174624", "metadata": {}, "source": [ "# TensorDict tutorial" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "30d2c125", + "id": "139b8238", "metadata": {}, "source": [ "TensorDict is a new tensor structure introduced in torchrl. \n", @@ -24,7 +24,7 @@ }, { "cell_type": "markdown", - "id": "e84a89f7", + "id": "5adeede6", "metadata": {}, "source": [ "#### Improving the modularity of codes" @@ -32,7 +32,7 @@ }, { "cell_type": "markdown", - "id": "f34a8870", + "id": "3a11d7f7", "metadata": {}, "source": [ "Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. \n", @@ -80,7 +80,7 @@ }, { "cell_type": "markdown", - "id": "7d54fa30", + "id": "0d128d7c", "metadata": {}, "source": [ "#### Can't i do this with a python dict?" @@ -88,7 +88,7 @@ }, { "cell_type": "markdown", - "id": "0f677019", + "id": "c2e2536a", "metadata": {}, "source": [ "One could argue that you could achieve the same results with a dataset that outputs a pytorch dict. \n", @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "0cf1c399", + "id": "b1303de8", "metadata": {}, "source": [ "## TensorDict dictionary features" @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "c14aaa07", + "id": "79458114", "metadata": {}, "source": [ "TensorDict shares a lot of features with python dictionaries" @@ -164,7 +164,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "f00c4e68", + "id": "70a39ae6", "metadata": {}, "outputs": [], "source": [ @@ -175,7 +175,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "85250bcb", + "id": "4009045f", "metadata": {}, "outputs": [ { @@ -199,7 +199,7 @@ }, { "cell_type": "markdown", - "id": "63b37416", + "id": "0f543f6c", "metadata": {}, "source": [ "If we want to access a certain key, it is explicit:" @@ -208,7 +208,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "1340830f", + "id": "0235a1a7", "metadata": {}, "outputs": [ { @@ -240,7 +240,7 @@ }, { "cell_type": "markdown", - "id": "c6797535", + "id": "47026db0", "metadata": {}, "source": [ "Also works with get()" @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "9a57a8df", + "id": "360f7c26", "metadata": {}, "outputs": [ { @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "75891945", + "id": "1f200416", "metadata": {}, "source": [ "#### TensorDict.keys()\n", @@ -291,7 +291,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "8c25f688", + "id": "a9882c10", "metadata": {}, "outputs": [ { @@ -310,7 +310,7 @@ }, { "cell_type": "markdown", - "id": "4eb389f9", + "id": "1fecc00f", "metadata": {}, "source": [ "#### TensorDict.values()\n", @@ -320,13 +320,13 @@ { "cell_type": "code", "execution_count": 6, - "id": "eb78d850", + "id": "0e72b88a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -341,7 +341,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "496e6904", + "id": "5fd21818", "metadata": {}, "outputs": [ { @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "f0ad92f9", + "id": "a820a435", "metadata": {}, "source": [ "#### TensorDict.set()\n", @@ -370,7 +370,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "b7983bac", + "id": "96c29680", "metadata": {}, "outputs": [ { @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "99b51010", + "id": "45e59abc", "metadata": {}, "source": [ "#### TensorDict.update()\n", @@ -434,7 +434,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "002ede14", + "id": "96f216bf", "metadata": {}, "outputs": [ { @@ -481,7 +481,7 @@ }, { "cell_type": "markdown", - "id": "ad0ac1b6", + "id": "9feb8632", "metadata": {}, "source": [ "#### TensorDict del key\n", @@ -491,7 +491,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "e8d2b807", + "id": "7afc943e", "metadata": {}, "outputs": [ { @@ -509,7 +509,7 @@ }, { "cell_type": "markdown", - "id": "b0a07ee4", + "id": "5fb04f41", "metadata": {}, "source": [ "## TensorDict as a pytorch Tensor" @@ -517,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "db0c756e", + "id": "1c834589", "metadata": {}, "source": [ "But wait? Can't we do this with a classical dict? \n", @@ -533,7 +533,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "fa5978eb", + "id": "f22d4b66", "metadata": {}, "outputs": [], "source": [ @@ -544,7 +544,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "d6a72121", + "id": "daaa4e28", "metadata": {}, "outputs": [ { @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "742a73a0", + "id": "0aa2e7cd", "metadata": {}, "source": [ "#### Batch size" @@ -578,7 +578,7 @@ }, { "cell_type": "markdown", - "id": "2929efaa", + "id": "6d17d8b4", "metadata": {}, "source": [ "Tensor dict has a batch size which is shared across all tensors. The batch size can be [], unidimensional or multidimensional according to your needs." @@ -587,7 +587,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "4cca800c", + "id": "3c0217e0", "metadata": {}, "outputs": [ { @@ -604,7 +604,7 @@ }, { "cell_type": "markdown", - "id": "d228d0ec", + "id": "bda9c6aa", "metadata": {}, "source": [ "You cannot have items that don't share the batch size inside the same TensorDict:" @@ -613,7 +613,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "d1a40031", + "id": "d08b9d11", "metadata": {}, "outputs": [ { @@ -624,9 +624,9 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtensordict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:356\u001b[0m, in \u001b[0;36m_TensorDict.update\u001b[0;34m(self, input_dict_or_td, clone, inplace, **kwargs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clone:\n\u001b[1;32m 355\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mclone()\n\u001b[0;32m--> 356\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:1646\u001b[0m, in \u001b[0;36mTensorDict.set\u001b[0;34m(self, key, value, inplace, _run_checks, _meta_val)\u001b[0m\n\u001b[1;32m 1644\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict \u001b[38;5;129;01mand\u001b[39;00m inplace:\n\u001b[1;32m 1645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_(key, value)\n\u001b[0;32m-> 1646\u001b[0m proc_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1647\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_tensor_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_shared\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# check_tensor_shape=_run_checks\u001b[39;00m\n\u001b[1;32m 1652\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict[key] \u001b[38;5;241m=\u001b[39m proc_value\n\u001b[1;32m 1653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict_meta[key] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1654\u001b[0m MetaTensor(\n\u001b[1;32m 1655\u001b[0m proc_value,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m _meta_val\n\u001b[1;32m 1661\u001b[0m )\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:480\u001b[0m, in \u001b[0;36m_TensorDict._process_tensor\u001b[0;34m(self, input, check_device, check_tensor_shape, check_shared)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_shared is not authorized anymore\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_tensor_shape \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch dimension mismatch, got self.batch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and tensor.shape[:self.batch_dims]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 483\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 484\u001b[0m )\n\u001b[1;32m 486\u001b[0m \u001b[38;5;66;03m# minimum ndimension is 1\u001b[39;00m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:361\u001b[0m, in \u001b[0;36m_TensorDict.update\u001b[0;34m(self, input_dict_or_td, clone, inplace, **kwargs)\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clone:\n\u001b[1;32m 360\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mclone()\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:1657\u001b[0m, in \u001b[0;36mTensorDict.set\u001b[0;34m(self, key, value, inplace, _run_checks, _meta_val)\u001b[0m\n\u001b[1;32m 1655\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict \u001b[38;5;129;01mand\u001b[39;00m inplace:\n\u001b[1;32m 1656\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_(key, value)\n\u001b[0;32m-> 1657\u001b[0m proc_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_tensor_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_shared\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1662\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# check_tensor_shape=_run_checks\u001b[39;00m\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict[key] \u001b[38;5;241m=\u001b[39m proc_value\n\u001b[1;32m 1664\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict_meta[key] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1665\u001b[0m MetaTensor(\n\u001b[1;32m 1666\u001b[0m proc_value,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1671\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m _meta_val\n\u001b[1;32m 1672\u001b[0m )\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:485\u001b[0m, in \u001b[0;36m_TensorDict._process_tensor\u001b[0;34m(self, input, check_device, check_tensor_shape, check_shared)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_shared is not authorized anymore\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_tensor_shape \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size:\n\u001b[0;32m--> 485\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 486\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch dimension mismatch, got self.batch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and tensor.shape[:self.batch_dims]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 489\u001b[0m )\n\u001b[1;32m 491\u001b[0m \u001b[38;5;66;03m# minimum ndimension is 1\u001b[39;00m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "\u001b[0;31mRuntimeError\u001b[0m: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])" ] } @@ -638,7 +638,7 @@ { "cell_type": "code", "execution_count": null, - "id": "068e6ff1", + "id": "f8ec89ec", "metadata": {}, "outputs": [], "source": [ @@ -647,7 +647,7 @@ }, { "cell_type": "markdown", - "id": "ff6b1c53", + "id": "93ba3948", "metadata": {}, "source": [ "#### Devices" @@ -655,7 +655,7 @@ }, { "cell_type": "markdown", - "id": "90079564", + "id": "12efa0f5", "metadata": {}, "source": [ "TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device" @@ -663,7 +663,7 @@ }, { "cell_type": "markdown", - "id": "f2a19c16", + "id": "6e7c9d12", "metadata": {}, "source": [ "#### Memory sharing" @@ -671,7 +671,7 @@ }, { "cell_type": "markdown", - "id": "5e9333ff", + "id": "fd540505", "metadata": {}, "source": [ "When on cpu, you can use either `TensorDict.memmap_()` or `TensorDict.share_memory_()` to setup you tensor dict as a memmap or send it to shared memory resp." @@ -679,7 +679,7 @@ }, { "cell_type": "markdown", - "id": "6146d69b", + "id": "fc71180a", "metadata": {}, "source": [ "#### Cloning\n", @@ -689,7 +689,7 @@ { "cell_type": "code", "execution_count": 15, - "id": "bbb6e3bf", + "id": "c88f387f", "metadata": {}, "outputs": [ { @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "b2c791fd", + "id": "9251e116", "metadata": {}, "source": [ "### Tensor operations\n", @@ -745,7 +745,7 @@ }, { "cell_type": "markdown", - "id": "e9876782", + "id": "5f8159df", "metadata": {}, "source": [ "#### Slicing and indexing\n", @@ -755,7 +755,7 @@ { "cell_type": "code", "execution_count": 16, - "id": "1c0ff66b", + "id": "8409cfb4", "metadata": {}, "outputs": [ { @@ -782,7 +782,7 @@ { "cell_type": "code", "execution_count": 17, - "id": "6c5aa4aa", + "id": "af1387cf", "metadata": {}, "outputs": [ { @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": 18, - "id": "e4ba5aa9", + "id": "d7133060", "metadata": {}, "outputs": [ { @@ -835,7 +835,7 @@ }, { "cell_type": "markdown", - "id": "f797c29b", + "id": "c69d84a1", "metadata": {}, "source": [ "#### Setting values with indexing\n", @@ -845,7 +845,7 @@ { "cell_type": "code", "execution_count": 19, - "id": "8cf86358", + "id": "19f235ee", "metadata": {}, "outputs": [ { @@ -894,7 +894,7 @@ }, { "cell_type": "markdown", - "id": "7436749a", + "id": "3148a34a", "metadata": {}, "source": [ "#### Masking" @@ -902,7 +902,7 @@ }, { "cell_type": "markdown", - "id": "6bb0f008", + "id": "b8c2075f", "metadata": {}, "source": [ "We can perform masking on the indexes. Mask must be a tensor." @@ -911,7 +911,7 @@ { "cell_type": "code", "execution_count": 20, - "id": "9a18e36e", + "id": "7c6bf5e1", "metadata": {}, "outputs": [ { @@ -938,7 +938,7 @@ }, { "cell_type": "markdown", - "id": "c09de567", + "id": "e9835c39", "metadata": {}, "source": [ "TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\\*shape), squeeze(dim), unsqueeze(dim), permute(\\*dims) requiring the operations to comply with the batch_size" @@ -946,7 +946,7 @@ }, { "cell_type": "markdown", - "id": "b3e73bbb", + "id": "d74d0fb9", "metadata": {}, "source": [ "#### View" @@ -954,7 +954,7 @@ }, { "cell_type": "markdown", - "id": "b99a9340", + "id": "c5ec195a", "metadata": {}, "source": [ "Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict" @@ -963,7 +963,7 @@ { "cell_type": "code", "execution_count": 21, - "id": "6eb7e6ac", + "id": "d909b9cd", "metadata": {}, "outputs": [ { @@ -991,7 +991,7 @@ }, { "cell_type": "markdown", - "id": "cd31b3b7", + "id": "6cf6665c", "metadata": {}, "source": [ "#### Permute" @@ -999,8 +999,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "8329a908", + "execution_count": 22, + "id": "f1774e04", "metadata": {}, "outputs": [ { @@ -1017,7 +1017,7 @@ "\top=permute(dims=(1, 0)))" ] }, - "execution_count": 23, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1028,7 +1028,7 @@ }, { "cell_type": "markdown", - "id": "48979edd", + "id": "e57c354e", "metadata": {}, "source": [ "#### Reshape\n", @@ -1037,8 +1037,8 @@ }, { "cell_type": "code", - "execution_count": 155, - "id": "a3871109", + "execution_count": 23, + "id": "c9b3ab59", "metadata": {}, "outputs": [ { @@ -1053,7 +1053,7 @@ " is_shared=False)" ] }, - "execution_count": 155, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1062,48 +1062,9 @@ "tensordict.reshape(-1)" ] }, - { - "cell_type": "code", - "execution_count": 185, - "id": "9d650b34", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False),\n", - " TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False),\n", - " TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False))" - ] - }, - "execution_count": 185, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, { "cell_type": "markdown", - "id": "9c7a6e55", + "id": "48933b38", "metadata": {}, "source": [ "#### Unbind and Cat\n", @@ -1112,8 +1073,8 @@ }, { "cell_type": "code", - "execution_count": 188, - "id": "2480deab", + "execution_count": 24, + "id": "68e8975f", "metadata": {}, "outputs": [ { @@ -1153,7 +1114,7 @@ " is_shared=False)" ] }, - "execution_count": 188, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1167,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "78381ee8", + "id": "3a7a3f68", "metadata": {}, "source": [ "#### Squeeze and Unsqueeze\n", @@ -1176,8 +1137,8 @@ }, { "cell_type": "code", - "execution_count": 164, - "id": "a38f2b53", + "execution_count": 25, + "id": "0c3a54a0", "metadata": {}, "outputs": [ { @@ -1208,7 +1169,7 @@ }, { "cell_type": "markdown", - "id": "242e845f", + "id": "eadb1bc0", "metadata": {}, "source": [ "#### Stacking" @@ -1216,7 +1177,7 @@ }, { "cell_type": "markdown", - "id": "5593343a", + "id": "f1dff375", "metadata": {}, "source": [ "TensorDict supports stacking, stacking is done in a lazy fashion, returning a LazyStackedTensorDict item." @@ -1224,8 +1185,8 @@ }, { "cell_type": "code", - "execution_count": 157, - "id": "01075269", + "execution_count": 26, + "id": "8f90e26a", "metadata": {}, "outputs": [ { @@ -1253,7 +1214,7 @@ }, { "cell_type": "markdown", - "id": "ba554281", + "id": "22161bd3", "metadata": {}, "source": [ "If we want to have a contiguous tensordict, we can call `.to_tensordict()` or `.contiguous()`. It is recommended to perform this operation before accessing the values of the stacked tensordict for efficiency purposes" @@ -1261,8 +1222,8 @@ }, { "cell_type": "code", - "execution_count": 92, - "id": "0f546995", + "execution_count": 27, + "id": "b8c66c4a", "metadata": {}, "outputs": [ { @@ -1293,7 +1254,7 @@ }, { "cell_type": "markdown", - "id": "17ea6285", + "id": "e86afd08", "metadata": {}, "source": [ "## How to use them in practice? The tensor the TensorDictModule" @@ -1301,7 +1262,7 @@ }, { "cell_type": "markdown", - "id": "1bb76da1", + "id": "78192367", "metadata": {}, "source": [ "Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." @@ -1309,8 +1270,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "ac5424b8", + "execution_count": 28, + "id": "7a9c34d7", "metadata": {}, "outputs": [], "source": [ @@ -1320,7 +1281,7 @@ }, { "cell_type": "markdown", - "id": "2cff8307", + "id": "2198427a", "metadata": {}, "source": [ "### Example: Simple Linear layer" @@ -1328,7 +1289,7 @@ }, { "cell_type": "markdown", - "id": "5f8bfc04", + "id": "a9286f32", "metadata": {}, "source": [ "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." @@ -1336,8 +1297,8 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "c45b16aa", + "execution_count": 29, + "id": "f42fd847", "metadata": {}, "outputs": [ { @@ -1353,7 +1314,7 @@ " is_shared=False)" ] }, - "execution_count": 22, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1366,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "d9cad800", + "id": "f1452e41", "metadata": {}, "source": [ "We can also do it inplace" @@ -1374,8 +1335,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "d18be016", + "execution_count": 30, + "id": "c2e6db70", "metadata": {}, "outputs": [ { @@ -1390,7 +1351,7 @@ " is_shared=False)" ] }, - "execution_count": 23, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1403,7 +1364,7 @@ }, { "cell_type": "markdown", - "id": "3269f1d3", + "id": "7fbae8ef", "metadata": {}, "source": [ "### Example: 2 input merging with 2 linear layer" @@ -1411,7 +1372,7 @@ }, { "cell_type": "markdown", - "id": "e1c1f6d0", + "id": "a4c36399", "metadata": {}, "source": [ "Now lets imagine a more complex network that takes 2 entries and average them into a single output" @@ -1419,8 +1380,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "bfcdec34", + "execution_count": 31, + "id": "f4692e76", "metadata": {}, "outputs": [], "source": [ @@ -1435,8 +1396,8 @@ }, { "cell_type": "code", - "execution_count": 25, - "id": "9183ab24", + "execution_count": 32, + "id": "55f66f45", "metadata": {}, "outputs": [ { @@ -1453,7 +1414,7 @@ " is_shared=False)" ] }, - "execution_count": 25, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1466,7 +1427,7 @@ }, { "cell_type": "markdown", - "id": "371b9214", + "id": "f13b043e", "metadata": {}, "source": [ "### Example: 1 input to 2 outputs linear layer\n", @@ -1475,8 +1436,8 @@ }, { "cell_type": "code", - "execution_count": 26, - "id": "9d9fccc3", + "execution_count": 33, + "id": "c3dc55f1", "metadata": {}, "outputs": [], "source": [ @@ -1491,8 +1452,8 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "cd5a891f", + "execution_count": 34, + "id": "52fc40c4", "metadata": {}, "outputs": [ { @@ -1509,7 +1470,7 @@ " is_shared=False)" ] }, - "execution_count": 27, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1522,7 +1483,7 @@ }, { "cell_type": "markdown", - "id": "e70a2ec1", + "id": "a40b756d", "metadata": {}, "source": [ "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", @@ -1531,7 +1492,7 @@ }, { "cell_type": "markdown", - "id": "8e7d459c", + "id": "897a7533", "metadata": {}, "source": [ "### Example: A transformer with TensorDict?\n", @@ -1545,8 +1506,8 @@ }, { "cell_type": "code", - "execution_count": 28, - "id": "80f2ddab", + "execution_count": 35, + "id": "05e6a450", "metadata": {}, "outputs": [], "source": [ @@ -1606,7 +1567,7 @@ }, { "cell_type": "markdown", - "id": "98188b17", + "id": "f5086060", "metadata": {}, "source": [ "Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block." @@ -1614,8 +1575,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "id": "3a88bcc6", + "execution_count": 36, + "id": "d3f130d9", "metadata": {}, "outputs": [], "source": [ @@ -1637,8 +1598,8 @@ }, { "cell_type": "code", - "execution_count": 30, - "id": "15ea7d75", + "execution_count": 37, + "id": "c2d2a519", "metadata": {}, "outputs": [ { @@ -1658,7 +1619,7 @@ " is_shared=False)" ] }, - "execution_count": 30, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1683,7 +1644,7 @@ }, { "cell_type": "markdown", - "id": "e960bbd6", + "id": "e513287f", "metadata": {}, "source": [ "The output of the transformer layer can now be found at tokens[\"X_to\"]" @@ -1691,48 +1652,48 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "858f205a", + "execution_count": 38, + "id": "d58f4d89", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.8713, -1.2626, 1.3218, -0.2947, 1.6938],\n", - " [-0.7374, -0.6038, -1.4958, -0.3975, 0.0741],\n", - " [-0.5026, 0.0095, 1.8102, 0.2308, 1.0256]],\n", + "tensor([[[ 1.3338, 0.5568, 1.5257, -0.9140, -0.7399],\n", + " [-1.1602, -1.1171, -0.0355, -0.3293, 0.5658],\n", + " [ 0.0528, 1.3808, -1.1389, -1.2089, 1.2281]],\n", "\n", - " [[ 0.4484, 0.0689, 1.0152, -1.1690, 1.2264],\n", - " [-0.5891, 0.0737, 0.7038, 1.0404, 0.0131],\n", - " [ 1.0050, -2.4708, -0.7890, -1.0160, 0.4391]],\n", + " [[ 1.0998, -0.2481, -0.5027, 0.3425, -1.3240],\n", + " [-1.2873, -2.0554, 1.1529, -0.5357, -0.2973],\n", + " [ 1.0694, 0.3098, 0.3401, 0.3393, 1.5966]],\n", "\n", - " [[-1.8364, -0.5181, -0.5258, 0.5166, 1.8120],\n", - " [ 1.3389, 0.1451, -0.1267, -0.7637, 1.6104],\n", - " [-0.1859, -0.4134, -1.4359, -0.1131, 0.4961]],\n", + " [[ 0.2762, -0.3396, -1.6140, 0.0866, 1.0148],\n", + " [-1.0849, -2.2679, 0.1260, 1.2162, 1.4198],\n", + " [ 1.0519, 0.2551, -0.3484, 0.0038, 0.2044]],\n", "\n", - " [[-1.0511, 0.1636, -0.9440, -0.2152, -0.4874],\n", - " [ 1.4676, 2.0405, 0.2846, 0.5990, 1.0199],\n", - " [-1.5073, 0.0980, -1.5943, -0.1160, 0.2421]],\n", + " [[-1.6978, -1.0813, -0.9879, -0.1534, -0.8177],\n", + " [ 1.0584, -1.3029, -0.5820, 0.4057, 1.6585],\n", + " [ 0.1717, 0.3136, 0.8307, 1.1407, 1.0436]],\n", "\n", - " [[-0.6059, 0.3442, 0.6854, -0.0933, 1.8850],\n", - " [-0.2040, -2.0479, 0.8991, 1.1162, 0.0855],\n", - " [-1.6792, -0.4797, 0.4558, 0.4763, -0.8374]],\n", + " [[ 0.9372, -1.2179, -0.5154, -1.0837, 0.3776],\n", + " [-2.3519, -0.2721, -0.1398, 1.7179, -0.5435],\n", + " [ 0.3842, 0.8391, 0.4074, 0.9533, 0.5076]],\n", "\n", - " [[-1.2463, -1.3887, 1.2930, 0.5651, 0.9994],\n", - " [-0.1023, -0.4523, -2.0760, 1.6500, 0.5962],\n", - " [ 0.7221, 0.1171, 0.1437, -0.2839, -0.5374]],\n", + " [[-0.0587, -0.1878, 0.5516, 0.0882, 0.9291],\n", + " [ 1.1866, -1.2275, -0.3984, -0.0310, -0.4985],\n", + " [-1.1856, -1.0034, 2.6606, -0.9745, 0.1493]],\n", "\n", - " [[-0.3637, -1.2232, -2.0972, -0.7830, -0.1663],\n", - " [-0.5369, 0.0789, -0.9869, 0.6352, 1.5502],\n", - " [ 0.3313, 0.2594, 1.5771, 0.6336, 1.0914]],\n", + " [[-0.5460, 0.6498, -0.2246, -2.5625, -0.1683],\n", + " [ 0.2090, -0.9892, 0.5802, 0.5241, 0.5589],\n", + " [ 2.2136, 0.5872, 0.0154, -0.2186, -0.6290]],\n", "\n", - " [[-1.0102, 0.4594, 0.5603, 0.1587, 0.2164],\n", - " [-1.3799, -0.1682, -1.9153, 0.9154, 1.8860],\n", - " [-0.0694, -0.8951, -0.6851, 0.6067, 1.3203]]],\n", + " [[-1.3134, 0.0524, 0.3805, 0.3787, 0.6957],\n", + " [ 1.8009, -0.5925, -1.7172, -0.9240, 1.6469],\n", + " [-0.1907, -0.8003, 0.1441, -0.6741, 1.1129]]],\n", " grad_fn=)" ] }, - "execution_count": 31, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1743,7 +1704,7 @@ }, { "cell_type": "markdown", - "id": "6d7a09de", + "id": "bc2a210d", "metadata": {}, "source": [ "We can now create a transformer easily" @@ -1751,8 +1712,8 @@ }, { "cell_type": "code", - "execution_count": 32, - "id": "a48be821", + "execution_count": 39, + "id": "d6241e1f", "metadata": {}, "outputs": [], "source": [ @@ -1767,8 +1728,8 @@ }, { "cell_type": "code", - "execution_count": 33, - "id": "9e749bfc", + "execution_count": 40, + "id": "544f6335", "metadata": {}, "outputs": [], "source": [ @@ -1787,7 +1748,7 @@ }, { "cell_type": "markdown", - "id": "723d4062", + "id": "a704bddb", "metadata": {}, "source": [ "For an encoder, we just need to take the same tokens for both queries, keys and values." @@ -1795,48 +1756,48 @@ }, { "cell_type": "code", - "execution_count": 34, - "id": "1c02bf4e", + "execution_count": 41, + "id": "52dbf4fb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-7.0456e-01, 6.2688e-01, -1.1026e+00, 2.2786e-02, -4.8655e-02],\n", - " [-9.7513e-01, 1.4375e+00, 9.4504e-01, 2.0078e+00, -8.5485e-01],\n", - " [-1.3192e+00, 8.7940e-01, -1.2088e+00, 4.6428e-01, -1.7003e-01]],\n", + "tensor([[[-2.5683e-01, 4.5024e-01, 9.4612e-01, -1.1000e+00, 6.6867e-01],\n", + " [-6.6982e-01, 8.6136e-01, 1.5597e+00, -1.0453e+00, 1.2007e+00],\n", + " [-6.9352e-01, -3.2060e-02, -4.3720e-01, -2.1913e+00, 7.3923e-01]],\n", "\n", - " [[-2.4683e+00, 1.1900e+00, -4.6999e-01, 8.1202e-01, -1.0019e+00],\n", - " [ 9.6744e-01, 1.0676e+00, 1.0292e+00, 4.7986e-01, -1.2177e+00],\n", - " [ 3.2124e-02, 4.6237e-01, -7.1896e-01, 2.0720e-03, -1.6574e-01]],\n", + " [[ 9.1422e-01, -2.7420e-01, -2.1522e-01, -1.1503e+00, 1.6155e+00],\n", + " [ 3.1163e-01, -7.0292e-02, 1.0830e+00, -2.3652e+00, 1.4391e-01],\n", + " [ 3.9072e-01, -3.0454e-01, 8.5802e-01, -1.4170e+00, 4.7974e-01]],\n", "\n", - " [[-2.2703e-01, 1.5631e+00, 1.1274e+00, 8.1163e-02, -1.5204e+00],\n", - " [-1.8939e+00, 7.4907e-01, -1.6144e+00, 7.3381e-01, 7.8596e-01],\n", - " [-7.7463e-02, 2.4682e-01, 4.6115e-01, 3.5791e-01, -7.7312e-01]],\n", + " [[-7.3433e-01, 6.3703e-01, 4.6386e-01, -1.3121e+00, 6.5461e-01],\n", + " [-3.1634e-01, 2.7175e-01, 4.1725e-01, -6.2716e-01, 6.5264e-01],\n", + " [-1.8677e+00, 1.3871e+00, 9.4117e-01, -1.7092e+00, 1.1414e+00]],\n", "\n", - " [[ 4.2789e-01, 7.7004e-02, -5.2232e-01, -1.3905e+00, -6.8685e-01],\n", - " [-9.8940e-01, 6.9261e-02, 1.8176e+00, 1.2323e+00, -2.8591e-01],\n", - " [-1.9008e+00, 1.2424e+00, 9.0587e-01, 3.6788e-01, -3.6444e-01]],\n", + " [[-1.9929e-01, 7.1908e-01, 1.1399e+00, -2.0360e+00, 9.4091e-01],\n", + " [ 1.8042e-01, 8.9284e-03, 5.7997e-01, -6.0130e-01, 1.1606e+00],\n", + " [-3.7275e-01, 9.1703e-01, -1.2281e-01, -2.2440e+00, -7.0603e-02]],\n", "\n", - " [[ 1.4538e+00, 4.6922e-01, -8.1502e-01, -3.0426e-01, 3.3914e-01],\n", - " [ 4.5448e-02, 5.8241e-01, -5.1411e-02, 2.1194e-01, -3.5672e-01],\n", - " [-2.6675e+00, 1.0907e+00, 9.3493e-01, 4.4590e-01, -1.3785e+00]],\n", + " [[-7.2842e-01, -5.2198e-01, -8.3247e-01, -1.5439e-01, -3.7558e-01],\n", + " [-8.5731e-01, 1.0901e+00, 1.5317e+00, -2.1651e+00, 1.4644e+00],\n", + " [-4.1692e-01, 4.9001e-01, 9.7057e-01, -3.6126e-01, 8.6664e-01]],\n", "\n", - " [[-8.6485e-01, 1.1146e+00, -9.4181e-02, -1.1407e-01, 6.2847e-01],\n", - " [-9.5162e-01, 1.4542e+00, 1.2157e-02, -2.0240e-01, -1.5317e+00],\n", - " [-9.3001e-01, 9.8748e-01, 9.6898e-01, 1.2264e+00, -1.7035e+00]],\n", + " [[ 9.5534e-01, 1.4917e-01, -7.1200e-03, -2.0727e+00, 1.1170e+00],\n", + " [-2.9475e-02, 8.1727e-01, 1.3990e+00, -2.5198e-01, 7.2835e-01],\n", + " [-1.2625e+00, -9.9376e-01, -8.0062e-01, -8.8569e-01, 1.1377e+00]],\n", "\n", - " [[ 1.8318e-01, 8.8817e-01, -4.1605e-01, -7.2077e-02, 1.0248e+00],\n", - " [ 1.1950e+00, -3.1990e-01, -3.0087e+00, -8.1052e-01, 4.2547e-01],\n", - " [ 4.8264e-02, 8.2594e-01, -4.8231e-01, -2.2883e-01, 7.4760e-01]],\n", + " [[-7.5436e-01, 4.3068e-01, 7.9711e-01, -1.2966e+00, 9.4933e-03],\n", + " [ 4.6172e-01, 9.7587e-01, 1.4781e+00, -8.3284e-01, 8.6749e-01],\n", + " [-1.6161e+00, 3.7033e-01, -4.2724e-02, -1.8422e+00, 9.9399e-01]],\n", "\n", - " [[-2.0185e+00, 2.7137e-01, 1.9260e-01, -3.7461e-01, -1.4131e+00],\n", - " [ 3.0497e-01, 4.7218e-01, 1.4051e+00, -7.2168e-01, -7.7026e-01],\n", - " [ 1.3967e+00, -6.0792e-01, 1.1108e-01, 1.7145e+00, 3.7664e-02]]],\n", + " [[-2.0276e-01, -1.3223e-01, -2.9282e-01, -7.4340e-02, 1.6503e+00],\n", + " [ 1.0073e+00, -1.5842e-01, 1.0905e+00, -1.8213e+00, 6.9849e-01],\n", + " [-1.8112e-02, -9.7116e-01, 1.1706e+00, -1.9438e+00, -2.3435e-03]]],\n", " grad_fn=)" ] }, - "execution_count": 34, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -1850,7 +1811,7 @@ }, { "cell_type": "markdown", - "id": "3b3a19b6", + "id": "f379ac76", "metadata": {}, "source": [ "For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values." @@ -1858,48 +1819,48 @@ }, { "cell_type": "code", - "execution_count": 35, - "id": "ec28da52", + "execution_count": 42, + "id": "913b3979", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.3380, 0.7571, -1.0049, 0.7433, -0.1268],\n", - " [-1.2908, 1.1133, -0.1441, 1.9060, -1.1740],\n", - " [-1.2169, 1.1862, -0.9880, 0.9077, -0.3301]],\n", + "tensor([[[ 0.2650, 0.8114, 1.6235, -1.8246, -0.1827],\n", + " [-0.5331, 0.8769, 1.4354, -0.7027, 0.0041],\n", + " [-0.0240, 0.0151, 0.7448, -1.9483, -0.5608]],\n", "\n", - " [[-0.9989, 1.0225, 0.0516, 1.4268, -1.3924],\n", - " [-0.4546, 1.6237, -0.9002, 1.0320, -1.1606],\n", - " [-0.5801, 0.9752, -1.1864, 0.4610, 0.0803]],\n", + " [[ 0.6278, 0.2316, 0.6965, -1.4007, 0.4978],\n", + " [ 0.5148, 0.1977, 1.5761, -2.1702, -0.3718],\n", + " [ 0.5269, 0.1489, 1.0460, -1.5976, -0.5239]],\n", "\n", - " [[-0.9055, 1.3019, -0.1604, 1.2191, -1.8533],\n", - " [-1.2209, 1.0449, -0.6953, 1.0392, 0.4056],\n", - " [-0.3828, 0.6094, 0.0309, 0.9272, -1.3600]],\n", + " [[-0.2073, 1.7283, 0.7957, -1.2614, -0.3906],\n", + " [-0.2481, 1.4483, 0.3833, -0.3303, -0.3051],\n", + " [-0.7997, 1.4176, 0.4302, -1.9264, -0.7345]],\n", "\n", - " [[-0.3067, 0.0189, -0.9604, -0.3966, -0.3716],\n", - " [-1.0679, 0.4198, 1.2470, 1.9662, -0.8475],\n", - " [-1.6165, 1.1001, -0.3376, 1.4791, -0.3264]],\n", + " [[ 0.4208, 0.4412, 1.3680, -2.0761, -0.3130],\n", + " [ 0.6140, 0.1374, 1.0931, 0.1626, 0.0274],\n", + " [-0.4122, 0.9318, 0.7548, -1.8592, -1.2908]],\n", "\n", - " [[-0.0980, 0.6583, -1.2214, 0.2261, -0.4485],\n", - " [-0.6073, 1.5547, -0.3092, 1.1512, -1.0601],\n", - " [-1.4279, 1.2677, 0.1266, 1.4727, -1.2850]],\n", + " [[-0.7726, 0.0486, 0.0335, -0.0701, -0.9185],\n", + " [-0.2630, 1.0334, 1.7184, -2.6263, 0.0084],\n", + " [-0.5563, 1.1644, 0.8858, -0.0215, 0.3358]],\n", "\n", - " [[-0.4164, 0.7499, 0.2536, 0.1846, 0.1112],\n", - " [-0.4129, 1.0208, -0.3762, 0.7487, -1.3215],\n", - " [-0.1598, 0.5029, -0.3288, 1.9525, -2.5087]],\n", + " [[ 1.0746, 0.6018, 0.5270, -0.9513, 0.1145],\n", + " [-0.6757, 1.6353, 1.4940, -0.1454, 0.1573],\n", + " [ 0.0575, -1.1170, -1.3292, -1.9089, 0.4654]],\n", "\n", - " [[-0.3682, 0.8304, -0.8471, 0.5485, 0.4438],\n", - " [-0.8762, 1.1699, -2.6552, 0.4682, 0.1896],\n", - " [-0.3391, 1.0842, -1.0764, 0.5564, 0.8712]],\n", + " [[-0.3958, 0.6817, 0.7094, -1.1415, -0.1320],\n", + " [ 0.2856, 1.6527, 1.2931, -0.6426, 0.4360],\n", + " [-0.9961, 0.7746, -0.0636, -2.4226, -0.0389]],\n", "\n", - " [[-1.0284, 0.3804, 0.4793, 0.4096, -2.1200],\n", - " [-0.2695, 1.0064, -0.2066, 0.1593, -0.7439],\n", - " [ 0.9516, 0.8927, -0.9734, 1.9427, -0.8802]]],\n", + " [[ 0.0635, -0.2169, 0.2886, -0.9645, 0.0584],\n", + " [ 1.1510, 0.3362, 1.9643, -0.8624, -0.6656],\n", + " [ 0.8300, -0.7676, 1.4487, -1.8439, -0.8198]]],\n", " grad_fn=)" ] }, - "execution_count": 35, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1913,7 +1874,7 @@ }, { "cell_type": "markdown", - "id": "b17e1d61", + "id": "4c3698ba", "metadata": {}, "source": [ "Now we can look at both models:" @@ -1921,8 +1882,8 @@ }, { "cell_type": "code", - "execution_count": 36, - "id": "2c270fd5", + "execution_count": 43, + "id": "f9c19fc1", "metadata": {}, "outputs": [ { @@ -2246,7 +2207,7 @@ ")" ] }, - "execution_count": 36, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -2257,8 +2218,8 @@ }, { "cell_type": "code", - "execution_count": 37, - "id": "ca1e26ab", + "execution_count": 44, + "id": "4fb30839", "metadata": {}, "outputs": [ { @@ -2582,7 +2543,7 @@ ")" ] }, - "execution_count": 37, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } @@ -2590,14 +2551,6 @@ "source": [ "transformer_decoder" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33c8bd5a", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 07d59f43be2..3e5974b0bc1 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "bf31fa1a", + "id": "86dc2115", "metadata": {}, "source": [ "# The TensorDictModule" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "7270f31e", + "id": "cdeec13d", "metadata": {}, "source": [ "Make sure to first read the tensordict tutorial" @@ -18,7 +18,7 @@ }, { "cell_type": "markdown", - "id": "f4a1987c", + "id": "97ca082a", "metadata": {}, "source": [ "How do we use the TensorDict it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." @@ -27,7 +27,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "e26ce6cb", + "id": "97c23098", "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "markdown", - "id": "33bf9e9e", + "id": "4f1d8733", "metadata": {}, "source": [ "### Example: Simple Linear layer" @@ -47,7 +47,7 @@ }, { "cell_type": "markdown", - "id": "9ea1a71b", + "id": "311e8412", "metadata": {}, "source": [ "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." @@ -56,7 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "c7ddc935", + "id": "d27e2c36", "metadata": {}, "outputs": [ { @@ -85,7 +85,7 @@ }, { "cell_type": "markdown", - "id": "f77864a9", + "id": "6c8c4078", "metadata": {}, "source": [ "We can also do it inplace" @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "cf7cbec9", + "id": "b9b645c5", "metadata": {}, "outputs": [ { @@ -122,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "da06f3be", + "id": "60788b57", "metadata": {}, "source": [ "### Example: 2 input merging with 2 linear layer" @@ -130,7 +130,7 @@ }, { "cell_type": "markdown", - "id": "44bb7628", + "id": "dd382cb3", "metadata": {}, "source": [ "Now lets imagine a more complex network that takes 2 entries and average them into a single output" @@ -139,7 +139,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "f7843416", + "id": "d9934513", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +155,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "250ea5ea", + "id": "5f3173f0", "metadata": {}, "outputs": [ { @@ -185,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "c59bba03", + "id": "0b1ca7da", "metadata": {}, "source": [ "### Example: 1 input to 2 outputs linear layer\n", @@ -195,7 +195,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "beb948ef", + "id": "2b35f920", "metadata": {}, "outputs": [], "source": [ @@ -211,7 +211,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "544962c2", + "id": "28409cae", "metadata": {}, "outputs": [ { @@ -241,7 +241,7 @@ }, { "cell_type": "markdown", - "id": "48405202", + "id": "ad4a37df", "metadata": {}, "source": [ "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", @@ -250,7 +250,7 @@ }, { "cell_type": "markdown", - "id": "ed69ff38", + "id": "ad9b159e", "metadata": {}, "source": [ "### Example: A transformer with TensorDict?\n", @@ -265,7 +265,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "6b6278c1", + "id": "e6cd37b9", "metadata": {}, "outputs": [], "source": [ @@ -325,7 +325,7 @@ }, { "cell_type": "markdown", - "id": "cca50494", + "id": "47df2142", "metadata": {}, "source": [ "Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block." @@ -333,8 +333,8 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "da8fd12f", + "execution_count": 9, + "id": "a8568f1c", "metadata": {}, "outputs": [], "source": [ @@ -352,8 +352,8 @@ }, { "cell_type": "code", - "execution_count": 32, - "id": "ff60c29e", + "execution_count": 10, + "id": "a0ccbbd5", "metadata": {}, "outputs": [ { @@ -373,7 +373,7 @@ " is_shared=False)" ] }, - "execution_count": 32, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -412,7 +412,7 @@ }, { "cell_type": "markdown", - "id": "62fc34e6", + "id": "12118015", "metadata": {}, "source": [ "The output of the transformer layer can now be found at tokens[\"X_to\"]" @@ -420,48 +420,48 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "be55bd8a", + "execution_count": 11, + "id": "5c28e899", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.8713, -1.2626, 1.3218, -0.2947, 1.6938],\n", - " [-0.7374, -0.6038, -1.4958, -0.3975, 0.0741],\n", - " [-0.5026, 0.0095, 1.8102, 0.2308, 1.0256]],\n", + "tensor([[[ 0.1732, 1.0383, -1.5069, -0.8307, -0.5521],\n", + " [-0.8138, 0.8951, -0.6802, -1.3608, 0.3892],\n", + " [-0.3947, 0.7221, -0.4580, 1.4660, 1.9133]],\n", "\n", - " [[ 0.4484, 0.0689, 1.0152, -1.1690, 1.2264],\n", - " [-0.5891, 0.0737, 0.7038, 1.0404, 0.0131],\n", - " [ 1.0050, -2.4708, -0.7890, -1.0160, 0.4391]],\n", + " [[-0.6783, 0.5550, 0.4649, 1.5167, 1.6206],\n", + " [ 0.6388, -0.3218, 0.7403, 0.2852, 0.1950],\n", + " [-0.6361, -1.8308, 0.0806, -1.9018, -0.7281]],\n", "\n", - " [[-1.8364, -0.5181, -0.5258, 0.5166, 1.8120],\n", - " [ 1.3389, 0.1451, -0.1267, -0.7637, 1.6104],\n", - " [-0.1859, -0.4134, -1.4359, -0.1131, 0.4961]],\n", + " [[ 0.9160, -0.4022, -0.8094, 0.6584, -1.5115],\n", + " [ 0.9658, -1.1464, -0.9496, -1.4521, 0.9001],\n", + " [ 1.2254, 1.4921, -0.1453, 0.7521, -0.4933]],\n", "\n", - " [[-1.0511, 0.1636, -0.9440, -0.2152, -0.4874],\n", - " [ 1.4676, 2.0405, 0.2846, 0.5990, 1.0199],\n", - " [-1.5073, 0.0980, -1.5943, -0.1160, 0.2421]],\n", + " [[-0.3162, 1.2837, 0.3168, 1.7858, 0.0061],\n", + " [-0.1342, -0.4649, -1.2371, -0.1789, -0.4095],\n", + " [ 1.8135, -0.2081, -0.6934, -1.9989, 0.4353]],\n", "\n", - " [[-0.6059, 0.3442, 0.6854, -0.0933, 1.8850],\n", - " [-0.2040, -2.0479, 0.8991, 1.1162, 0.0855],\n", - " [-1.6792, -0.4797, 0.4558, 0.4763, -0.8374]],\n", + " [[-0.4506, 0.7247, -1.1347, -0.1918, 1.0423],\n", + " [-0.4250, 0.6152, -0.7718, -1.5260, 0.6411],\n", + " [ 2.3155, 0.8685, -0.9266, 0.1958, -0.9766]],\n", "\n", - " [[-1.2463, -1.3887, 1.2930, 0.5651, 0.9994],\n", - " [-0.1023, -0.4523, -2.0760, 1.6500, 0.5962],\n", - " [ 0.7221, 0.1171, 0.1437, -0.2839, -0.5374]],\n", + " [[ 1.4929, -0.4170, -0.6244, 0.0319, -0.4917],\n", + " [-0.9035, 0.1552, 0.6767, -0.2369, 1.3653],\n", + " [ 0.0061, -1.9266, -1.4868, 1.3227, 1.0363]],\n", "\n", - " [[-0.3637, -1.2232, -2.0972, -0.7830, -0.1663],\n", - " [-0.5369, 0.0789, -0.9869, 0.6352, 1.5502],\n", - " [ 0.3313, 0.2594, 1.5771, 0.6336, 1.0914]],\n", + " [[ 0.5236, -0.9703, 0.8447, -0.1412, 1.3548],\n", + " [ 0.3037, 1.2733, -1.3643, 0.1551, 1.8588],\n", + " [-1.2808, -0.5453, -0.0203, -1.3418, -0.6499]],\n", "\n", - " [[-1.0102, 0.4594, 0.5603, 0.1587, 0.2164],\n", - " [-1.3799, -0.1682, -1.9153, 0.9154, 1.8860],\n", - " [-0.0694, -0.8951, -0.6851, 0.6067, 1.3203]]],\n", + " [[-0.8650, -0.4386, -0.3220, 1.9545, -0.8611],\n", + " [-1.2402, 0.2068, 1.1227, -0.1070, 2.0424],\n", + " [-0.4647, 0.0610, 0.5333, -0.2866, -1.3355]]],\n", " grad_fn=)" ] }, - "execution_count": 31, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -472,7 +472,7 @@ }, { "cell_type": "markdown", - "id": "7084629a", + "id": "a91fd810", "metadata": {}, "source": [ "We can now create a transformer easily" @@ -480,24 +480,30 @@ }, { "cell_type": "code", - "execution_count": 32, - "id": "85026014", + "execution_count": 12, + "id": "3dbdb5c0", "metadata": {}, "outputs": [], "source": [ - "class TransformerTensorDict(nn.Module):\n", - " def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", - " super().__init__()\n", - " self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])\n", - " def forward(self, X_tensor_dict):\n", - " for transformer_block in self.transformer:\n", - " transformer_block(X_tensor_dict)" + "class TransformerTensorDict(TensorDictSequence):\n", + " def __init__(\n", + " self,\n", + " num_blocks,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads\n", + " ):\n", + " super().__init__(*[TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])" ] }, { "cell_type": "code", - "execution_count": 33, - "id": "b1355ec6", + "execution_count": 13, + "id": "317e3161", "metadata": {}, "outputs": [], "source": [ @@ -509,14 +515,18 @@ "batch_size = 8\n", "num_heads = 2\n", "\n", - "tokens = TensorDict({\"X_to\":torch.randn(batch_size, to_len, to_dim), \"X_from\":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])\n", - "\n", - "\n" + "tokens = TensorDict(\n", + " {\n", + " \"X_to\":torch.randn(batch_size, to_len, to_dim),\n", + " \"X_from\":torch.randn(batch_size, from_len, from_dim)\n", + " },\n", + " batch_size=[batch_size]\n", + ")" ] }, { "cell_type": "markdown", - "id": "51c8c402", + "id": "f1dd80dc", "metadata": {}, "source": [ "For an encoder, we just need to take the same tokens for both queries, keys and values." @@ -524,54 +534,63 @@ }, { "cell_type": "code", - "execution_count": 34, - "id": "0ccbbf16", + "execution_count": 14, + "id": "011c42e6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-7.0456e-01, 6.2688e-01, -1.1026e+00, 2.2786e-02, -4.8655e-02],\n", - " [-9.7513e-01, 1.4375e+00, 9.4504e-01, 2.0078e+00, -8.5485e-01],\n", - " [-1.3192e+00, 8.7940e-01, -1.2088e+00, 4.6428e-01, -1.7003e-01]],\n", + "tensor([[[ 0.1218, 1.6380, 0.5283, -1.2097, -2.1976],\n", + " [ 0.2963, 0.8917, 0.7497, -0.1242, -0.3194],\n", + " [ 0.5004, 1.3847, -0.8797, -0.3203, -1.0600]],\n", "\n", - " [[-2.4683e+00, 1.1900e+00, -4.6999e-01, 8.1202e-01, -1.0019e+00],\n", - " [ 9.6744e-01, 1.0676e+00, 1.0292e+00, 4.7986e-01, -1.2177e+00],\n", - " [ 3.2124e-02, 4.6237e-01, -7.1896e-01, 2.0720e-03, -1.6574e-01]],\n", + " [[ 0.3451, 1.8204, 1.1999, -1.6162, -0.2519],\n", + " [-0.1433, 1.0322, -0.6018, -0.6799, -0.5113],\n", + " [ 0.4102, 0.7934, 0.8460, -0.9515, -1.6914]],\n", "\n", - " [[-2.2703e-01, 1.5631e+00, 1.1274e+00, 8.1163e-02, -1.5204e+00],\n", - " [-1.8939e+00, 7.4907e-01, -1.6144e+00, 7.3381e-01, 7.8596e-01],\n", - " [-7.7463e-02, 2.4682e-01, 4.6115e-01, 3.5791e-01, -7.7312e-01]],\n", + " [[-0.2757, 1.4416, 0.8119, -0.6991, -1.2287],\n", + " [ 1.0713, 1.3480, -0.0162, -1.2294, -2.0809],\n", + " [-0.1340, 0.6964, 0.6220, 0.3045, -0.6318]],\n", "\n", - " [[ 4.2789e-01, 7.7004e-02, -5.2232e-01, -1.3905e+00, -6.8685e-01],\n", - " [-9.8940e-01, 6.9261e-02, 1.8176e+00, 1.2323e+00, -2.8591e-01],\n", - " [-1.9008e+00, 1.2424e+00, 9.0587e-01, 3.6788e-01, -3.6444e-01]],\n", + " [[ 0.1715, 2.3998, 1.4253, -1.2790, -1.2869],\n", + " [ 0.1807, 0.6704, 0.4473, -1.1606, -0.7941],\n", + " [-0.0601, 0.0541, 0.4753, -0.2883, -0.9554]],\n", "\n", - " [[ 1.4538e+00, 4.6922e-01, -8.1502e-01, -3.0426e-01, 3.3914e-01],\n", - " [ 4.5448e-02, 5.8241e-01, -5.1411e-02, 2.1194e-01, -3.5672e-01],\n", - " [-2.6675e+00, 1.0907e+00, 9.3493e-01, 4.4590e-01, -1.3785e+00]],\n", + " [[ 0.2194, 1.3032, -0.0246, -0.1858, -1.3101],\n", + " [-0.5537, 0.8703, -1.1531, -0.8894, -1.7302],\n", + " [ 1.3058, 0.2501, 1.4498, 1.0468, -0.5985]],\n", "\n", - " [[-8.6485e-01, 1.1146e+00, -9.4181e-02, -1.1407e-01, 6.2847e-01],\n", - " [-9.5162e-01, 1.4542e+00, 1.2157e-02, -2.0240e-01, -1.5317e+00],\n", - " [-9.3001e-01, 9.8748e-01, 9.6898e-01, 1.2264e+00, -1.7035e+00]],\n", + " [[ 0.8845, 0.2856, 0.9751, -0.6513, 0.4614],\n", + " [-0.5029, 1.6529, 1.3564, -0.8560, -1.0680],\n", + " [-0.7042, 0.5084, 0.6723, -1.2830, -1.7314]],\n", "\n", - " [[ 1.8318e-01, 8.8817e-01, -4.1605e-01, -7.2077e-02, 1.0248e+00],\n", - " [ 1.1950e+00, -3.1990e-01, -3.0087e+00, -8.1052e-01, 4.2547e-01],\n", - " [ 4.8264e-02, 8.2594e-01, -4.8231e-01, -2.2883e-01, 7.4760e-01]],\n", + " [[ 0.2185, 0.6484, -1.4813, -0.9310, -0.6846],\n", + " [-0.0514, 0.6041, 0.8461, -0.6310, -0.9047],\n", + " [ 0.5254, 2.2600, 1.2691, -0.4674, -1.2203]],\n", "\n", - " [[-2.0185e+00, 2.7137e-01, 1.9260e-01, -3.7461e-01, -1.4131e+00],\n", - " [ 3.0497e-01, 4.7218e-01, 1.4051e+00, -7.2168e-01, -7.7026e-01],\n", - " [ 1.3967e+00, -6.0792e-01, 1.1108e-01, 1.7145e+00, 3.7664e-02]]],\n", + " [[ 1.1002, 1.4143, -1.1888, -0.1106, -1.7328],\n", + " [ 0.2170, 0.5818, 0.2111, 0.4387, -1.0139],\n", + " [-0.4513, 1.5186, 1.0518, -1.1227, -0.9134]]],\n", " grad_fn=)" ] }, - "execution_count": 34, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "transformer_encoder = TransformerTensorDict(6, \"X_to\", \"X_to\", to_dim, to_len, to_dim, latent_dim, num_heads)\n", + "transformer_encoder = TransformerTensorDict(\n", + " 6,\n", + " \"X_to\",\n", + " \"X_to\",\n", + " to_dim,\n", + " to_len,\n", + " to_dim,\n", + " latent_dim,\n", + " num_heads\n", + ")\n", "\n", "transformer_encoder(tokens)\n", "tokens[\"X_to\"]" @@ -579,7 +598,7 @@ }, { "cell_type": "markdown", - "id": "ddc5fe38", + "id": "63d709b2", "metadata": {}, "source": [ "For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values." @@ -587,54 +606,63 @@ }, { "cell_type": "code", - "execution_count": 35, - "id": "16a72a4d", + "execution_count": 15, + "id": "f9ed7016", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.3380, 0.7571, -1.0049, 0.7433, -0.1268],\n", - " [-1.2908, 1.1133, -0.1441, 1.9060, -1.1740],\n", - " [-1.2169, 1.1862, -0.9880, 0.9077, -0.3301]],\n", + "tensor([[[-0.5958, 1.9269, 1.4385, -1.7099, -1.2024],\n", + " [-0.1923, 0.6554, 0.8646, -0.9089, 0.4366],\n", + " [-0.1544, 0.5690, 0.1171, -1.3729, 0.1287]],\n", "\n", - " [[-0.9989, 1.0225, 0.0516, 1.4268, -1.3924],\n", - " [-0.4546, 1.6237, -0.9002, 1.0320, -1.1606],\n", - " [-0.5801, 0.9752, -1.1864, 0.4610, 0.0803]],\n", + " [[-0.6611, 1.6326, 0.9590, -1.3373, 0.1869],\n", + " [-0.5724, 0.9246, 0.3437, -1.2203, -0.2542],\n", + " [-0.2347, 0.9846, 1.5354, -1.5329, -0.7540]],\n", "\n", - " [[-0.9055, 1.3019, -0.1604, 1.2191, -1.8533],\n", - " [-1.2209, 1.0449, -0.6953, 1.0392, 0.4056],\n", - " [-0.3828, 0.6094, 0.0309, 0.9272, -1.3600]],\n", + " [[-1.0947, 1.5092, 1.3402, -1.0084, -0.4252],\n", + " [-0.3901, 1.3286, 1.2193, -1.4316, -1.3246],\n", + " [-0.7211, 0.7870, 0.2914, -0.2573, 0.1774]],\n", "\n", - " [[-0.3067, 0.0189, -0.9604, -0.3966, -0.3716],\n", - " [-1.0679, 0.4198, 1.2470, 1.9662, -0.8475],\n", - " [-1.6165, 1.1001, -0.3376, 1.4791, -0.3264]],\n", + " [[-1.0461, 2.1735, 1.8569, -0.7953, -0.6545],\n", + " [-0.3190, 0.8452, 0.8193, -1.0471, -0.7304],\n", + " [-0.2703, 0.0935, 0.5782, -0.6885, -0.8154]],\n", "\n", - " [[-0.0980, 0.6583, -1.2214, 0.2261, -0.4485],\n", - " [-0.6073, 1.5547, -0.3092, 1.1512, -1.0601],\n", - " [-1.4279, 1.2677, 0.1266, 1.4727, -1.2850]],\n", + " [[-0.3828, 1.8378, 0.1272, -0.4188, -0.1151],\n", + " [-0.9871, 0.3603, -1.4740, -1.2111, -1.6092],\n", + " [ 0.2361, 0.9122, 1.5031, 0.4612, 0.7603]],\n", "\n", - " [[-0.4164, 0.7499, 0.2536, 0.1846, 0.1112],\n", - " [-0.4129, 1.0208, -0.3762, 0.7487, -1.3215],\n", - " [-0.1598, 0.5029, -0.3288, 1.9525, -2.5087]],\n", + " [[ 0.1466, 0.6786, 1.1725, -0.5156, 0.1942],\n", + " [-0.9457, 1.7856, 1.7606, -0.9622, -0.9245],\n", + " [-0.8863, 0.1064, 0.6751, -1.0641, -1.2214]],\n", "\n", - " [[-0.3682, 0.8304, -0.8471, 0.5485, 0.4438],\n", - " [-0.8762, 1.1699, -2.6552, 0.4682, 0.1896],\n", - " [-0.3391, 1.0842, -1.0764, 0.5564, 0.8712]],\n", + " [[-0.1457, 0.3580, -0.4237, -0.6528, -0.9734],\n", + " [-0.6525, 0.8373, 0.8691, -0.7156, -0.2167],\n", + " [-0.8809, 2.1679, 2.0259, -0.8261, -0.7707]],\n", "\n", - " [[-1.0284, 0.3804, 0.4793, 0.4096, -2.1200],\n", - " [-0.2695, 1.0064, -0.2066, 0.1593, -0.7439],\n", - " [ 0.9516, 0.8927, -0.9734, 1.9427, -0.8802]]],\n", + " [[ 0.1682, 1.5683, -1.0913, -0.9862, -0.5086],\n", + " [-0.1570, 0.8816, 0.4636, -0.5736, -0.1888],\n", + " [-0.9000, 1.8169, 1.4407, -1.5377, -0.3962]]],\n", " grad_fn=)" ] }, - "execution_count": 35, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "transformer_decoder = TransformerTensorDict(6, \"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads)\n", + "transformer_decoder = TransformerTensorDict(\n", + " 6,\n", + " \"X_to\",\n", + " \"X_from\",\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads\n", + ")\n", "\n", "transformer_decoder(tokens)\n", "tokens[\"X_to\"]" @@ -642,7 +670,7 @@ }, { "cell_type": "markdown", - "id": "1571146e", + "id": "ef98efd7", "metadata": {}, "source": [ "Now we can look at both models:" @@ -650,332 +678,346 @@ }, { "cell_type": "code", - "execution_count": 36, - "id": "42291d5d", + "execution_count": 16, + "id": "43888021", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerTensorDict(\n", - " (transformer): ModuleList(\n", - " (0): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (1): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (2): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (3): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (4): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (5): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " )\n", - ")" + " module=ModuleList(\n", + " (0): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (1): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (2): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (3): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (4): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (5): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=5, out_features=10, bias=True)\n", + " (v): Linear(in_features=5, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_to', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])" ] }, - "execution_count": 36, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -986,332 +1028,346 @@ }, { "cell_type": "code", - "execution_count": 37, - "id": "9d9af81a", + "execution_count": 17, + "id": "c619bb6d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TransformerTensorDict(\n", - " (transformer): ModuleList(\n", - " (0): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (1): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (2): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (3): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (4): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (5): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " )\n", - ")" + " module=ModuleList(\n", + " (0): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (1): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (2): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (3): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (4): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " (5): TransformerBlockTensorDict(\n", + " module=ModuleList(\n", + " (0): TensorDictModule(\n", + " module=TokensToQKV(\n", + " (q): Linear(in_features=5, out_features=10, bias=True)\n", + " (k): Linear(in_features=6, out_features=10, bias=True)\n", + " (v): Linear(in_features=6, out_features=10, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (1): TensorDictModule(\n", + " module=SplitHeads(), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['Q', 'K', 'V'])\n", + " (2): TensorDictModule(\n", + " module=Attention(\n", + " (softmax): Softmax(dim=-1)\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['Q', 'K', 'V'], \n", + " out_keys=['X_out', 'Attn'])\n", + " (3): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " (4): TensorDictModule(\n", + " module=FFN(\n", + " (FFN): Sequential(\n", + " (0): Linear(in_features=5, out_features=20, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=20, out_features=5, bias=True)\n", + " (3): Dropout(p=0.2, inplace=False)\n", + " )\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to'], \n", + " out_keys=['X_out'])\n", + " (5): TensorDictModule(\n", + " module=SkipLayerNorm(\n", + " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_out'], \n", + " out_keys=['X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", + " ), \n", + " device=cpu, \n", + " in_keys=['X_to', 'X_from', 'X_to', 'X_from', 'X_from', 'X_from', 'X_from', 'X_from'], \n", + " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])" ] }, - "execution_count": 37, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -1319,6 +1375,14 @@ "source": [ "transformer_decoder" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8900b6bb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 97bb800a2459f2a3c2cd9a9867bee9345179345d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 Jul 2022 14:24:51 +0100 Subject: [PATCH 06/21] init --- tutorials/tensordict.ipynb | 724 +++++++++++++++---------------------- 1 file changed, 288 insertions(+), 436 deletions(-) diff --git a/tutorials/tensordict.ipynb b/tutorials/tensordict.ipynb index 40270f3d0d5..28dc11a17b2 100644 --- a/tutorials/tensordict.ipynb +++ b/tutorials/tensordict.ipynb @@ -13,13 +13,13 @@ "id": "139b8238", "metadata": {}, "source": [ - "TensorDict is a new tensor structure introduced in torchrl. \n", + "`TensorDict` is a new tensor structure introduced in TorchRL. \n", "\n", - "With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. \n", + "With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. `TensorDict` aims at making it more convenient to deal with multiple tensors at the same time. \n", "\n", - "Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes. \n", + "Furthermore, different RL algorithms can deal with different input and outputs. The `TensorDict` class makes it possible to abstract away the differences between these algorithmes. \n", "\n", - "TensorDict combines the convinience of using dicts to organize your data with the power of pytorch tensors.\n" + "TensorDict combines the convinience of using `dict`s to organize your data with the power of pytorch tensors.\n" ] }, { @@ -37,14 +37,14 @@ "source": [ "Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. \n", "\n", - "We want to train 2 methods (Algo A on dataset A and algo B on dataset B) that share the same training loop. \n", + "Suppose we want to train a common algorithm over these two datasets (i.e. an algorithm that would ignore the mask or infer it when needed). \n", "\n", "In classical pytorch we would need to do the following:\n", "```python\n", "#Method A\n", "for i in range(optim_steps):\n", " images, labels = get_data_A()\n", - " loss = loss_module_A(images, labels)\n", + " loss = loss_module(images, labels)\n", " loss.backward()\n", " optim.step()\n", " optim.zero_grad()\n", @@ -54,17 +54,17 @@ "#Method B\n", "for i in range(optim_steps):\n", " images, masks, labels = get_data_B()\n", - " loss = loss_module_B(images, masks, labels)\n", + " loss = loss_module(images, labels)\n", " loss.backward()\n", " optim.step()\n", " optim.zero_grad()\n", "```\n", + "\n", "We can see that this limits the reusability of code. A lot of code has to be rewriten because of the modality difference between the 2 datasets.\n", "The idea of TensorDict is to do the following:\n", + "\n", "```python\n", "# General Method\n", - "get_data = instantiate(cfg.data)\n", - "loss_module = instantiate(cfg.module)\n", "for i in range(optim_steps):\n", " tensordict = get_data()\n", " loss = loss_module(tensordict)\n", @@ -74,8 +74,7 @@ "```\n", "\n", "\n", - "Now we can reuse the same training loop for all methods that we want. We just need to make sure that instantiate(cfg.data) and instantiate(cfg.module) maps to the desired method and data.\n", - "\n" + "Now we can reuse the same training loop across datasets and losses." ] }, { @@ -103,6 +102,7 @@ " return {\"modality_A\": torch.Tensor(torch.randn(2)), \"modality_B\": torch.Tensor(torch.randn(2))}\n", " \n", "```\n", + "\n", "However to achieve this you would need to write a complicated collate function that make sure that every modality is agregated properly.\n", "\n", "```python\n", @@ -140,9 +140,8 @@ "\n", "dataloader = Dataloader(DictDataset(), collate_fn = collate_tensordict_fn)\n", "```\n", - "This is an exemple of how TensorDict could facilitate such operations.\n", "\n", - "TensorDict inherits multiple properties from torch tensors that we will detail furtherdown, which make them quite practical." + "TensorDict inherits multiple properties from `torch.Tensor` and `dict` that we will detail furtherdown." ] }, { @@ -150,7 +149,7 @@ "id": "b1303de8", "metadata": {}, "source": [ - "## TensorDict dictionary features" + "## `TensorDict` dictionary features" ] }, { @@ -158,7 +157,7 @@ "id": "79458114", "metadata": {}, "source": [ - "TensorDict shares a lot of features with python dictionaries" + "`TensorDict` shares a lot of features with python dictionaries" ] }, { @@ -193,7 +192,9 @@ } ], "source": [ - "tensordict = TensorDict({\"a\": torch.zeros(3, 4, 5), \"b\": torch.zeros(3, 4)}, batch_size=[3, 4])\n", + "a = torch.zeros(3, 4, 5)\n", + "b = torch.zeros(3, 4)\n", + "tensordict = TensorDict({\"a\": a, \"b\": b}, batch_size=[3, 4])\n", "print(tensordict)" ] }, @@ -202,7 +203,8 @@ "id": "0f543f6c", "metadata": {}, "source": [ - "If we want to access a certain key, it is explicit:" + "### `get(key)`\n", + "If we want to access a certain key, we can index the tensordict or alternatively use the `get` method:" ] }, { @@ -215,211 +217,141 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]])\n", + "True\n", "torch.Size([3, 4, 5])\n" ] } ], "source": [ - "print(tensordict[\"a\"])\n", + "print(tensordict[\"a\"] is tensordict.get(\"a\") is a)\n", "print(tensordict[\"a\"].shape)" ] }, { "cell_type": "markdown", - "id": "47026db0", + "id": "3cc9df67-8834-4a75-8e19-6be21322c8a5", "metadata": {}, "source": [ - "Also works with get()" + "The `get` method also supports default values:" ] }, { "cell_type": "code", "execution_count": 4, - "id": "360f7c26", + "id": "d6638e51-628e-4bb1-b106-989b2b8c34be", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]])\n", - "torch.Size([3, 4, 5])\n" - ] + "data": { + "text/plain": [ + "tensor([1., 1., 1.])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "print(tensordict.get(\"a\"))\n", - "print(tensordict.get(\"a\").shape)" + "out = tensordict.get(\"foo\", torch.ones(3))\n", + "out" ] }, { "cell_type": "markdown", - "id": "1f200416", + "id": "ea8db3b3-a9c1-4ba4-8c42-dd346c8fcbdd", "metadata": {}, "source": [ - "#### TensorDict.keys()\n", - "Keys can be retrieved to TensorDict" + "## `set(key, value)`\n", + "The `set()` method can be used to set new values. Regular indexing also does the job:" ] }, { "cell_type": "code", "execution_count": 5, - "id": "a9882c10", + "id": "853e0b14-a8e6-4c29-ad98-1226a21c2ff6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "a\n", - "b\n" + "td[\"c\"] is c: True\n", + "td[\"d\"] is d: True\n" ] } ], "source": [ - "for key in tensordict.keys():\n", - " print(key)" + "c = torch.zeros((3, 4, 2, 2))\n", + "tensordict.set(\"c\", c)\n", + "print(f\"td[\\\"c\\\"] is c: {c is tensordict['c']}\")\n", + "\n", + "d = torch.zeros((3, 4, 2, 2))\n", + "tensordict[\"d\"] = d\n", + "print(f\"td[\\\"d\\\"] is d: {d is tensordict['d']}\")" ] }, { "cell_type": "markdown", - "id": "1fecc00f", + "id": "1f200416", "metadata": {}, "source": [ - "#### TensorDict.values()\n", - "The values of a TensorDict can be retrieved with the values() function. On the contrary of python dicts, the values() function return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors which can take a lot of space in memory." + "## Other methods:\n", + "### `keys`\n", + "We can access the keys of a tensordict:" ] }, { "cell_type": "code", "execution_count": 6, - "id": "0e72b88a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict.values()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5fd21818", + "id": "a9882c10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([3, 4, 5])\n", - "torch.Size([3, 4, 1])\n" + "a\n", + "b\n", + "c\n", + "d\n" ] } ], "source": [ - "for value in tensordict.values():\n", - " print(value.shape)" + "for key in tensordict.keys():\n", + " print(key)" ] }, { "cell_type": "markdown", - "id": "a820a435", + "id": "1fecc00f", "metadata": {}, "source": [ - "#### TensorDict.set()\n", - "The set function can be used to set new values" + "### `values`\n", + "The values of a `TensorDict` can be retrieved with the `values()` function. Note that, unlike python `dict`s, the `values()` method returns a generator and not a list." ] }, { "cell_type": "code", - "execution_count": 8, - "id": "96c29680", + "execution_count": 7, + "id": "5fd21818", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "c is set as tensor([[[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]],\n", - "\n", - "\n", - " [[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]],\n", - "\n", - "\n", - " [[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]]])\n" + "torch.Size([3, 4, 5])\n", + "torch.Size([3, 4, 1])\n", + "torch.Size([3, 4, 2, 2])\n", + "torch.Size([3, 4, 2, 2])\n" ] } ], "source": [ - "tensordict.set(\"c\", torch.zeros((3, 4, 2, 2)))\n", - "print(f\"c is set as {tensordict['c']}\")" + "for value in tensordict.values():\n", + " print(value.shape)" ] }, { @@ -427,13 +359,13 @@ "id": "45e59abc", "metadata": {}, "source": [ - "#### TensorDict.update()\n", - "The update function can be used to update the dict with other dict values (Or TensorDict)" + "### TensorDict.update()\n", + "The `update` method can be used to update a TensorDict with another one (or with a dict):" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "96f216bf", "metadata": {}, "outputs": [ @@ -441,42 +373,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "a is now tensor([[[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]]])\n", - "d is set as tensor([[[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]]])\n" + "a is now equal to 1: True\n", + "d is now equal to 2: True\n" ] } ], "source": [ - "tensordict.update({\"a\":torch.ones((3, 4, 5)), \"d\":torch.ones((3, 4, 2))})\n", + "tensordict.update({\"a\": torch.ones((3, 4, 5)), \"d\": 2*torch.ones((3, 4, 2))})\n", "# Also works with tensordict.update(TensorDict({\"a\":torch.ones((3, 4, 5)), \"c\":torch.ones((3, 4, 2))}, batch_size=[3,4]))\n", - "print(f\"a is now {tensordict['a']}\")\n", - "print(f\"d is set as {tensordict['d']}\")" + "print(f\"a is now equal to 1: {(tensordict['a'] == 1).all()}\")\n", + "print(f\"d is now equal to 2: {(tensordict['d'] == 2).all()}\")" ] }, { @@ -484,13 +390,13 @@ "id": "9feb8632", "metadata": {}, "source": [ - "#### TensorDict del key\n", - "TensorDict also support keys deletion with the del operator:" + "### TensorDict del key\n", + "TensorDict also support keys deletion with the `del` operator:" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "7afc943e", "metadata": {}, "outputs": [ @@ -512,7 +418,7 @@ "id": "5fb04f41", "metadata": {}, "source": [ - "## TensorDict as a pytorch Tensor" + "## TensorDict as a Tensor-like object" ] }, { @@ -532,7 +438,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "f22d4b66", "metadata": {}, "outputs": [], @@ -543,7 +449,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "daaa4e28", "metadata": {}, "outputs": [ @@ -557,15 +463,13 @@ " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", " batch_size=torch.Size([3, 4]),\n", " device=cpu,\n", - " is_shared=False)\n", - "Our Tensor dict is of size torch.Size([3, 4])\n" + " is_shared=False)\n" ] } ], "source": [ "tensordict = TensorDict({\"a\": torch.zeros(3, 4, 5), \"b\": torch.zeros(3, 4)}, batch_size=[3, 4])\n", - "print(tensordict)\n", - "print(f\"Our Tensor dict is of size {tensordict.shape}\")" + "print(tensordict)" ] }, { @@ -612,37 +516,46 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "d08b9d11", "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtensordict\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:361\u001b[0m, in \u001b[0;36m_TensorDict.update\u001b[0;34m(self, input_dict_or_td, clone, inplace, **kwargs)\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clone:\n\u001b[1;32m 360\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mclone()\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:1657\u001b[0m, in \u001b[0;36mTensorDict.set\u001b[0;34m(self, key, value, inplace, _run_checks, _meta_val)\u001b[0m\n\u001b[1;32m 1655\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict \u001b[38;5;129;01mand\u001b[39;00m inplace:\n\u001b[1;32m 1656\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_(key, value)\n\u001b[0;32m-> 1657\u001b[0m proc_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_tensor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_tensor_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_shared\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_run_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1662\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# check_tensor_shape=_run_checks\u001b[39;00m\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict[key] \u001b[38;5;241m=\u001b[39m proc_value\n\u001b[1;32m 1664\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tensordict_meta[key] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1665\u001b[0m MetaTensor(\n\u001b[1;32m 1666\u001b[0m proc_value,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1671\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m _meta_val\n\u001b[1;32m 1672\u001b[0m )\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/data/tensordict/tensordict.py:485\u001b[0m, in \u001b[0;36m_TensorDict._process_tensor\u001b[0;34m(self, input, check_device, check_tensor_shape, check_shared)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_shared is not authorized anymore\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_tensor_shape \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size:\n\u001b[0;32m--> 485\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 486\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch dimension mismatch, got self.batch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and tensor.shape[:self.batch_dims]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtensor\u001b[38;5;241m.\u001b[39mshape[: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 489\u001b[0m )\n\u001b[1;32m 491\u001b[0m \u001b[38;5;66;03m# minimum ndimension is 1\u001b[39;00m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndimension() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", - "\u001b[0;31mRuntimeError\u001b[0m: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])" + "name": "stdout", + "output_type": "stream", + "text": [ + "Caramba! We got this error: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])\n" ] } ], "source": [ - "tensordict.update({\"c\": torch.zeros(4, 3, 1)})" + "# we cannot add tensors that violate the batch size:\n", + "try:\n", + " tensordict.update({\"c\": torch.zeros(4, 3, 1)})\n", + "except RuntimeError as err:\n", + " print(f\"Caramba! We got this error: {err}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "f8ec89ec", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Caramba! We got this error: the tensor a has shape torch.Size([3, 4, 5]) which is incompatible with the new shape torch.Size([4, 4])\n" + ] + } + ], "source": [ - "tensordict.batch_size = [4,4]" + "# If we reset the batch size, it has to comply with the tensordict batch size\n", + "try:\n", + " tensordict.batch_size = [4,4]\n", + "except RuntimeError as err:\n", + " print(f\"Caramba! We got this error: {err}\")" ] }, { @@ -658,6 +571,7 @@ "id": "12efa0f5", "metadata": {}, "source": [ + "### Device\n", "TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device" ] }, @@ -666,7 +580,7 @@ "id": "6e7c9d12", "metadata": {}, "source": [ - "#### Memory sharing" + "### Memory sharing via physical memory usage" ] }, { @@ -674,7 +588,7 @@ "id": "fd540505", "metadata": {}, "source": [ - "When on cpu, you can use either `TensorDict.memmap_()` or `TensorDict.share_memory_()` to setup you tensor dict as a memmap or send it to shared memory resp." + "When on cpu, one can use either `tensordict.memmap_()` or `tensordict.share_memory_()` to send a `tensordict` to represent it as a memory-mapped collection of tensors or put it in shared memory resp." ] }, { @@ -682,13 +596,13 @@ "id": "fc71180a", "metadata": {}, "source": [ - "#### Cloning\n", + "### Cloning\n", "TensorDict supports cloning. Cloning returns the same SubTensorDict item than the original item." ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 20, "id": "c88f387f", "metadata": {}, "outputs": [ @@ -696,42 +610,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]])\n", - "tensor([[[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]]])\n" + "redefining a tensor in the clone does not impact the original tensordict: tensor(False)\n" ] } ], "source": [ "tensordict_clone = tensordict.clone()\n", "tensordict_clone[\"a\"] = torch.ones(*tensordict.shape, 5)\n", - "print(tensordict[\"a\"])\n", - "print(tensordict_clone[\"a\"])" + "print(\"redefining a tensor in the clone does not impact the original tensordict: \", (tensordict[\"a\"] == tensordict_clone[\"a\"]).all())" ] }, { @@ -748,13 +634,13 @@ "id": "5f8159df", "metadata": {}, "source": [ - "#### Slicing and indexing\n", - "Slicing and indexing is supported among the batch dimension" + "### Slicing and indexing\n", + "Slicing and indexing is supported along the batch dimensions" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "id": "8409cfb4", "metadata": {}, "outputs": [ @@ -770,7 +656,7 @@ " is_shared=False)" ] }, - "execution_count": 16, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -781,7 +667,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "id": "af1387cf", "metadata": {}, "outputs": [ @@ -797,7 +683,7 @@ " is_shared=False)" ] }, - "execution_count": 17, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -808,7 +694,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, "id": "d7133060", "metadata": {}, "outputs": [ @@ -824,7 +710,7 @@ " is_shared=False)" ] }, - "execution_count": 18, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -844,36 +730,48 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, + "id": "a9a2a860-e9c3-4112-9ae3-d6ba5b8473f7", + "metadata": {}, + "outputs": [], + "source": [ + "subtd = tensordict[:, torch.tensor([1, 3])] # a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data\n", + "tensordict.fill_(\"a\", -1)\n", + "assert (subtd[\"a\"] == -1).all() # the \"a\" key-value pair has changed" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "id": "19f235ee", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(tensor([[[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", + "(tensor([[[ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.]],\n", " \n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", + " [[ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0.]],\n", " \n", - " [[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]]),\n", - " tensor([[[1.],\n", - " [1.],\n", - " [1.],\n", - " [1.]],\n", + " [[-1., -1., -1., -1., -1.],\n", + " [-1., -1., -1., -1., -1.],\n", + " [-1., -1., -1., -1., -1.],\n", + " [-1., -1., -1., -1., -1.]]]),\n", + " tensor([[[0.],\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", " \n", - " [[1.],\n", - " [1.],\n", - " [1.],\n", - " [1.]],\n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", " \n", " [[0.],\n", " [0.],\n", @@ -881,17 +779,25 @@ " [0.]]]))" ] }, - "execution_count": 19, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "td2 = TensorDict({\"a\": torch.ones(2, 4, 5), \"b\": torch.ones(2, 4)}, batch_size=[2, 4])\n", + "td2 = TensorDict({\"a\": torch.zeros(2, 4, 5), \"b\": torch.zeros(2, 4)}, batch_size=[2, 4])\n", "tensordict[:-1] = td2\n", "tensordict[\"a\"], tensordict[\"b\"]" ] }, + { + "cell_type": "markdown", + "id": "f523aa6d-42ff-4fa1-9cc2-ea778aa1bc1b", + "metadata": {}, + "source": [ + "We can set values easily just by indexing the tensordict:" + ] + }, { "cell_type": "markdown", "id": "3148a34a", @@ -905,12 +811,13 @@ "id": "b8c2075f", "metadata": {}, "source": [ + "### Masking\n", "We can perform masking on the indexes. Mask must be a tensor." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 29, "id": "7c6bf5e1", "metadata": {}, "outputs": [ @@ -926,7 +833,7 @@ " is_shared=False)" ] }, - "execution_count": 20, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -944,25 +851,18 @@ "TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\\*shape), squeeze(dim), unsqueeze(dim), permute(\\*dims) requiring the operations to comply with the batch_size" ] }, - { - "cell_type": "markdown", - "id": "d74d0fb9", - "metadata": {}, - "source": [ - "#### View" - ] - }, { "cell_type": "markdown", "id": "c5ec195a", "metadata": {}, "source": [ + "### View\n", "Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 31, "id": "d909b9cd", "metadata": {}, "outputs": [ @@ -980,7 +880,7 @@ "\top=view(size=torch.Size([-1])))" ] }, - "execution_count": 21, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -999,7 +899,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 33, "id": "f1774e04", "metadata": {}, "outputs": [ @@ -1017,7 +917,7 @@ "\top=permute(dims=(1, 0)))" ] }, - "execution_count": 22, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1037,7 +937,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 35, "id": "c9b3ab59", "metadata": {}, "outputs": [ @@ -1053,7 +953,7 @@ " is_shared=False)" ] }, - "execution_count": 23, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1073,35 +973,10 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 36, "id": "68e8975f", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False), TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False), TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False))\n" - ] - }, { "data": { "text/plain": [ @@ -1114,7 +989,7 @@ " is_shared=False)" ] }, - "execution_count": 24, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1122,7 +997,6 @@ "source": [ "#Cat\n", "list_tensordict = tensordict.unbind(0)\n", - "print(list_tensordict)\n", "torch.cat(list_tensordict, dim=0)" ] }, @@ -1137,7 +1011,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 37, "id": "0c3a54a0", "metadata": {}, "outputs": [ @@ -1185,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 38, "id": "8f90e26a", "metadata": {}, "outputs": [ @@ -1199,8 +1073,7 @@ " b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},\n", " batch_size=torch.Size([2, 3, 4]),\n", " device=cpu,\n", - " is_shared=False)\n", - "every tensordict is awesome!\n" + " is_shared=False)\n" ] } ], @@ -1208,7 +1081,7 @@ "#Stack\n", "staked_tensordict = torch.stack([tensordict, tensordict.clone()], dim=0)\n", "print(staked_tensordict)\n", - "if staked_tensordict[0] is tensordict:\n", + "if staked_tensordict[0] is tensordict and staked_tensordict[0] is not tensordict:\n", " print(\"every tensordict is awesome!\")" ] }, @@ -1222,34 +1095,13 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 39, "id": "b8c66c4a", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([2, 3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n", - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([2, 3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n" - ] - } - ], + "outputs": [], "source": [ - "print(staked_tensordict.contiguous())\n", - "print(staked_tensordict.to_tensordict())" + "assert isinstance(staked_tensordict.contiguous(), TensorDict)\n", + "assert isinstance(staked_tensordict.to_tensordict(), TensorDict)" ] }, { @@ -1270,7 +1122,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 40, "id": "7a9c34d7", "metadata": {}, "outputs": [], @@ -1297,7 +1149,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 41, "id": "f42fd847", "metadata": {}, "outputs": [ @@ -1314,7 +1166,7 @@ " is_shared=False)" ] }, - "execution_count": 29, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -1335,7 +1187,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 42, "id": "c2e6db70", "metadata": {}, "outputs": [ @@ -1351,7 +1203,7 @@ " is_shared=False)" ] }, - "execution_count": 30, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1380,7 +1232,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 43, "id": "f4692e76", "metadata": {}, "outputs": [], @@ -1396,7 +1248,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 44, "id": "55f66f45", "metadata": {}, "outputs": [ @@ -1414,7 +1266,7 @@ " is_shared=False)" ] }, - "execution_count": 32, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } @@ -1436,7 +1288,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 45, "id": "c3dc55f1", "metadata": {}, "outputs": [], @@ -1452,7 +1304,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 46, "id": "52fc40c4", "metadata": {}, "outputs": [ @@ -1470,7 +1322,7 @@ " is_shared=False)" ] }, - "execution_count": 34, + "execution_count": 46, "metadata": {}, "output_type": "execute_result" } @@ -1506,7 +1358,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 47, "id": "05e6a450", "metadata": {}, "outputs": [], @@ -1575,7 +1427,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 48, "id": "d3f130d9", "metadata": {}, "outputs": [], @@ -1598,7 +1450,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 49, "id": "c2d2a519", "metadata": {}, "outputs": [ @@ -1619,7 +1471,7 @@ " is_shared=False)" ] }, - "execution_count": 37, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -1652,48 +1504,48 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 50, "id": "d58f4d89", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[ 1.3338, 0.5568, 1.5257, -0.9140, -0.7399],\n", - " [-1.1602, -1.1171, -0.0355, -0.3293, 0.5658],\n", - " [ 0.0528, 1.3808, -1.1389, -1.2089, 1.2281]],\n", + "tensor([[[ 1.3948, 0.9014, -0.3433, 0.7251, -1.0122],\n", + " [ 0.1403, 1.2825, -1.2302, -1.4635, 0.7165],\n", + " [ 0.8251, -1.5065, -1.0976, -0.0920, 0.7596]],\n", "\n", - " [[ 1.0998, -0.2481, -0.5027, 0.3425, -1.3240],\n", - " [-1.2873, -2.0554, 1.1529, -0.5357, -0.2973],\n", - " [ 1.0694, 0.3098, 0.3401, 0.3393, 1.5966]],\n", + " [[ 1.5074, -0.4161, 0.5480, 1.1882, -1.3595],\n", + " [ 0.4081, -0.2427, 1.0663, 1.1888, -1.2557],\n", + " [ 0.1856, -0.5556, -1.4885, 0.5571, -1.3313]],\n", "\n", - " [[ 0.2762, -0.3396, -1.6140, 0.0866, 1.0148],\n", - " [-1.0849, -2.2679, 0.1260, 1.2162, 1.4198],\n", - " [ 1.0519, 0.2551, -0.3484, 0.0038, 0.2044]],\n", + " [[-0.1015, -1.1877, -0.1132, -0.9613, -0.7382],\n", + " [-0.9579, 1.7701, 1.8217, 0.5509, -0.6816],\n", + " [ 1.6410, 0.2096, -0.6056, -0.8921, 0.2459]],\n", "\n", - " [[-1.6978, -1.0813, -0.9879, -0.1534, -0.8177],\n", - " [ 1.0584, -1.3029, -0.5820, 0.4057, 1.6585],\n", - " [ 0.1717, 0.3136, 0.8307, 1.1407, 1.0436]],\n", + " [[-0.3648, 0.2833, 0.2137, 0.8552, -0.9622],\n", + " [ 0.3903, -1.2456, -0.6480, -0.4978, 0.8822],\n", + " [-0.2155, 1.4676, 0.3733, 1.6645, -2.1963]],\n", "\n", - " [[ 0.9372, -1.2179, -0.5154, -1.0837, 0.3776],\n", - " [-2.3519, -0.2721, -0.1398, 1.7179, -0.5435],\n", - " [ 0.3842, 0.8391, 0.4074, 0.9533, 0.5076]],\n", + " [[ 0.3817, -1.3816, -0.0126, 1.1476, 1.8669],\n", + " [ 0.5688, 0.9381, -1.3945, 0.6863, -1.5389],\n", + " [-0.2369, 0.7993, -1.0803, -0.4629, -0.2810]],\n", "\n", - " [[-0.0587, -0.1878, 0.5516, 0.0882, 0.9291],\n", - " [ 1.1866, -1.2275, -0.3984, -0.0310, -0.4985],\n", - " [-1.1856, -1.0034, 2.6606, -0.9745, 0.1493]],\n", + " [[ 0.7018, 1.2853, -0.2395, 1.8107, 0.1803],\n", + " [-1.4307, -0.3072, 0.6677, 0.5036, -1.2250],\n", + " [ 0.6976, 0.3290, -2.0245, -0.5033, -0.4459]],\n", "\n", - " [[-0.5460, 0.6498, -0.2246, -2.5625, -0.1683],\n", - " [ 0.2090, -0.9892, 0.5802, 0.5241, 0.5589],\n", - " [ 2.2136, 0.5872, 0.0154, -0.2186, -0.6290]],\n", + " [[ 0.1244, -0.1468, 0.2397, 0.2596, -0.5887],\n", + " [ 1.0288, 1.4734, -1.6026, 0.9561, -0.8791],\n", + " [ 0.0268, 0.7016, 0.7880, 0.0460, -2.4271]],\n", "\n", - " [[-1.3134, 0.0524, 0.3805, 0.3787, 0.6957],\n", - " [ 1.8009, -0.5925, -1.7172, -0.9240, 1.6469],\n", - " [-0.1907, -0.8003, 0.1441, -0.6741, 1.1129]]],\n", + " [[-2.4209, 0.3694, 1.0487, 0.8913, 0.6070],\n", + " [ 1.2738, -0.4474, 0.0166, -0.1589, 0.2163],\n", + " [-1.1257, 0.6097, 0.2396, -1.6827, 0.5633]]],\n", " grad_fn=)" ] }, - "execution_count": 38, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -1712,7 +1564,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 51, "id": "d6241e1f", "metadata": {}, "outputs": [], @@ -1728,7 +1580,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 52, "id": "544f6335", "metadata": {}, "outputs": [], @@ -1756,48 +1608,48 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 53, "id": "52dbf4fb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-2.5683e-01, 4.5024e-01, 9.4612e-01, -1.1000e+00, 6.6867e-01],\n", - " [-6.6982e-01, 8.6136e-01, 1.5597e+00, -1.0453e+00, 1.2007e+00],\n", - " [-6.9352e-01, -3.2060e-02, -4.3720e-01, -2.1913e+00, 7.3923e-01]],\n", + "tensor([[[ 0.3641, -0.0979, -0.0854, -0.4905, -2.0029],\n", + " [-1.5806, 0.5326, 0.6632, 0.7920, -1.1313],\n", + " [ 0.5733, -0.1321, 2.1497, 0.6715, -0.2256]],\n", "\n", - " [[ 9.1422e-01, -2.7420e-01, -2.1522e-01, -1.1503e+00, 1.6155e+00],\n", - " [ 3.1163e-01, -7.0292e-02, 1.0830e+00, -2.3652e+00, 1.4391e-01],\n", - " [ 3.9072e-01, -3.0454e-01, 8.5802e-01, -1.4170e+00, 4.7974e-01]],\n", + " [[-0.5301, 1.3014, 0.8339, -1.6371, -0.0447],\n", + " [-0.1658, 0.9115, -0.1586, -0.1376, -1.6049],\n", + " [-0.1376, 0.6356, 1.5934, 0.7616, -1.6211]],\n", "\n", - " [[-7.3433e-01, 6.3703e-01, 4.6386e-01, -1.3121e+00, 6.5461e-01],\n", - " [-3.1634e-01, 2.7175e-01, 4.1725e-01, -6.2716e-01, 6.5264e-01],\n", - " [-1.8677e+00, 1.3871e+00, 9.4117e-01, -1.7092e+00, 1.1414e+00]],\n", + " [[-0.7808, 0.1323, 0.3524, -1.8699, -1.2661],\n", + " [-0.3532, 0.0802, 1.8762, -0.0252, -1.0859],\n", + " [ 0.4886, 1.1930, 1.5847, 0.0510, -0.3773]],\n", "\n", - " [[-1.9929e-01, 7.1908e-01, 1.1399e+00, -2.0360e+00, 9.4091e-01],\n", - " [ 1.8042e-01, 8.9284e-03, 5.7997e-01, -6.0130e-01, 1.1606e+00],\n", - " [-3.7275e-01, 9.1703e-01, -1.2281e-01, -2.2440e+00, -7.0603e-02]],\n", + " [[-0.5639, 0.4097, 1.2614, -1.1658, 0.8153],\n", + " [ 0.9616, -1.0353, 0.8457, 0.3561, -1.7874],\n", + " [-0.3008, -0.3083, 1.2756, 0.7932, -1.5570]],\n", "\n", - " [[-7.2842e-01, -5.2198e-01, -8.3247e-01, -1.5439e-01, -3.7558e-01],\n", - " [-8.5731e-01, 1.0901e+00, 1.5317e+00, -2.1651e+00, 1.4644e+00],\n", - " [-4.1692e-01, 4.9001e-01, 9.7057e-01, -3.6126e-01, 8.6664e-01]],\n", + " [[ 1.6996, 0.6104, 0.7166, -1.4791, -0.0746],\n", + " [-1.3742, -0.0917, 0.3204, 0.4544, -1.6289],\n", + " [ 0.7167, 0.8835, 1.0875, -0.8911, -0.9494]],\n", "\n", - " [[ 9.5534e-01, 1.4917e-01, -7.1200e-03, -2.0727e+00, 1.1170e+00],\n", - " [-2.9475e-02, 8.1727e-01, 1.3990e+00, -2.5198e-01, 7.2835e-01],\n", - " [-1.2625e+00, -9.9376e-01, -8.0062e-01, -8.8569e-01, 1.1377e+00]],\n", + " [[-0.3857, 1.2185, 0.8044, -1.2474, -0.6684],\n", + " [-2.6451, 0.3813, 0.8183, 0.2206, -0.0257],\n", + " [-0.2125, 0.9058, 1.3390, -0.0635, -0.4396]],\n", "\n", - " [[-7.5436e-01, 4.3068e-01, 7.9711e-01, -1.2966e+00, 9.4933e-03],\n", - " [ 4.6172e-01, 9.7587e-01, 1.4781e+00, -8.3284e-01, 8.6749e-01],\n", - " [-1.6161e+00, 3.7033e-01, -4.2724e-02, -1.8422e+00, 9.9399e-01]],\n", + " [[ 0.4440, -0.4641, 1.3416, 0.4208, 0.0535],\n", + " [-0.0167, 0.5248, 0.4601, 1.1979, -2.6929],\n", + " [-0.1802, 0.8493, 0.2579, -1.4159, -0.7799]],\n", "\n", - " [[-2.0276e-01, -1.3223e-01, -2.9282e-01, -7.4340e-02, 1.6503e+00],\n", - " [ 1.0073e+00, -1.5842e-01, 1.0905e+00, -1.8213e+00, 6.9849e-01],\n", - " [-1.8112e-02, -9.7116e-01, 1.1706e+00, -1.9438e+00, -2.3435e-03]]],\n", + " [[ 0.0957, 0.2631, 1.9058, -0.7983, -2.1256],\n", + " [-0.8988, -0.3532, 1.4373, 0.5188, -0.8031],\n", + " [ 0.4242, -0.5689, 0.5862, 0.9721, -0.6552]]],\n", " grad_fn=)" ] }, - "execution_count": 41, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } @@ -1819,48 +1671,48 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 54, "id": "913b3979", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[ 0.2650, 0.8114, 1.6235, -1.8246, -0.1827],\n", - " [-0.5331, 0.8769, 1.4354, -0.7027, 0.0041],\n", - " [-0.0240, 0.0151, 0.7448, -1.9483, -0.5608]],\n", + "tensor([[[ 0.4194, 1.1577, -0.2249, -0.2299, -1.6838],\n", + " [-0.4603, 1.5326, 0.1736, 0.4194, -2.1895],\n", + " [ 0.2842, 0.8026, 1.1243, -0.1371, -0.9882]],\n", "\n", - " [[ 0.6278, 0.2316, 0.6965, -1.4007, 0.4978],\n", - " [ 0.5148, 0.1977, 1.5761, -2.1702, -0.3718],\n", - " [ 0.5269, 0.1489, 1.0460, -1.5976, -0.5239]],\n", + " [[-0.3067, 1.4502, 0.5999, -1.0000, -0.6322],\n", + " [-0.2193, 2.2219, -0.2861, -0.1852, -1.1817],\n", + " [ 0.0121, 0.7552, 0.8953, -0.3838, -1.7396]],\n", "\n", - " [[-0.2073, 1.7283, 0.7957, -1.2614, -0.3906],\n", - " [-0.2481, 1.4483, 0.3833, -0.3303, -0.3051],\n", - " [-0.7997, 1.4176, 0.4302, -1.9264, -0.7345]],\n", + " [[-0.2220, 0.9376, 0.4108, -0.8514, -1.4368],\n", + " [-0.2243, 1.0650, 1.1564, -0.5117, -1.8274],\n", + " [ 0.3239, 1.7627, 0.7365, -0.2319, -1.0873]],\n", "\n", - " [[ 0.4208, 0.4412, 1.3680, -2.0761, -0.3130],\n", - " [ 0.6140, 0.1374, 1.0931, 0.1626, 0.0274],\n", - " [-0.4122, 0.9318, 0.7548, -1.8592, -1.2908]],\n", + " [[ 0.0923, 0.1925, 2.0104, -1.0669, 0.2943],\n", + " [ 1.0389, -0.3586, 0.0279, -0.2836, -1.3698],\n", + " [-0.2385, 0.1469, 1.1481, 0.5598, -2.1937]],\n", "\n", - " [[-0.7726, 0.0486, 0.0335, -0.0701, -0.9185],\n", - " [-0.2630, 1.0334, 1.7184, -2.6263, 0.0084],\n", - " [-0.5563, 1.1644, 0.8858, -0.0215, 0.3358]],\n", + " [[ 1.3624, 0.0818, 0.6500, -1.9682, -0.3359],\n", + " [-0.3614, 0.9265, 0.2931, 0.2009, -1.4732],\n", + " [ 0.2946, 1.1101, 1.3561, -1.0921, -1.0449]],\n", "\n", - " [[ 1.0746, 0.6018, 0.5270, -0.9513, 0.1145],\n", - " [-0.6757, 1.6353, 1.4940, -0.1454, 0.1573],\n", - " [ 0.0575, -1.1170, -1.3292, -1.9089, 0.4654]],\n", + " [[ 0.1569, 1.4944, 0.1292, -1.0900, -0.8406],\n", + " [-1.6930, 0.3846, 0.9151, 0.1136, -1.2707],\n", + " [ 0.4416, 1.6022, 1.1814, -0.6376, -0.8870]],\n", "\n", - " [[-0.3958, 0.6817, 0.7094, -1.1415, -0.1320],\n", - " [ 0.2856, 1.6527, 1.2931, -0.6426, 0.4360],\n", - " [-0.9961, 0.7746, -0.0636, -2.4226, -0.0389]],\n", + " [[ 0.6334, -0.2367, 1.1179, -0.5009, -0.3828],\n", + " [ 0.0639, 1.2607, 0.4867, 0.1112, -2.5056],\n", + " [ 0.6137, 1.0620, 0.6164, -1.4468, -0.8929]],\n", "\n", - " [[ 0.0635, -0.2169, 0.2886, -0.9645, 0.0584],\n", - " [ 1.1510, 0.3362, 1.9643, -0.8624, -0.6656],\n", - " [ 0.8300, -0.7676, 1.4487, -1.8439, -0.8198]]],\n", + " [[-0.1152, 1.1531, 1.4088, -0.6771, -2.7837],\n", + " [-0.1132, 0.3721, 1.0719, 0.0737, -1.1365],\n", + " [ 0.3906, 0.4455, 0.2721, 0.2531, -0.6153]]],\n", " grad_fn=)" ] }, - "execution_count": 42, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -1882,7 +1734,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 55, "id": "f9c19fc1", "metadata": {}, "outputs": [ @@ -2207,7 +2059,7 @@ ")" ] }, - "execution_count": 43, + "execution_count": 55, "metadata": {}, "output_type": "execute_result" } @@ -2218,7 +2070,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 56, "id": "4fb30839", "metadata": {}, "outputs": [ @@ -2543,7 +2395,7 @@ ")" ] }, - "execution_count": 44, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -2569,7 +2421,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.8.3" } }, "nbformat": 4, From 4fb2571e95ff7acb7cad88154497c0237780a434 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Mon, 11 Jul 2022 17:42:32 +0100 Subject: [PATCH 07/21] TensorDictModule initial commit --- tutorials/tensordict.ipynb | 2429 ------------------------------------ 1 file changed, 2429 deletions(-) delete mode 100644 tutorials/tensordict.ipynb diff --git a/tutorials/tensordict.ipynb b/tutorials/tensordict.ipynb deleted file mode 100644 index 28dc11a17b2..00000000000 --- a/tutorials/tensordict.ipynb +++ /dev/null @@ -1,2429 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c0174624", - "metadata": {}, - "source": [ - "# TensorDict tutorial" - ] - }, - { - "cell_type": "markdown", - "id": "139b8238", - "metadata": {}, - "source": [ - "`TensorDict` is a new tensor structure introduced in TorchRL. \n", - "\n", - "With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. `TensorDict` aims at making it more convenient to deal with multiple tensors at the same time. \n", - "\n", - "Furthermore, different RL algorithms can deal with different input and outputs. The `TensorDict` class makes it possible to abstract away the differences between these algorithmes. \n", - "\n", - "TensorDict combines the convinience of using `dict`s to organize your data with the power of pytorch tensors.\n" - ] - }, - { - "cell_type": "markdown", - "id": "5adeede6", - "metadata": {}, - "source": [ - "#### Improving the modularity of codes" - ] - }, - { - "cell_type": "markdown", - "id": "3a11d7f7", - "metadata": {}, - "source": [ - "Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. \n", - "\n", - "Suppose we want to train a common algorithm over these two datasets (i.e. an algorithm that would ignore the mask or infer it when needed). \n", - "\n", - "In classical pytorch we would need to do the following:\n", - "```python\n", - "#Method A\n", - "for i in range(optim_steps):\n", - " images, labels = get_data_A()\n", - " loss = loss_module(images, labels)\n", - " loss.backward()\n", - " optim.step()\n", - " optim.zero_grad()\n", - "````\n", - "\n", - "```python\n", - "#Method B\n", - "for i in range(optim_steps):\n", - " images, masks, labels = get_data_B()\n", - " loss = loss_module(images, labels)\n", - " loss.backward()\n", - " optim.step()\n", - " optim.zero_grad()\n", - "```\n", - "\n", - "We can see that this limits the reusability of code. A lot of code has to be rewriten because of the modality difference between the 2 datasets.\n", - "The idea of TensorDict is to do the following:\n", - "\n", - "```python\n", - "# General Method\n", - "for i in range(optim_steps):\n", - " tensordict = get_data()\n", - " loss = loss_module(tensordict)\n", - " loss.backward()\n", - " optim.step()\n", - " optim.zero_grad()\n", - "```\n", - "\n", - "\n", - "Now we can reuse the same training loop across datasets and losses." - ] - }, - { - "cell_type": "markdown", - "id": "0d128d7c", - "metadata": {}, - "source": [ - "#### Can't i do this with a python dict?" - ] - }, - { - "cell_type": "markdown", - "id": "c2e2536a", - "metadata": {}, - "source": [ - "One could argue that you could achieve the same results with a dataset that outputs a pytorch dict. \n", - "```python\n", - "class DictDataset(Dataset):\n", - " ...\n", - " \n", - " def __getitem__(self, idx)\n", - " \n", - " ...\n", - " \n", - " return {\"modality_A\": torch.Tensor(torch.randn(2)), \"modality_B\": torch.Tensor(torch.randn(2))}\n", - " \n", - "```\n", - "\n", - "However to achieve this you would need to write a complicated collate function that make sure that every modality is agregated properly.\n", - "\n", - "```python\n", - "\n", - "def collate_dict_fn(dict_list):\n", - " final_dict = {}\n", - " for key in dict_list[0].keys():\n", - " final_dict[key]= []\n", - " for single_dict in dict_list:\n", - " final_dict[key].append(single_dict[key])\n", - " final_dict[key] = torch.stack(final_dict[key], dim=0)\n", - " return final_dict\n", - "\n", - "\n", - "dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)\n", - "\n", - "````\n", - "With TensorDicts this is now much simpler:\n", - "\n", - "```python\n", - "class DictDataset(Dataset):\n", - " ...\n", - " \n", - " def __getitem__(self, idx)\n", - " \n", - " ...\n", - " \n", - " return TensorDict({\"modality_A\": torch.Tensor(torch.randn(2)), \"modality_B\": torch.Tensor(torch.randn(2)), batch_size=[]})\n", - "```\n", - "\n", - "\n", - "Here, the collate function is as simple as:\n", - "```python\n", - "collate_tensordict_fn = lambda tds : torch.stack(tds, dim=0)\n", - "\n", - "dataloader = Dataloader(DictDataset(), collate_fn = collate_tensordict_fn)\n", - "```\n", - "\n", - "TensorDict inherits multiple properties from `torch.Tensor` and `dict` that we will detail furtherdown." - ] - }, - { - "cell_type": "markdown", - "id": "b1303de8", - "metadata": {}, - "source": [ - "## `TensorDict` dictionary features" - ] - }, - { - "cell_type": "markdown", - "id": "79458114", - "metadata": {}, - "source": [ - "`TensorDict` shares a lot of features with python dictionaries" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "70a39ae6", - "metadata": {}, - "outputs": [], - "source": [ - "from torchrl.data import TensorDict\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4009045f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n" - ] - } - ], - "source": [ - "a = torch.zeros(3, 4, 5)\n", - "b = torch.zeros(3, 4)\n", - "tensordict = TensorDict({\"a\": a, \"b\": b}, batch_size=[3, 4])\n", - "print(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "0f543f6c", - "metadata": {}, - "source": [ - "### `get(key)`\n", - "If we want to access a certain key, we can index the tensordict or alternatively use the `get` method:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "0235a1a7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "torch.Size([3, 4, 5])\n" - ] - } - ], - "source": [ - "print(tensordict[\"a\"] is tensordict.get(\"a\") is a)\n", - "print(tensordict[\"a\"].shape)" - ] - }, - { - "cell_type": "markdown", - "id": "3cc9df67-8834-4a75-8e19-6be21322c8a5", - "metadata": {}, - "source": [ - "The `get` method also supports default values:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "d6638e51-628e-4bb1-b106-989b2b8c34be", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([1., 1., 1.])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "out = tensordict.get(\"foo\", torch.ones(3))\n", - "out" - ] - }, - { - "cell_type": "markdown", - "id": "ea8db3b3-a9c1-4ba4-8c42-dd346c8fcbdd", - "metadata": {}, - "source": [ - "## `set(key, value)`\n", - "The `set()` method can be used to set new values. Regular indexing also does the job:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "853e0b14-a8e6-4c29-ad98-1226a21c2ff6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "td[\"c\"] is c: True\n", - "td[\"d\"] is d: True\n" - ] - } - ], - "source": [ - "c = torch.zeros((3, 4, 2, 2))\n", - "tensordict.set(\"c\", c)\n", - "print(f\"td[\\\"c\\\"] is c: {c is tensordict['c']}\")\n", - "\n", - "d = torch.zeros((3, 4, 2, 2))\n", - "tensordict[\"d\"] = d\n", - "print(f\"td[\\\"d\\\"] is d: {d is tensordict['d']}\")" - ] - }, - { - "cell_type": "markdown", - "id": "1f200416", - "metadata": {}, - "source": [ - "## Other methods:\n", - "### `keys`\n", - "We can access the keys of a tensordict:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "a9882c10", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a\n", - "b\n", - "c\n", - "d\n" - ] - } - ], - "source": [ - "for key in tensordict.keys():\n", - " print(key)" - ] - }, - { - "cell_type": "markdown", - "id": "1fecc00f", - "metadata": {}, - "source": [ - "### `values`\n", - "The values of a `TensorDict` can be retrieved with the `values()` function. Note that, unlike python `dict`s, the `values()` method returns a generator and not a list." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5fd21818", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([3, 4, 5])\n", - "torch.Size([3, 4, 1])\n", - "torch.Size([3, 4, 2, 2])\n", - "torch.Size([3, 4, 2, 2])\n" - ] - } - ], - "source": [ - "for value in tensordict.values():\n", - " print(value.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "45e59abc", - "metadata": {}, - "source": [ - "### TensorDict.update()\n", - "The `update` method can be used to update a TensorDict with another one (or with a dict):" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "96f216bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a is now equal to 1: True\n", - "d is now equal to 2: True\n" - ] - } - ], - "source": [ - "tensordict.update({\"a\": torch.ones((3, 4, 5)), \"d\": 2*torch.ones((3, 4, 2))})\n", - "# Also works with tensordict.update(TensorDict({\"a\":torch.ones((3, 4, 5)), \"c\":torch.ones((3, 4, 2))}, batch_size=[3,4]))\n", - "print(f\"a is now equal to 1: {(tensordict['a'] == 1).all()}\")\n", - "print(f\"d is now equal to 2: {(tensordict['d'] == 2).all()}\")" - ] - }, - { - "cell_type": "markdown", - "id": "9feb8632", - "metadata": {}, - "source": [ - "### TensorDict del key\n", - "TensorDict also support keys deletion with the `del` operator:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7afc943e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_keys(['a', 'b', 'd'])\n" - ] - } - ], - "source": [ - "del tensordict[\"c\"]\n", - "print(tensordict.keys())" - ] - }, - { - "cell_type": "markdown", - "id": "5fb04f41", - "metadata": {}, - "source": [ - "## TensorDict as a Tensor-like object" - ] - }, - { - "cell_type": "markdown", - "id": "1c834589", - "metadata": {}, - "source": [ - "But wait? Can't we do this with a classical dict? \n", - "Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.\n", - "TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.\n", - "\n", - "TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:\n", - "- device;\n", - "- memory location (shared, memory-mapped array, ...);\n", - "- batch size (i.e. n^th first dimensions)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "f22d4b66", - "metadata": {}, - "outputs": [], - "source": [ - "from torchrl.data import TensorDict\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "daaa4e28", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n" - ] - } - ], - "source": [ - "tensordict = TensorDict({\"a\": torch.zeros(3, 4, 5), \"b\": torch.zeros(3, 4)}, batch_size=[3, 4])\n", - "print(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "0aa2e7cd", - "metadata": {}, - "source": [ - "#### Batch size" - ] - }, - { - "cell_type": "markdown", - "id": "6d17d8b4", - "metadata": {}, - "source": [ - "Tensor dict has a batch size which is shared across all tensors. The batch size can be [], unidimensional or multidimensional according to your needs." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "3c0217e0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Our Tensor dict is of size torch.Size([3, 4])\n" - ] - } - ], - "source": [ - "print(f\"Our Tensor dict is of size {tensordict.shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "bda9c6aa", - "metadata": {}, - "source": [ - "You cannot have items that don't share the batch size inside the same TensorDict:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "d08b9d11", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Caramba! We got this error: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])\n" - ] - } - ], - "source": [ - "# we cannot add tensors that violate the batch size:\n", - "try:\n", - " tensordict.update({\"c\": torch.zeros(4, 3, 1)})\n", - "except RuntimeError as err:\n", - " print(f\"Caramba! We got this error: {err}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "f8ec89ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Caramba! We got this error: the tensor a has shape torch.Size([3, 4, 5]) which is incompatible with the new shape torch.Size([4, 4])\n" - ] - } - ], - "source": [ - "# If we reset the batch size, it has to comply with the tensordict batch size\n", - "try:\n", - " tensordict.batch_size = [4,4]\n", - "except RuntimeError as err:\n", - " print(f\"Caramba! We got this error: {err}\")" - ] - }, - { - "cell_type": "markdown", - "id": "93ba3948", - "metadata": {}, - "source": [ - "#### Devices" - ] - }, - { - "cell_type": "markdown", - "id": "12efa0f5", - "metadata": {}, - "source": [ - "### Device\n", - "TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device" - ] - }, - { - "cell_type": "markdown", - "id": "6e7c9d12", - "metadata": {}, - "source": [ - "### Memory sharing via physical memory usage" - ] - }, - { - "cell_type": "markdown", - "id": "fd540505", - "metadata": {}, - "source": [ - "When on cpu, one can use either `tensordict.memmap_()` or `tensordict.share_memory_()` to send a `tensordict` to represent it as a memory-mapped collection of tensors or put it in shared memory resp." - ] - }, - { - "cell_type": "markdown", - "id": "fc71180a", - "metadata": {}, - "source": [ - "### Cloning\n", - "TensorDict supports cloning. Cloning returns the same SubTensorDict item than the original item." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "c88f387f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "redefining a tensor in the clone does not impact the original tensordict: tensor(False)\n" - ] - } - ], - "source": [ - "tensordict_clone = tensordict.clone()\n", - "tensordict_clone[\"a\"] = torch.ones(*tensordict.shape, 5)\n", - "print(\"redefining a tensor in the clone does not impact the original tensordict: \", (tensordict[\"a\"] == tensordict_clone[\"a\"]).all())" - ] - }, - { - "cell_type": "markdown", - "id": "9251e116", - "metadata": {}, - "source": [ - "### Tensor operations\n", - "We can perform tensor operations among the batch dimensions:" - ] - }, - { - "cell_type": "markdown", - "id": "5f8159df", - "metadata": {}, - "source": [ - "### Slicing and indexing\n", - "Slicing and indexing is supported along the batch dimensions" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "8409cfb4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([4]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "af1387cf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([2, 4]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict[1:]" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "d7133060", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([3, 2]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict[:, 2:]" - ] - }, - { - "cell_type": "markdown", - "id": "c69d84a1", - "metadata": {}, - "source": [ - "#### Setting values with indexing\n", - "We can also edit certain tensor features by deliminting certain indexes:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "a9a2a860-e9c3-4112-9ae3-d6ba5b8473f7", - "metadata": {}, - "outputs": [], - "source": [ - "subtd = tensordict[:, torch.tensor([1, 3])] # a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data\n", - "tensordict.fill_(\"a\", -1)\n", - "assert (subtd[\"a\"] == -1).all() # the \"a\" key-value pair has changed" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "19f235ee", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[[ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.]],\n", - " \n", - " [[ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0.]],\n", - " \n", - " [[-1., -1., -1., -1., -1.],\n", - " [-1., -1., -1., -1., -1.],\n", - " [-1., -1., -1., -1., -1.],\n", - " [-1., -1., -1., -1., -1.]]]),\n", - " tensor([[[0.],\n", - " [0.],\n", - " [0.],\n", - " [0.]],\n", - " \n", - " [[0.],\n", - " [0.],\n", - " [0.],\n", - " [0.]],\n", - " \n", - " [[0.],\n", - " [0.],\n", - " [0.],\n", - " [0.]]]))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "td2 = TensorDict({\"a\": torch.zeros(2, 4, 5), \"b\": torch.zeros(2, 4)}, batch_size=[2, 4])\n", - "tensordict[:-1] = td2\n", - "tensordict[\"a\"], tensordict[\"b\"]" - ] - }, - { - "cell_type": "markdown", - "id": "f523aa6d-42ff-4fa1-9cc2-ea778aa1bc1b", - "metadata": {}, - "source": [ - "We can set values easily just by indexing the tensordict:" - ] - }, - { - "cell_type": "markdown", - "id": "3148a34a", - "metadata": {}, - "source": [ - "#### Masking" - ] - }, - { - "cell_type": "markdown", - "id": "b8c2075f", - "metadata": {}, - "source": [ - "### Masking\n", - "We can perform masking on the indexes. Mask must be a tensor." - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "7c6bf5e1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([6, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([6, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([6]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]).bool()\n", - "tensordict[mask]" - ] - }, - { - "cell_type": "markdown", - "id": "e9835c39", - "metadata": {}, - "source": [ - "TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\\*shape), squeeze(dim), unsqueeze(dim), permute(\\*dims) requiring the operations to comply with the batch_size" - ] - }, - { - "cell_type": "markdown", - "id": "c5ec195a", - "metadata": {}, - "source": [ - "### View\n", - "Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "d909b9cd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ViewedTensorDict(\n", - "\tsource=TensorDict(\n", - "\t fields={\n", - "\t a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", - "\t b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", - "\t batch_size=torch.Size([3, 4]),\n", - "\t device=cpu,\n", - "\t is_shared=False), \n", - "\top=view(size=torch.Size([-1])))" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict.view(-1)" - ] - }, - { - "cell_type": "markdown", - "id": "6cf6665c", - "metadata": {}, - "source": [ - "#### Permute" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "f1774e04", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "PermutedTensorDict(\n", - "\tsource=TensorDict(\n", - "\t fields={\n", - "\t a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", - "\t b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", - "\t batch_size=torch.Size([3, 4]),\n", - "\t device=cpu,\n", - "\t is_shared=False), \n", - "\top=permute(dims=(1, 0)))" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict.permute(1,0)" - ] - }, - { - "cell_type": "markdown", - "id": "e57c354e", - "metadata": {}, - "source": [ - "#### Reshape\n", - "Reshape allows reshaping the tensordict batch size" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "c9b3ab59", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([12, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([12]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict.reshape(-1)" - ] - }, - { - "cell_type": "markdown", - "id": "48933b38", - "metadata": {}, - "source": [ - "#### Unbind and Cat\n", - "TensorDict can unbind and cat among a dim over the tensordict batch size" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "68e8975f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([12, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([12]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#Cat\n", - "list_tensordict = tensordict.unbind(0)\n", - "torch.cat(list_tensordict, dim=0)" - ] - }, - { - "cell_type": "markdown", - "id": "3a7a3f68", - "metadata": {}, - "source": [ - "#### Squeeze and Unsqueeze\n", - "Tensordict also supports squeeze and unsqueeze. Use `to_tensordict` to retrieve a tensordict" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "0c3a54a0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([1, 3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([1, 3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([1, 3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n", - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n" - ] - } - ], - "source": [ - "print(tensordict.unsqueeze(0).to_tensordict())\n", - "print(tensordict.squeeze(0).to_tensordict())" - ] - }, - { - "cell_type": "markdown", - "id": "eadb1bc0", - "metadata": {}, - "source": [ - "#### Stacking" - ] - }, - { - "cell_type": "markdown", - "id": "f1dff375", - "metadata": {}, - "source": [ - "TensorDict supports stacking, stacking is done in a lazy fashion, returning a LazyStackedTensorDict item." - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "8f90e26a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LazyStackedTensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},\n", - " batch_size=torch.Size([2, 3, 4]),\n", - " device=cpu,\n", - " is_shared=False)\n" - ] - } - ], - "source": [ - "#Stack\n", - "staked_tensordict = torch.stack([tensordict, tensordict.clone()], dim=0)\n", - "print(staked_tensordict)\n", - "if staked_tensordict[0] is tensordict and staked_tensordict[0] is not tensordict:\n", - " print(\"every tensordict is awesome!\")" - ] - }, - { - "cell_type": "markdown", - "id": "22161bd3", - "metadata": {}, - "source": [ - "If we want to have a contiguous tensordict, we can call `.to_tensordict()` or `.contiguous()`. It is recommended to perform this operation before accessing the values of the stacked tensordict for efficiency purposes" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "b8c66c4a", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(staked_tensordict.contiguous(), TensorDict)\n", - "assert isinstance(staked_tensordict.to_tensordict(), TensorDict)" - ] - }, - { - "cell_type": "markdown", - "id": "e86afd08", - "metadata": {}, - "source": [ - "## How to use them in practice? The tensor the TensorDictModule" - ] - }, - { - "cell_type": "markdown", - "id": "78192367", - "metadata": {}, - "source": [ - "Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "7a9c34d7", - "metadata": {}, - "outputs": [], - "source": [ - "from torchrl.modules import TensorDictModule\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "markdown", - "id": "2198427a", - "metadata": {}, - "source": [ - "### Example: Simple Linear layer" - ] - }, - { - "cell_type": "markdown", - "id": "a9286f32", - "metadata": {}, - "source": [ - "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "f42fd847", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict = TensorDict({\"a\": torch.randn(5, 3), \"b\": torch.randn(5, 4, 3)}, batch_size=[5])\n", - "linear = TensorDictModule(nn.Linear(3, 10),in_keys=[\"a\"], out_keys=[\"a_out\"])\n", - "linear(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "f1452e41", - "metadata": {}, - "source": [ - "We can also do it inplace" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "c2e6db70", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict = TensorDict({\"a\":torch.randn(5, 3), \"b\":torch.randn(5, 4, 3)}, batch_size=[5])\n", - "linear = TensorDictModule(nn.Linear(3, 10),in_keys=[\"a\"], out_keys=[\"a\"])\n", - "linear(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "7fbae8ef", - "metadata": {}, - "source": [ - "### Example: 2 input merging with 2 linear layer" - ] - }, - { - "cell_type": "markdown", - "id": "a4c36399", - "metadata": {}, - "source": [ - "Now lets imagine a more complex network that takes 2 entries and average them into a single output" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "f4692e76", - "metadata": {}, - "outputs": [], - "source": [ - "class MergeLinear(nn.Module):\n", - " def __init__(self, in_1, in_2, out):\n", - " super().__init__()\n", - " self.linear_1 = nn.Linear(in_1,out)\n", - " self.linear_2 = nn.Linear(in_2,out)\n", - " def forward(self, x_1, x_2):\n", - " return (self.linear_1(x_1) + self.linear_2(x_2))/2" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "55f66f45", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", - " c: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", - " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3), \"c\":torch.randn(5,4)}, batch_size=[5])\n", - "mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=[\"a\",\"c\"], out_keys=[\"output\"])\n", - "mergelinear(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "f13b043e", - "metadata": {}, - "source": [ - "### Example: 1 input to 2 outputs linear layer\n", - "We can also map to multiple outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "c3dc55f1", - "metadata": {}, - "outputs": [], - "source": [ - "class MultiHeadLinear(nn.Module):\n", - " def __init__(self, in_1, out_1, out_2):\n", - " super().__init__()\n", - " self.linear_1 = nn.Linear(in_1,out_1)\n", - " self.linear_2 = nn.Linear(in_1,out_2)\n", - " def forward(self, x):\n", - " return self.linear_1(x), self.linear_2(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "52fc40c4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", - " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", - " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tensordict = TensorDict({\"a\":torch.randn(5,3), \"b\":torch.randn(5,4,3)}, batch_size=[5])\n", - "mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"])\n", - "mergelinear(tensordict)" - ] - }, - { - "cell_type": "markdown", - "id": "a40b756d", - "metadata": {}, - "source": [ - "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", - "The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key." - ] - }, - { - "cell_type": "markdown", - "id": "897a7533", - "metadata": {}, - "source": [ - "### Example: A transformer with TensorDict?\n", - "Let's attempt to create a transformer with TensorDict and TensorDictModule\n", - "\n", - "Disclaimer: This implementation don't claim to be \"better\" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.\n", - "For simplicity we will not have positional encoders.\n", - "\n", - "Let's first implement the classical transformers blocks" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "05e6a450", - "metadata": {}, - "outputs": [], - "source": [ - "class TokensToQKV(nn.Module):\n", - " def __init__(self, to_dim, from_dim, latent_dim):\n", - " super().__init__()\n", - " self.q = nn.Linear(to_dim, latent_dim)\n", - " self.k = nn.Linear(from_dim, latent_dim)\n", - " self.v = nn.Linear(from_dim, latent_dim)\n", - " def forward(self, X_to, X_from):\n", - " Q = self.q(X_to)\n", - " K = self.k(X_from)\n", - " V = self.v(X_from)\n", - " return Q, K, V\n", - "\n", - "class SplitHeads(nn.Module):\n", - " def __init__(self, num_heads):\n", - " super().__init__()\n", - " self.num_heads = num_heads\n", - " def forward(self, Q, K, V):\n", - " batch_size, to_num, latent_dim = Q.shape\n", - " _, from_num, _ = K.shape\n", - " d_tensor = latent_dim // self.num_heads\n", - " Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)\n", - " K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", - " V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", - " return Q, K, V\n", - "class Attention(nn.Module):\n", - " def __init__(self, latent_dim, to_dim):\n", - " super().__init__()\n", - " self.softmax = nn.Softmax(dim=-1)\n", - " self.out = nn.Linear(latent_dim, to_dim)\n", - " def forward(self, Q, K, V):\n", - " batch_size, n_heads, to_num, d_in = Q.shape\n", - " attn = self.softmax(Q @ K.transpose(2,3) / d_in)\n", - " out = attn @ V\n", - " out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))\n", - " return out, attn\n", - "class SkipLayerNorm(nn.Module):\n", - " def __init__(self, to_len, to_dim):\n", - " super().__init__()\n", - " self.layer_norm = nn.LayerNorm((to_len, to_dim))\n", - " def forward(self, x_0, x_1):\n", - " return self.layer_norm(x_0+x_1)\n", - "class FFN(nn.Module):\n", - " def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):\n", - " super().__init__()\n", - " self.FFN = nn.Sequential(\n", - " nn.Linear(to_dim, hidden_dim),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_dim, to_dim),\n", - " nn.Dropout(dropout_rate)\n", - " )\n", - " def forward(self, X):\n", - " return self.FFN(X)\n" - ] - }, - { - "cell_type": "markdown", - "id": "f5086060", - "metadata": {}, - "source": [ - "Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block." - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "d3f130d9", - "metadata": {}, - "outputs": [], - "source": [ - "class TransformerBlockTensorDict(nn.Module):\n", - " def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", - " super().__init__()\n", - " self.transformer_block = nn.Sequential(\n", - " TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=[\"Q\", \"K\", \"V\"]),\n", - " TensorDictModule(SplitHeads(num_heads), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"Q\", \"K\", \"V\"]),\n", - " TensorDictModule(Attention(latent_dim, to_dim), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"X_out\",\"Attn\"]),\n", - " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[\"X_to\", \"X_out\"], out_keys=[\"X_to\"]),\n", - " TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=[\"X_to\"], out_keys=[\"X_out\"]),\n", - " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[\"X_to\", \"X_out\"], out_keys=[\"X_to\"]),\n", - " )\n", - " def forward(self, X_tensor_dict):\n", - " self.transformer_block(X_tensor_dict)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "c2d2a519", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),\n", - " K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", - " Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", - " V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", - " X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", - " X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n", - " X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},\n", - " batch_size=torch.Size([8]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "to_dim = 5\n", - "from_dim = 6\n", - "latent_dim = 10\n", - "to_len = 3\n", - "from_len = 10\n", - "batch_size = 8\n", - "num_heads = 2\n", - "\n", - "tokens = TensorDict({\"X_to\":torch.randn(batch_size, to_len, to_dim), \"X_from\":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])\n", - "\n", - "transformer_block = TransformerBlockTensorDict(\"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads)\n", - "\n", - "transformer_block(tokens)\n", - "\n", - "tokens" - ] - }, - { - "cell_type": "markdown", - "id": "e513287f", - "metadata": {}, - "source": [ - "The output of the transformer layer can now be found at tokens[\"X_to\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "d58f4d89", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 1.3948, 0.9014, -0.3433, 0.7251, -1.0122],\n", - " [ 0.1403, 1.2825, -1.2302, -1.4635, 0.7165],\n", - " [ 0.8251, -1.5065, -1.0976, -0.0920, 0.7596]],\n", - "\n", - " [[ 1.5074, -0.4161, 0.5480, 1.1882, -1.3595],\n", - " [ 0.4081, -0.2427, 1.0663, 1.1888, -1.2557],\n", - " [ 0.1856, -0.5556, -1.4885, 0.5571, -1.3313]],\n", - "\n", - " [[-0.1015, -1.1877, -0.1132, -0.9613, -0.7382],\n", - " [-0.9579, 1.7701, 1.8217, 0.5509, -0.6816],\n", - " [ 1.6410, 0.2096, -0.6056, -0.8921, 0.2459]],\n", - "\n", - " [[-0.3648, 0.2833, 0.2137, 0.8552, -0.9622],\n", - " [ 0.3903, -1.2456, -0.6480, -0.4978, 0.8822],\n", - " [-0.2155, 1.4676, 0.3733, 1.6645, -2.1963]],\n", - "\n", - " [[ 0.3817, -1.3816, -0.0126, 1.1476, 1.8669],\n", - " [ 0.5688, 0.9381, -1.3945, 0.6863, -1.5389],\n", - " [-0.2369, 0.7993, -1.0803, -0.4629, -0.2810]],\n", - "\n", - " [[ 0.7018, 1.2853, -0.2395, 1.8107, 0.1803],\n", - " [-1.4307, -0.3072, 0.6677, 0.5036, -1.2250],\n", - " [ 0.6976, 0.3290, -2.0245, -0.5033, -0.4459]],\n", - "\n", - " [[ 0.1244, -0.1468, 0.2397, 0.2596, -0.5887],\n", - " [ 1.0288, 1.4734, -1.6026, 0.9561, -0.8791],\n", - " [ 0.0268, 0.7016, 0.7880, 0.0460, -2.4271]],\n", - "\n", - " [[-2.4209, 0.3694, 1.0487, 0.8913, 0.6070],\n", - " [ 1.2738, -0.4474, 0.0166, -0.1589, 0.2163],\n", - " [-1.1257, 0.6097, 0.2396, -1.6827, 0.5633]]],\n", - " grad_fn=)" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tokens[\"X_to\"]" - ] - }, - { - "cell_type": "markdown", - "id": "bc2a210d", - "metadata": {}, - "source": [ - "We can now create a transformer easily" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "d6241e1f", - "metadata": {}, - "outputs": [], - "source": [ - "class TransformerTensorDict(nn.Module):\n", - " def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", - " super().__init__()\n", - " self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])\n", - " def forward(self, X_tensor_dict):\n", - " for transformer_block in self.transformer:\n", - " transformer_block(X_tensor_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "544f6335", - "metadata": {}, - "outputs": [], - "source": [ - "to_dim = 5\n", - "from_dim = 6\n", - "latent_dim = 10\n", - "to_len = 3\n", - "from_len = 10\n", - "batch_size = 8\n", - "num_heads = 2\n", - "\n", - "tokens = TensorDict({\"X_to\":torch.randn(batch_size, to_len, to_dim), \"X_from\":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "a704bddb", - "metadata": {}, - "source": [ - "For an encoder, we just need to take the same tokens for both queries, keys and values." - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "52dbf4fb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 0.3641, -0.0979, -0.0854, -0.4905, -2.0029],\n", - " [-1.5806, 0.5326, 0.6632, 0.7920, -1.1313],\n", - " [ 0.5733, -0.1321, 2.1497, 0.6715, -0.2256]],\n", - "\n", - " [[-0.5301, 1.3014, 0.8339, -1.6371, -0.0447],\n", - " [-0.1658, 0.9115, -0.1586, -0.1376, -1.6049],\n", - " [-0.1376, 0.6356, 1.5934, 0.7616, -1.6211]],\n", - "\n", - " [[-0.7808, 0.1323, 0.3524, -1.8699, -1.2661],\n", - " [-0.3532, 0.0802, 1.8762, -0.0252, -1.0859],\n", - " [ 0.4886, 1.1930, 1.5847, 0.0510, -0.3773]],\n", - "\n", - " [[-0.5639, 0.4097, 1.2614, -1.1658, 0.8153],\n", - " [ 0.9616, -1.0353, 0.8457, 0.3561, -1.7874],\n", - " [-0.3008, -0.3083, 1.2756, 0.7932, -1.5570]],\n", - "\n", - " [[ 1.6996, 0.6104, 0.7166, -1.4791, -0.0746],\n", - " [-1.3742, -0.0917, 0.3204, 0.4544, -1.6289],\n", - " [ 0.7167, 0.8835, 1.0875, -0.8911, -0.9494]],\n", - "\n", - " [[-0.3857, 1.2185, 0.8044, -1.2474, -0.6684],\n", - " [-2.6451, 0.3813, 0.8183, 0.2206, -0.0257],\n", - " [-0.2125, 0.9058, 1.3390, -0.0635, -0.4396]],\n", - "\n", - " [[ 0.4440, -0.4641, 1.3416, 0.4208, 0.0535],\n", - " [-0.0167, 0.5248, 0.4601, 1.1979, -2.6929],\n", - " [-0.1802, 0.8493, 0.2579, -1.4159, -0.7799]],\n", - "\n", - " [[ 0.0957, 0.2631, 1.9058, -0.7983, -2.1256],\n", - " [-0.8988, -0.3532, 1.4373, 0.5188, -0.8031],\n", - " [ 0.4242, -0.5689, 0.5862, 0.9721, -0.6552]]],\n", - " grad_fn=)" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transformer_encoder = TransformerTensorDict(6, \"X_to\", \"X_to\", to_dim, to_len, to_dim, latent_dim, num_heads)\n", - "\n", - "transformer_encoder(tokens)\n", - "tokens[\"X_to\"]" - ] - }, - { - "cell_type": "markdown", - "id": "f379ac76", - "metadata": {}, - "source": [ - "For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values." - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "913b3979", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 0.4194, 1.1577, -0.2249, -0.2299, -1.6838],\n", - " [-0.4603, 1.5326, 0.1736, 0.4194, -2.1895],\n", - " [ 0.2842, 0.8026, 1.1243, -0.1371, -0.9882]],\n", - "\n", - " [[-0.3067, 1.4502, 0.5999, -1.0000, -0.6322],\n", - " [-0.2193, 2.2219, -0.2861, -0.1852, -1.1817],\n", - " [ 0.0121, 0.7552, 0.8953, -0.3838, -1.7396]],\n", - "\n", - " [[-0.2220, 0.9376, 0.4108, -0.8514, -1.4368],\n", - " [-0.2243, 1.0650, 1.1564, -0.5117, -1.8274],\n", - " [ 0.3239, 1.7627, 0.7365, -0.2319, -1.0873]],\n", - "\n", - " [[ 0.0923, 0.1925, 2.0104, -1.0669, 0.2943],\n", - " [ 1.0389, -0.3586, 0.0279, -0.2836, -1.3698],\n", - " [-0.2385, 0.1469, 1.1481, 0.5598, -2.1937]],\n", - "\n", - " [[ 1.3624, 0.0818, 0.6500, -1.9682, -0.3359],\n", - " [-0.3614, 0.9265, 0.2931, 0.2009, -1.4732],\n", - " [ 0.2946, 1.1101, 1.3561, -1.0921, -1.0449]],\n", - "\n", - " [[ 0.1569, 1.4944, 0.1292, -1.0900, -0.8406],\n", - " [-1.6930, 0.3846, 0.9151, 0.1136, -1.2707],\n", - " [ 0.4416, 1.6022, 1.1814, -0.6376, -0.8870]],\n", - "\n", - " [[ 0.6334, -0.2367, 1.1179, -0.5009, -0.3828],\n", - " [ 0.0639, 1.2607, 0.4867, 0.1112, -2.5056],\n", - " [ 0.6137, 1.0620, 0.6164, -1.4468, -0.8929]],\n", - "\n", - " [[-0.1152, 1.1531, 1.4088, -0.6771, -2.7837],\n", - " [-0.1132, 0.3721, 1.0719, 0.0737, -1.1365],\n", - " [ 0.3906, 0.4455, 0.2721, 0.2531, -0.6153]]],\n", - " grad_fn=)" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transformer_decoder = TransformerTensorDict(6, \"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads)\n", - "\n", - "transformer_decoder(tokens)\n", - "tokens[\"X_to\"]" - ] - }, - { - "cell_type": "markdown", - "id": "4c3698ba", - "metadata": {}, - "source": [ - "Now we can look at both models:" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "f9c19fc1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TransformerTensorDict(\n", - " (transformer): ModuleList(\n", - " (0): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (1): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (2): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (3): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (4): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (5): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transformer_encoder" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "4fb30839", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TransformerTensorDict(\n", - " (transformer): ModuleList(\n", - " (0): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (1): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (2): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (3): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (4): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " (5): TransformerBlockTensorDict(\n", - " (transformer_block): Sequential(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transformer_decoder" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 38e8bebb6736db965a2789a9cb066f5bc2f7711b Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Wed, 13 Jul 2022 10:19:54 +0100 Subject: [PATCH 08/21] Details fixed --- tutorials/tensordictmodule.ipynb | 1691 ++++-------------------------- 1 file changed, 193 insertions(+), 1498 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index e200d1c1a56..4a4b8f2e630 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -2,32 +2,34 @@ "cells": [ { "cell_type": "markdown", - "id": "2b194ee2", + "id": "8a90f5b4", "metadata": {}, "source": [ - "# The TensorDictModule" + "# TensorDictModule" ] }, { "cell_type": "markdown", - "id": "fc268c66", + "id": "a31301d9", "metadata": {}, "source": [ - "Make sure to first read the tensordict tutorial" + "We recommand reading the TensorDict tutorial before going through this one." ] }, { "cell_type": "markdown", - "id": "b16eeedd", + "id": "9206ae33", "metadata": {}, "source": [ - "How do we use the TensorDict it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key." + "For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`.
\n", + "The `TensorDictModule` class is an `nn.Module` that takes a `TensorDict` as input when called.
\n", + "It is up to the user to define the keys to be read as input and output." ] }, { "cell_type": "code", - "execution_count": 2, - "id": "b2e43fd6", + "execution_count": 1, + "id": "4d5d1a6d", "metadata": {}, "outputs": [], "source": [ @@ -39,24 +41,24 @@ }, { "cell_type": "markdown", - "id": "a46be707", + "id": "271e8714", "metadata": {}, "source": [ - "### Example: Simple Linear layer" + "### Example 1: Simple usage" ] }, { "cell_type": "markdown", - "id": "2a7e3f58", + "id": "db3d429d", "metadata": {}, "source": [ - "Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a." + "Let's imagine we have 2 entries `TensorDict`, a and b and we only want to pass a to our network." ] }, { "cell_type": "code", - "execution_count": 3, - "id": "6c688617", + "execution_count": 16, + "id": "989a3bea", "metadata": {}, "outputs": [ { @@ -72,90 +74,94 @@ " is_shared=False)" ] }, - "execution_count": 3, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tensordict = TensorDict({\"a\": torch.randn(5, 3), \"b\": torch.randn(5, 4, 3)}, batch_size=[5])\n", + "tensordict = TensorDict(\n", + " {\n", + " \"a\": torch.randn(5, 3),\n", + " \"b\": torch.zeros(5, 4, 3)\n", + " },\n", + " batch_size=[5]\n", + ")\n", "linear = TensorDictModule(nn.Linear(3, 10),in_keys=[\"a\"], out_keys=[\"a_out\"])\n", - "linear(tensordict)" + "linear(tensordict)\n", + "assert (tensordict[\"b\"] == torch.zeros(5, 4, 3)).all()\n", + "tensordict" ] }, { "cell_type": "markdown", - "id": "1f71e15f", + "id": "cde8dbcb", "metadata": {}, "source": [ - "We can also do it inplace" + "We can also do it inplace:" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "4a33bfae", + "execution_count": 18, + "id": "23128445", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "tensordict = TensorDict({\"a\": torch.randn(5, 3), \"b\": torch.randn(5, 4, 3)}, batch_size=[5])\n", - "linear = TensorDictModule(nn.Linear(3, 10),in_keys=[\"a\"], out_keys=[\"a\"])\n", - "linear(tensordict)" + "tensordict = TensorDict(\n", + " {\n", + " \"a\": torch.randn(5, 3),\n", + " \"b\": torch.zeros(5, 4, 3)\n", + " },\n", + " batch_size=[5]\n", + ")\n", + "\n", + "linear = TensorDictModule(\n", + " nn.Linear(3, 10),\n", + " in_keys=[\"a\"], \n", + " out_keys=[\"a\"]\n", + ")\n", + "\n", + "linear(tensordict)\n", + "assert tensordict[\"a\"].shape == torch.Size([5,10])" ] }, { "cell_type": "markdown", - "id": "963909a1", + "id": "e88255a9", "metadata": {}, "source": [ - "### Example: 2 input merging with 2 linear layer" + "### Example 2: Multiple inputs" ] }, { "cell_type": "markdown", - "id": "b2a6f397", + "id": "63714f8c", "metadata": {}, "source": [ - "Now lets imagine a more complex network that takes 2 entries and average them into a single output" + "Now lets imagine a more complex network that takes 2 entries and average them into a single output:" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "41f0e812", + "execution_count": 4, + "id": "1e0cce94", "metadata": {}, "outputs": [], "source": [ "class MergeLinear(nn.Module):\n", " def __init__(self, in_1, in_2, out):\n", " super().__init__()\n", - " self.linear_1 = nn.Linear(in_1,out)\n", - " self.linear_2 = nn.Linear(in_2,out)\n", + " self.linear_1 = nn.Linear(in_1, out)\n", + " self.linear_2 = nn.Linear(in_2, out)\n", " def forward(self, x_1, x_2):\n", " return (self.linear_1(x_1) + self.linear_2(x_2))/2" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "c6d3ddbb", + "execution_count": 5, + "id": "29c6242b", "metadata": {}, "outputs": [ { @@ -172,46 +178,58 @@ " is_shared=False)" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tensordict = TensorDict({\"a\": torch.randn(5, 3), \"b\": torch.randn(5, 4, 3), \"c\":torch.randn(5, 4)}, batch_size=[5])\n", - "mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=[\"a\",\"c\"], out_keys=[\"output\"])\n", + "tensordict = TensorDict(\n", + " {\n", + " \"a\": torch.randn(5, 3),\n", + " \"b\": torch.randn(5, 4),\n", + " }, \n", + " batch_size=[5]\n", + ")\n", + "\n", + "mergelinear = TensorDictModule(\n", + " MergeLinear(3, 4, 10),\n", + " in_keys=[\"a\",\"b\"],\n", + " out_keys=[\"output\"]\n", + ")\n", + " \n", "mergelinear(tensordict)" ] }, { "cell_type": "markdown", - "id": "ee41e873", + "id": "c831e815", "metadata": {}, "source": [ - "### Example: 1 input to 2 outputs linear layer\n", + "### Example 3: Multiple outputs\n", "We can also map to multiple outputs" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "b86d4867", + "execution_count": 6, + "id": "ff71ceb2", "metadata": {}, "outputs": [], "source": [ "class MultiHeadLinear(nn.Module):\n", " def __init__(self, in_1, out_1, out_2):\n", " super().__init__()\n", - " self.linear_1 = nn.Linear(in_1,out_1)\n", - " self.linear_2 = nn.Linear(in_1,out_2)\n", + " self.linear_1 = nn.Linear(in_1, out_1)\n", + " self.linear_2 = nn.Linear(in_1, out_2)\n", " def forward(self, x):\n", " return self.linear_1(x), self.linear_2(x)" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "0f132410", + "execution_count": 7, + "id": "de83ddac", "metadata": {}, "outputs": [ { @@ -228,38 +246,44 @@ " is_shared=False)" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tensordict = TensorDict({\"a\": torch.randn(5, 3), \"b\": torch.randn(5, 4, 3)}, batch_size=[5])\n", - "mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"])\n", - "mergelinear(tensordict)" + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"]\n", + ")\n", + "splitlinear(tensordict)" ] }, { "cell_type": "markdown", - "id": "0b852dff", + "id": "f31279e8", "metadata": {}, "source": [ - "As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.\n", - "The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key." + "As we shown previously, the `TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.\n", + "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place." ] }, { "cell_type": "markdown", - "id": "078a538a", + "id": "503a1e34", "metadata": {}, "source": [ - "### Example: A transformer with TensorDict?\n", - "Let's attempt to create a transformer with TensorDict and TensorDictModule.\n", + "### Example 4: A transformer with TensorDict?\n", + "Let's attempt to create a transformer with `TensorDict` and `TensorDictModule`.\n", "\n", "Here's a diagram that sums up the architecture:\n", "\n", "\n", - "Disclaimer: This implementation don't claim to be \"better\" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.\n", + "\n", + "Disclaimer: This implementation don't claim to be \"better\" than a classical tensor-based implementation. It is just meant to showcase the `TensorDictModule` features.\n", "For simplicity we will not have positional encoders.\n", "\n", "Let's first implement the classical transformers blocks." @@ -267,8 +291,8 @@ }, { "cell_type": "code", - "execution_count": 36, - "id": "ba74542a", + "execution_count": 8, + "id": "f039f790", "metadata": {}, "outputs": [], "source": [ @@ -328,16 +352,16 @@ }, { "cell_type": "markdown", - "id": "9651d3c2", + "id": "8569fc1d", "metadata": {}, "source": [ - "Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block." + "Now, we can build the encoder and decoder blocks that will be part of the transformer thanks to the TensorDictModule. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block." ] }, { "cell_type": "code", - "execution_count": 37, - "id": "afd16e7e", + "execution_count": 9, + "id": "8f1496d2", "metadata": {}, "outputs": [], "source": [ @@ -366,8 +390,8 @@ }, { "cell_type": "code", - "execution_count": 38, - "id": "eba41ed9", + "execution_count": 10, + "id": "06571f56", "metadata": {}, "outputs": [ { @@ -387,7 +411,7 @@ " is_shared=False)" ] }, - "execution_count": 38, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -426,7 +450,7 @@ }, { "cell_type": "markdown", - "id": "a89c3eea", + "id": "19a1c3b1", "metadata": {}, "source": [ "The output of the transformer layer can now be found at tokens[\"X_to\"]" @@ -434,48 +458,48 @@ }, { "cell_type": "code", - "execution_count": 39, - "id": "cb06a532", + "execution_count": 11, + "id": "cf2662f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.2024, -1.0720, 0.0880, 0.1581, 1.3610],\n", - " [-0.5052, -1.8012, 0.7504, -0.6252, -1.9845],\n", - " [ 0.7721, 0.9715, 0.5495, 1.2725, 0.2672]],\n", + "tensor([[[-4.1014e-01, -3.8609e-02, -4.6190e-01, -1.1984e+00, -1.0029e+00],\n", + " [ 4.8135e-01, 2.1887e+00, 1.0424e+00, 2.3389e-01, 2.2858e-01],\n", + " [ 9.2826e-02, 1.7278e+00, -1.3197e+00, -8.0110e-01, -7.6280e-01]],\n", "\n", - " [[ 0.5293, -0.4636, -1.3908, -0.7254, 0.7734],\n", - " [-1.7271, 1.1405, 0.5576, -0.3882, 0.0643],\n", - " [ 0.6418, 1.0622, 1.1772, -1.8478, 0.5967]],\n", + " [[-8.1625e-01, 9.2371e-01, -3.4789e-01, -1.7687e+00, -5.5227e-01],\n", + " [-4.1375e-01, -6.0962e-01, -6.1183e-01, 2.3965e-01, 1.3284e+00],\n", + " [-7.7074e-01, 6.0351e-01, -2.2673e-01, 2.4082e+00, 6.1431e-01]],\n", "\n", - " [[ 0.4203, 0.1880, -1.6405, -1.6656, 0.4852],\n", - " [ 0.5613, 0.6962, 0.3396, 0.1042, 1.3068],\n", - " [-0.0636, -0.3054, 1.7350, -1.7914, -0.3700]],\n", + " [[ 1.1999e+00, 9.0440e-01, 1.9596e-01, -1.7704e+00, 4.7291e-01],\n", + " [ 1.2238e+00, 1.9033e-01, 2.7292e-01, -2.5550e-01, -1.3451e+00],\n", + " [ 1.0397e+00, 1.1162e+00, -1.2123e+00, -1.0223e+00, -1.0106e+00]],\n", "\n", - " [[-0.6542, 0.3826, -0.9735, 1.6878, -0.2295],\n", - " [-0.6227, 0.1929, 1.3043, 1.3246, -1.2593],\n", - " [ 0.7568, -0.5468, -1.7795, -0.4934, 0.9099]],\n", + " [[ 4.9356e-01, -1.4560e-02, 2.1211e-02, 8.7976e-01, -8.7540e-02],\n", + " [-1.4565e-01, 1.0329e+00, 5.8444e-01, 6.7036e-01, -1.5396e+00],\n", + " [ 6.7115e-01, -9.1910e-01, 1.5387e+00, -8.9005e-01, -2.2955e+00]],\n", "\n", - " [[ 0.6101, 1.1662, -0.1247, -0.0322, 0.3963],\n", - " [-1.3019, 1.8116, -0.8462, 0.8816, -1.8484],\n", - " [ 0.5544, -0.6557, -1.4419, 0.5660, 0.2648]],\n", + " [[ 1.1135e+00, 7.9529e-01, -1.6116e+00, -7.5607e-01, -8.4692e-01],\n", + " [ 6.8204e-01, -1.1673e-01, 9.6905e-01, -8.8737e-01, -1.1990e+00],\n", + " [ 1.9570e+00, -3.7739e-01, 8.7870e-01, 2.8001e-01, -8.8057e-01]],\n", "\n", - " [[-0.4518, 1.6725, -1.2902, -0.8343, 0.7091],\n", - " [-1.1318, -0.3141, -0.1082, 0.5590, 0.6859],\n", - " [ 0.7135, 1.9085, 0.0153, -1.5834, -0.5498]],\n", + " [[ 1.4367e+00, -4.3311e-01, 1.5850e+00, -5.8594e-01, -1.2526e+00],\n", + " [ 1.5208e+00, -1.0612e+00, -8.0319e-01, -1.2658e+00, 8.2322e-01],\n", + " [-1.2821e-01, 6.0637e-01, 8.4822e-01, -7.5417e-01, -5.3608e-01]],\n", "\n", - " [[ 1.4865, -0.3687, -0.8707, 0.4826, 0.2675],\n", - " [ 0.6597, 0.9290, -1.1941, -0.9715, -1.8818],\n", - " [ 1.4414, 0.7599, -0.5614, -0.9248, 0.7463]],\n", + " [[ 1.7111e-01, 5.6883e-01, 1.9995e-01, 4.3374e-01, -2.2862e+00],\n", + " [-1.4681e+00, 9.7458e-01, 2.6624e-01, -1.6163e-01, -1.2614e+00],\n", + " [ 5.9899e-01, -1.4938e-03, 1.7823e+00, -5.5919e-01, 7.4231e-01]],\n", "\n", - " [[ 0.3178, -0.6035, 1.0795, 0.4479, 0.5360],\n", - " [ 1.3815, 0.2269, -0.1231, -2.4940, -0.1504],\n", - " [ 1.2107, 0.1268, 0.2812, -1.6662, -0.5712]]],\n", + " [[ 3.2855e-01, -1.3951e+00, 8.6517e-01, 1.3137e+00, 6.4360e-01],\n", + " [-1.5505e-01, -1.6713e+00, 6.1532e-01, -1.2615e+00, 1.7464e+00],\n", + " [-1.1965e+00, -6.1894e-01, 4.0306e-01, 4.8236e-01, -9.9691e-02]]],\n", " grad_fn=)" ] }, - "execution_count": 39, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -486,7 +510,7 @@ }, { "cell_type": "markdown", - "id": "75e1c95c", + "id": "c1396907", "metadata": {}, "source": [ "We can now create a transformer easily" @@ -494,7 +518,7 @@ }, { "cell_type": "markdown", - "id": "5125af26", + "id": "6fcdb817", "metadata": {}, "source": [ "For an encoder, we just need to take the same tokens for both queries, keys and values." @@ -502,16 +526,16 @@ }, { "cell_type": "markdown", - "id": "f909b289", + "id": "ecc1fe0f", "metadata": {}, "source": [ - "For a decoder, we now can extract info from X_from into X_to. X_from will map to queries whereas X_from will map to keys and values." + "For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values." ] }, { "cell_type": "code", - "execution_count": 40, - "id": "6b9f19ef", + "execution_count": 12, + "id": "a3e3027d", "metadata": {}, "outputs": [], "source": [ @@ -527,7 +551,17 @@ " latent_dim,\n", " num_heads\n", " ):\n", - " super().__init__(*[TransformerBlockEncoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])\n", + " super().__init__(\n", + " *[TransformerBlockEncoderTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads\n", + " ) for _ in range(num_blocks)\n", + " ])\n", "class TransformerDecoderTensorDict(TensorDictSequence):\n", " def __init__(\n", " self,\n", @@ -540,7 +574,17 @@ " latent_dim,\n", " num_heads\n", " ):\n", - " super().__init__(*[TransformerBlockDecoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])\n", + " super().__init__(\n", + " *[TransformerBlockDecoderTensorDict(\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads\n", + " ) for _ in range(num_blocks)\n", + " ])\n", " \n", "class TransformerTensorDict(TensorDictSequence):\n", " def __init__(\n", @@ -581,8 +625,8 @@ }, { "cell_type": "code", - "execution_count": 33, - "id": "5e22913e", + "execution_count": 13, + "id": "cc6ce12a", "metadata": {}, "outputs": [], "source": [ @@ -605,30 +649,33 @@ }, { "cell_type": "code", - "execution_count": 42, - "id": "b7b99195", + "execution_count": 14, + "id": "19f9fff2", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),\n", - " K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", - " Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", - " V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", - " X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", - " X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", - " X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},\n", - " batch_size=torch.Size([8]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" + "ename": "TypeError", + "evalue": "linear(): argument 'input' (position 1) must be Tensor, not NoneType", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m transformer \u001b[38;5;241m=\u001b[39m TransformerTensorDict(\n\u001b[1;32m 2\u001b[0m \u001b[38;5;241m6\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_to\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m num_heads\n\u001b[1;32m 10\u001b[0m )\n\u001b[0;32m---> 12\u001b[0m \u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m tokens\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", + " \u001b[0;31m[... skipping similar frames: Module._call_impl at line 1130 (2 times), TensorDictSequence.forward at line 228 (1 times)]\u001b[0m\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:336\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 336\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "Input \u001b[0;32mIn [8]\u001b[0m, in \u001b[0;36mTokensToQKV.forward\u001b[0;34m(self, X_to, X_from)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, X_to, X_from):\n\u001b[0;32m----> 8\u001b[0m Q \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_to\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m K \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mk(X_from)\n\u001b[1;32m 10\u001b[0m V \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv(X_from)\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: linear(): argument 'input' (position 1) must be Tensor, not NoneType" + ] } ], "source": [ @@ -649,7 +696,7 @@ }, { "cell_type": "markdown", - "id": "8006a21d", + "id": "23b2a0c6", "metadata": {}, "source": [ "Now we can look at the model:" @@ -657,1365 +704,13 @@ }, { "cell_type": "code", - "execution_count": 43, - "id": "db35f95f", + "execution_count": null, + "id": "0881f74f", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TransformerTensorDict(\n", - " module=ModuleList(\n", - " (0): TransformerEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (2): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (3): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (4): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (5): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_to'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (1): TransformerDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " (1): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " (2): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " (3): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " (4): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " (5): TransformerBlockDecoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TransformerBlockEncoderTensorDict(\n", - " module=ModuleList(\n", - " (0): AttentionBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=6, out_features=10, bias=True)\n", - " (k): Linear(in_features=5, out_features=10, bias=True)\n", - " (v): Linear(in_features=5, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=6, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'X_out', 'Attn', 'X_from'])\n", - " (1): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=6, out_features=24, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=24, out_features=6, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from'], \n", - " out_keys=['X_out'])\n", - " (2): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((10, 6), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_out'], \n", - " out_keys=['X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_from', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_to', 'X_from', 'X_from'], \n", - " out_keys=['X_to', 'Q', 'K', 'V', 'Attn', 'X_out', 'X_from'])" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "transformer" ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "1a17dda7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TransformerTensorDict(\n", - " module=ModuleList(\n", - " (0): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (1): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (2): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (3): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (4): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " (5): TransformerBlockTensorDict(\n", - " module=ModuleList(\n", - " (0): TensorDictModule(\n", - " module=TokensToQKV(\n", - " (q): Linear(in_features=5, out_features=10, bias=True)\n", - " (k): Linear(in_features=6, out_features=10, bias=True)\n", - " (v): Linear(in_features=6, out_features=10, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (1): TensorDictModule(\n", - " module=SplitHeads(), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['Q', 'K', 'V'])\n", - " (2): TensorDictModule(\n", - " module=Attention(\n", - " (softmax): Softmax(dim=-1)\n", - " (out): Linear(in_features=10, out_features=5, bias=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['Q', 'K', 'V'], \n", - " out_keys=['X_out', 'Attn'])\n", - " (3): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " (4): TensorDictModule(\n", - " module=FFN(\n", - " (FFN): Sequential(\n", - " (0): Linear(in_features=5, out_features=20, bias=True)\n", - " (1): ReLU()\n", - " (2): Linear(in_features=20, out_features=5, bias=True)\n", - " (3): Dropout(p=0.2, inplace=False)\n", - " )\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to'], \n", - " out_keys=['X_out'])\n", - " (5): TensorDictModule(\n", - " module=SkipLayerNorm(\n", - " (layer_norm): LayerNorm((3, 5), eps=1e-05, elementwise_affine=True)\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_out'], \n", - " out_keys=['X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])\n", - " ), \n", - " device=cpu, \n", - " in_keys=['X_to', 'X_from', 'X_to', 'X_from', 'X_from', 'X_from', 'X_from', 'X_from'], \n", - " out_keys=['Q', 'K', 'V', 'Attn', 'X_out', 'X_to'])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transformer_decoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7e42002", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From f3a12bea100e2c1e1f04c81baf69699c80e4d4d8 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Wed, 13 Jul 2022 16:28:05 +0100 Subject: [PATCH 09/21] Made suggered modifications --- torchrl/version.py | 2 + tutorials/tensordictmodule.ipynb | 704 ++++++++++++++++++++----------- 2 files changed, 451 insertions(+), 255 deletions(-) create mode 100644 torchrl/version.py diff --git a/torchrl/version.py b/torchrl/version.py new file mode 100644 index 00000000000..363b21c4ca6 --- /dev/null +++ b/torchrl/version.py @@ -0,0 +1,2 @@ +__version__ = '0.0.1a0+38e8beb' +git_version = '38e8bebb6736db965a2789a9cb066f5bc2f7711b' diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 4a4b8f2e630..950c7dd0fe2 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8a90f5b4", + "id": "3be0fafd", "metadata": {}, "source": [ "# TensorDictModule" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "a31301d9", + "id": "94bd315a", "metadata": {}, "source": [ "We recommand reading the TensorDict tutorial before going through this one." @@ -18,7 +18,7 @@ }, { "cell_type": "markdown", - "id": "9206ae33", + "id": "0652352c", "metadata": {}, "source": [ "For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`.
\n", @@ -28,20 +28,30 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "4d5d1a6d", + "execution_count": null, + "id": "5b0241ab", "metadata": {}, "outputs": [], "source": [ - "from torchrl.modules import TensorDictModule, TensorDictSequence\n", - "from torchrl.data import TensorDict\n", + "import torch\n", "import torch.nn as nn\n", - "import torch" + "\n", + "from torchrl.data import TensorDict\n", + "from torchrl.modules import TensorDictModule, TensorDictSequence" ] }, { "cell_type": "markdown", - "id": "271e8714", + "id": "129a6de9-cf97-4565-a229-c05ad18df882", + "metadata": {}, + "source": [ + "## `TensorDictModule` by examples\n", + "Let's learn about `TensorDictModule by exploring some examples" + ] + }, + { + "cell_type": "markdown", + "id": "9d1c188a", "metadata": {}, "source": [ "### Example 1: Simple usage" @@ -49,7 +59,7 @@ }, { "cell_type": "markdown", - "id": "db3d429d", + "id": "1d21a711", "metadata": {}, "source": [ "Let's imagine we have 2 entries `TensorDict`, a and b and we only want to pass a to our network." @@ -57,8 +67,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "id": "989a3bea", + "execution_count": 3, + "id": "6f33781f", "metadata": {}, "outputs": [ { @@ -74,20 +84,16 @@ " is_shared=False)" ] }, - "execution_count": 16, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensordict = TensorDict(\n", - " {\n", - " \"a\": torch.randn(5, 3),\n", - " \"b\": torch.zeros(5, 4, 3)\n", - " },\n", - " batch_size=[5]\n", + " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)}, batch_size=[5]\n", ")\n", - "linear = TensorDictModule(nn.Linear(3, 10),in_keys=[\"a\"], out_keys=[\"a_out\"])\n", + "linear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a_out\"])\n", "linear(tensordict)\n", "assert (tensordict[\"b\"] == torch.zeros(5, 4, 3)).all()\n", "tensordict" @@ -95,7 +101,7 @@ }, { "cell_type": "markdown", - "id": "cde8dbcb", + "id": "3a818629", "metadata": {}, "source": [ "We can also do it inplace:" @@ -103,32 +109,24 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "23128445", + "execution_count": 4, + "id": "4d1342d0", "metadata": {}, "outputs": [], "source": [ "tensordict = TensorDict(\n", - " {\n", - " \"a\": torch.randn(5, 3),\n", - " \"b\": torch.zeros(5, 4, 3)\n", - " },\n", - " batch_size=[5]\n", + " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)}, batch_size=[5]\n", ")\n", "\n", - "linear = TensorDictModule(\n", - " nn.Linear(3, 10),\n", - " in_keys=[\"a\"], \n", - " out_keys=[\"a\"]\n", - ")\n", + "linear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a\"])\n", "\n", "linear(tensordict)\n", - "assert tensordict[\"a\"].shape == torch.Size([5,10])" + "assert tensordict[\"a\"].shape == torch.Size([5, 10])" ] }, { "cell_type": "markdown", - "id": "e88255a9", + "id": "00035cbd", "metadata": {}, "source": [ "### Example 2: Multiple inputs" @@ -136,7 +134,7 @@ }, { "cell_type": "markdown", - "id": "63714f8c", + "id": "06a20c22", "metadata": {}, "source": [ "Now lets imagine a more complex network that takes 2 entries and average them into a single output:" @@ -144,8 +142,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "1e0cce94", + "execution_count": 5, + "id": "69098393", "metadata": {}, "outputs": [], "source": [ @@ -154,14 +152,15 @@ " super().__init__()\n", " self.linear_1 = nn.Linear(in_1, out)\n", " self.linear_2 = nn.Linear(in_2, out)\n", + "\n", " def forward(self, x_1, x_2):\n", - " return (self.linear_1(x_1) + self.linear_2(x_2))/2" + " return (self.linear_1(x_1) + self.linear_2(x_2)) / 2" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "29c6242b", + "execution_count": 6, + "id": "2dd686bb", "metadata": {}, "outputs": [ { @@ -170,15 +169,14 @@ "TensorDict(\n", " fields={\n", " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", - " c: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", " batch_size=torch.Size([5]),\n", " device=cpu,\n", " is_shared=False)" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -188,22 +186,20 @@ " {\n", " \"a\": torch.randn(5, 3),\n", " \"b\": torch.randn(5, 4),\n", - " }, \n", - " batch_size=[5]\n", + " },\n", + " batch_size=[5],\n", ")\n", "\n", "mergelinear = TensorDictModule(\n", - " MergeLinear(3, 4, 10),\n", - " in_keys=[\"a\",\"b\"],\n", - " out_keys=[\"output\"]\n", + " MergeLinear(3, 4, 10), in_keys=[\"a\", \"b\"], out_keys=[\"output\"]\n", ")\n", - " \n", + "\n", "mergelinear(tensordict)" ] }, { "cell_type": "markdown", - "id": "c831e815", + "id": "11256ae7", "metadata": {}, "source": [ "### Example 3: Multiple outputs\n", @@ -212,8 +208,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "ff71ceb2", + "execution_count": 7, + "id": "0b7f709b", "metadata": {}, "outputs": [], "source": [ @@ -222,14 +218,15 @@ " super().__init__()\n", " self.linear_1 = nn.Linear(in_1, out_1)\n", " self.linear_2 = nn.Linear(in_1, out_2)\n", + "\n", " def forward(self, x):\n", " return self.linear_1(x), self.linear_2(x)" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "de83ddac", + "execution_count": 8, + "id": "1b2b465f", "metadata": {}, "outputs": [ { @@ -238,7 +235,6 @@ "TensorDict(\n", " fields={\n", " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),\n", " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", " batch_size=torch.Size([5]),\n", @@ -246,7 +242,7 @@ " is_shared=False)" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -255,16 +251,14 @@ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", "\n", "splitlinear = TensorDictModule(\n", - " MultiHeadLinear(3, 4, 10),\n", - " in_keys=[\"a\"],\n", - " out_keys=[\"output_1\", \"output_2\"]\n", + " MultiHeadLinear(3, 4, 10), in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"]\n", ")\n", "splitlinear(tensordict)" ] }, { "cell_type": "markdown", - "id": "f31279e8", + "id": "859630c3", "metadata": {}, "source": [ "As we shown previously, the `TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.\n", @@ -273,10 +267,163 @@ }, { "cell_type": "markdown", - "id": "503a1e34", + "id": "760118ea", + "metadata": {}, + "source": [ + "### Example 4: Compatibility with functorch" + ] + }, + { + "cell_type": "markdown", + "id": "e2718a12", "metadata": {}, "source": [ - "### Example 4: A transformer with TensorDict?\n", + "TensorDictModule is compatible with functorch. We can use make_functional_with_buffers on top of it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b553bed1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),\n", + " output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functorch import make_functional_with_buffers, vmap\n", + "\n", + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10), in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"]\n", + ")\n", + "func, (params, buffers) = splitlinear.make_functional_with_buffers()\n", + "func(tensordict, params=params, buffers=buffers)" + ] + }, + { + "cell_type": "markdown", + "id": "50ac0393", + "metadata": {}, + "source": [ + "We can also use vmap. Let's do some model ensembling with it." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "86ccb7be", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-i1bssf2g/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 28\u001b[0m params \u001b[38;5;241m=\u001b[39m transpose_stack(params)\n\u001b[1;32m 29\u001b[0m buffers \u001b[38;5;241m=\u001b[39m transpose_stack(buffers)\n\u001b[0;32m---> 30\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", + "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:333\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(err_msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 328\u001b[0m kwargs_pruned \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 329\u001b[0m key: item\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, item \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 332\u001b[0m }\n\u001b[0;32m--> 333\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuffers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_pruned\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py:361\u001b[0m, in \u001b[0;36mvmap..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 359\u001b[0m _check_out_dims_is_int_or_int_pytree(out_dims, func)\n\u001b[1;32m 360\u001b[0m batch_size, flat_in_dims, flat_args, args_spec \u001b[38;5;241m=\u001b[39m _process_batched_inputs(in_dims, args, func)\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_flat_vmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 362\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_in_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs_spec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandomness\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 363\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py:487\u001b[0m, in \u001b[0;36m_flat_vmap\u001b[0;34m(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 486\u001b[0m batched_inputs \u001b[38;5;241m=\u001b[39m _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)\n\u001b[0;32m--> 487\u001b[0m batched_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatched_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)\n\u001b[1;32m 489\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/make_functional.py:282\u001b[0m, in \u001b[0;36mFunctionalModuleWithBuffers.forward\u001b[0;34m(self, params, buffers, *args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m old_state \u001b[38;5;241m=\u001b[39m _swap_state(\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model,\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map,\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28mlist\u001b[39m(params) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(buffers))\n\u001b[1;32m 281\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstateless_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 284\u001b[0m \u001b[38;5;66;03m# Remove the loaded state on self.stateless_model\u001b[39;00m\n\u001b[1;32m 285\u001b[0m _swap_state(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map, old_state)\n", + "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-i1bssf2g/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. " + ] + } + ], + "source": [ + "from functorch import make_functional_with_buffers, vmap\n", + "\n", + "num_models = 10\n", + "\n", + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear_models = [\n", + " TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"])\n", + " for _ in range(num_models)\n", + "]\n", + "\n", + "\n", + "def transpose_stack(tuple_of_tuple_of_tensors):\n", + " tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))\n", + " results = tuple(\n", + " torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors\n", + " )\n", + " return results\n", + "\n", + "\n", + "func = splitlinear_models[0].make_functional_with_buffers()[0]\n", + "params, buffers = zip(\n", + " *[\n", + " splitlinear.make_functional_with_buffers()[1]\n", + " for splitlinear in splitlinear_models\n", + " ]\n", + ")\n", + "params = transpose_stack(params)\n", + "buffers = transpose_stack(buffers)\n", + "func(tensordict, params=params, buffers=buffers, vmap=True).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1ed39eab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functorch import make_functional_with_buffers\n", + "\n", + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"])\n", + "func, param, buffers = make_functional_with_buffers(splitlinear)\n", + "func(param, buffers, tensordict)" + ] + }, + { + "cell_type": "markdown", + "id": "6304a098", + "metadata": {}, + "source": [ + "### Example 5: A transformer with TensorDict?\n", "Let's attempt to create a transformer with `TensorDict` and `TensorDictModule`.\n", "\n", "Here's a diagram that sums up the architecture:\n", @@ -291,8 +438,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "f039f790", + "execution_count": 12, + "id": "e1f7ba7b", "metadata": {}, "outputs": [], "source": [ @@ -302,16 +449,19 @@ " self.q = nn.Linear(to_dim, latent_dim)\n", " self.k = nn.Linear(from_dim, latent_dim)\n", " self.v = nn.Linear(from_dim, latent_dim)\n", + "\n", " def forward(self, X_to, X_from):\n", " Q = self.q(X_to)\n", " K = self.k(X_from)\n", " V = self.v(X_from)\n", " return Q, K, V\n", "\n", + "\n", "class SplitHeads(nn.Module):\n", " def __init__(self, num_heads):\n", " super().__init__()\n", " self.num_heads = num_heads\n", + "\n", " def forward(self, Q, K, V):\n", " batch_size, to_num, latent_dim = Q.shape\n", " _, from_num, _ = K.shape\n", @@ -320,39 +470,48 @@ " K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", " V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", " return Q, K, V\n", + "\n", + "\n", "class Attention(nn.Module):\n", " def __init__(self, latent_dim, to_dim):\n", " super().__init__()\n", " self.softmax = nn.Softmax(dim=-1)\n", " self.out = nn.Linear(latent_dim, to_dim)\n", + "\n", " def forward(self, Q, K, V):\n", " batch_size, n_heads, to_num, d_in = Q.shape\n", " attn = self.softmax(Q @ K.transpose(2, 3) / d_in)\n", " out = attn @ V\n", - " out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))\n", + " out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))\n", " return out, attn\n", + "\n", + "\n", "class SkipLayerNorm(nn.Module):\n", " def __init__(self, to_len, to_dim):\n", " super().__init__()\n", " self.layer_norm = nn.LayerNorm((to_len, to_dim))\n", + "\n", " def forward(self, x_0, x_1):\n", - " return self.layer_norm(x_0+x_1)\n", + " return self.layer_norm(x_0 + x_1)\n", + "\n", + "\n", "class FFN(nn.Module):\n", - " def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):\n", + " def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):\n", " super().__init__()\n", " self.FFN = nn.Sequential(\n", " nn.Linear(to_dim, hidden_dim),\n", " nn.ReLU(),\n", " nn.Linear(hidden_dim, to_dim),\n", - " nn.Dropout(dropout_rate)\n", + " nn.Dropout(dropout_rate),\n", " )\n", + "\n", " def forward(self, X):\n", - " return self.FFN(X)\n" + " return self.FFN(X)" ] }, { "cell_type": "markdown", - "id": "8569fc1d", + "id": "b5f6f291", "metadata": {}, "source": [ "Now, we can build the encoder and decoder blocks that will be part of the transformer thanks to the TensorDictModule. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block." @@ -360,38 +519,74 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "8f1496d2", + "execution_count": 13, + "id": "eb9775bd", "metadata": {}, "outputs": [], "source": [ "class AttentionBlockTensorDict(TensorDictSequence):\n", - " def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", + " def __init__(\n", + " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " ):\n", " super().__init__(\n", - " TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=[\"Q\", \"K\", \"V\"]),\n", - " TensorDictModule(SplitHeads(num_heads), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"Q\", \"K\", \"V\"]),\n", - " TensorDictModule(Attention(latent_dim, to_dim), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"X_out\",\"Attn\"]),\n", - " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, \"X_out\"], out_keys=[to_name]),\n", + " TensorDictModule(\n", + " TokensToQKV(to_dim, from_dim, latent_dim),\n", + " in_keys=[to_name, from_name],\n", + " out_keys=[\"Q\", \"K\", \"V\"],\n", + " ),\n", + " TensorDictModule(\n", + " SplitHeads(num_heads), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"Q\", \"K\", \"V\"]\n", + " ),\n", + " TensorDictModule(\n", + " Attention(latent_dim, to_dim),\n", + " in_keys=[\"Q\", \"K\", \"V\"],\n", + " out_keys=[\"X_out\", \"Attn\"],\n", + " ),\n", + " TensorDictModule(\n", + " SkipLayerNorm(to_len, to_dim),\n", + " in_keys=[to_name, \"X_out\"],\n", + " out_keys=[to_name],\n", + " ),\n", " )\n", + "\n", + "\n", "class TransformerBlockEncoderTensorDict(TensorDictSequence):\n", - " def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", + " def __init__(\n", + " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " ):\n", " super().__init__(\n", - " AttentionBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads),\n", - " TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=[to_name], out_keys=[\"X_out\"]),\n", - " TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=[to_name, \"X_out\"], out_keys=[to_name]),\n", + " AttentionBlockTensorDict(\n", + " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " ),\n", + " TensorDictModule(\n", + " FFN(to_dim, 4 * to_dim), in_keys=[to_name], out_keys=[\"X_out\"]\n", + " ),\n", + " TensorDictModule(\n", + " SkipLayerNorm(to_len, to_dim),\n", + " in_keys=[to_name, \"X_out\"],\n", + " out_keys=[to_name],\n", + " ),\n", " )\n", + "\n", + "\n", "class TransformerBlockDecoderTensorDict(TensorDictSequence):\n", - " def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):\n", + " def __init__(\n", + " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " ):\n", " super().__init__(\n", - " AttentionBlockTensorDict(to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads),\n", - " TransformerBlockEncoderTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads)\n", + " AttentionBlockTensorDict(\n", + " to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads\n", + " ),\n", + " TransformerBlockEncoderTensorDict(\n", + " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " ),\n", " )" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "06571f56", + "execution_count": 14, + "id": "e9601f5a", "metadata": {}, "outputs": [ { @@ -411,7 +606,7 @@ " is_shared=False)" ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -428,19 +623,13 @@ "tokens = TensorDict(\n", " {\n", " \"X_to\": torch.randn(batch_size, to_len, to_dim),\n", - " \"X_from\": torch.randn(batch_size, from_len, from_dim)\n", + " \"X_from\": torch.randn(batch_size, from_len, from_dim),\n", " },\n", - " batch_size=[batch_size]\n", + " batch_size=[batch_size],\n", ")\n", "\n", "transformer_block = AttentionBlockTensorDict(\n", - " \"X_to\",\n", - " \"X_from\",\n", - " to_dim,\n", - " to_len,\n", - " from_dim,\n", - " latent_dim,\n", - " num_heads\n", + " \"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads\n", ")\n", "\n", "transformer_block(tokens)\n", @@ -450,92 +639,28 @@ }, { "cell_type": "markdown", - "id": "19a1c3b1", - "metadata": {}, - "source": [ - "The output of the transformer layer can now be found at tokens[\"X_to\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "cf2662f7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[-4.1014e-01, -3.8609e-02, -4.6190e-01, -1.1984e+00, -1.0029e+00],\n", - " [ 4.8135e-01, 2.1887e+00, 1.0424e+00, 2.3389e-01, 2.2858e-01],\n", - " [ 9.2826e-02, 1.7278e+00, -1.3197e+00, -8.0110e-01, -7.6280e-01]],\n", - "\n", - " [[-8.1625e-01, 9.2371e-01, -3.4789e-01, -1.7687e+00, -5.5227e-01],\n", - " [-4.1375e-01, -6.0962e-01, -6.1183e-01, 2.3965e-01, 1.3284e+00],\n", - " [-7.7074e-01, 6.0351e-01, -2.2673e-01, 2.4082e+00, 6.1431e-01]],\n", - "\n", - " [[ 1.1999e+00, 9.0440e-01, 1.9596e-01, -1.7704e+00, 4.7291e-01],\n", - " [ 1.2238e+00, 1.9033e-01, 2.7292e-01, -2.5550e-01, -1.3451e+00],\n", - " [ 1.0397e+00, 1.1162e+00, -1.2123e+00, -1.0223e+00, -1.0106e+00]],\n", - "\n", - " [[ 4.9356e-01, -1.4560e-02, 2.1211e-02, 8.7976e-01, -8.7540e-02],\n", - " [-1.4565e-01, 1.0329e+00, 5.8444e-01, 6.7036e-01, -1.5396e+00],\n", - " [ 6.7115e-01, -9.1910e-01, 1.5387e+00, -8.9005e-01, -2.2955e+00]],\n", - "\n", - " [[ 1.1135e+00, 7.9529e-01, -1.6116e+00, -7.5607e-01, -8.4692e-01],\n", - " [ 6.8204e-01, -1.1673e-01, 9.6905e-01, -8.8737e-01, -1.1990e+00],\n", - " [ 1.9570e+00, -3.7739e-01, 8.7870e-01, 2.8001e-01, -8.8057e-01]],\n", - "\n", - " [[ 1.4367e+00, -4.3311e-01, 1.5850e+00, -5.8594e-01, -1.2526e+00],\n", - " [ 1.5208e+00, -1.0612e+00, -8.0319e-01, -1.2658e+00, 8.2322e-01],\n", - " [-1.2821e-01, 6.0637e-01, 8.4822e-01, -7.5417e-01, -5.3608e-01]],\n", - "\n", - " [[ 1.7111e-01, 5.6883e-01, 1.9995e-01, 4.3374e-01, -2.2862e+00],\n", - " [-1.4681e+00, 9.7458e-01, 2.6624e-01, -1.6163e-01, -1.2614e+00],\n", - " [ 5.9899e-01, -1.4938e-03, 1.7823e+00, -5.5919e-01, 7.4231e-01]],\n", - "\n", - " [[ 3.2855e-01, -1.3951e+00, 8.6517e-01, 1.3137e+00, 6.4360e-01],\n", - " [-1.5505e-01, -1.6713e+00, 6.1532e-01, -1.2615e+00, 1.7464e+00],\n", - " [-1.1965e+00, -6.1894e-01, 4.0306e-01, 4.8236e-01, -9.9691e-02]]],\n", - " grad_fn=)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tokens[\"X_to\"]" - ] - }, - { - "cell_type": "markdown", - "id": "c1396907", - "metadata": {}, - "source": [ - "We can now create a transformer easily" - ] - }, - { - "cell_type": "markdown", - "id": "6fcdb817", + "id": "34d2b6cb", "metadata": {}, "source": [ - "For an encoder, we just need to take the same tokens for both queries, keys and values." + "The output of the transformer layer can now be found at `tokens[\"X_to\"]`" ] }, { "cell_type": "markdown", - "id": "ecc1fe0f", + "id": "42dbfae5", "metadata": {}, "source": [ + "We can now create the transformer encoder and decoder.\n", + "\n", + "For an encoder, we just need to take the same tokens for both queries, keys and values.\n", + "\n", "For a decoder, we now can extract info from `X_from` into `X_to`. `X_from` will map to queries whereas X`_from` will map to keys and values." ] }, { "cell_type": "code", - "execution_count": 12, - "id": "a3e3027d", + "execution_count": 15, + "id": "1c6c85b5", "metadata": {}, "outputs": [], "source": [ @@ -549,19 +674,18 @@ " to_len,\n", " from_dim,\n", " latent_dim,\n", - " num_heads\n", + " num_heads,\n", " ):\n", " super().__init__(\n", - " *[TransformerBlockEncoderTensorDict(\n", - " to_name,\n", - " from_name,\n", - " to_dim,\n", - " to_len,\n", - " from_dim,\n", - " latent_dim,\n", - " num_heads\n", - " ) for _ in range(num_blocks)\n", - " ])\n", + " *[\n", + " TransformerBlockEncoderTensorDict(\n", + " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " )\n", + " for _ in range(num_blocks)\n", + " ]\n", + " )\n", + "\n", + "\n", "class TransformerDecoderTensorDict(TensorDictSequence):\n", " def __init__(\n", " self,\n", @@ -572,20 +696,18 @@ " to_len,\n", " from_dim,\n", " latent_dim,\n", - " num_heads\n", + " num_heads,\n", " ):\n", " super().__init__(\n", - " *[TransformerBlockDecoderTensorDict(\n", - " to_name,\n", - " from_name,\n", - " to_dim,\n", - " to_len,\n", - " from_dim,\n", - " latent_dim,\n", - " num_heads\n", - " ) for _ in range(num_blocks)\n", - " ])\n", - " \n", + " *[\n", + " TransformerBlockDecoderTensorDict(\n", + " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " )\n", + " for _ in range(num_blocks)\n", + " ]\n", + " )\n", + "\n", + "\n", "class TransformerTensorDict(TensorDictSequence):\n", " def __init__(\n", " self,\n", @@ -596,7 +718,7 @@ " to_len,\n", " from_dim,\n", " latent_dim,\n", - " num_heads\n", + " num_heads,\n", " ):\n", " super().__init__(\n", " TransformerEncoderTensorDict(\n", @@ -607,7 +729,7 @@ " to_len,\n", " to_dim,\n", " latent_dim,\n", - " num_heads\n", + " num_heads,\n", " ),\n", " TransformerDecoderTensorDict(\n", " num_blocks,\n", @@ -617,18 +739,39 @@ " from_len,\n", " to_dim,\n", " latent_dim,\n", - " num_heads\n", - " )\n", - " \n", - " ) \n" + " num_heads,\n", + " ),\n", + " )" ] }, { "cell_type": "code", - "execution_count": 13, - "id": "cc6ce12a", + "execution_count": 16, + "id": "09fa9f9a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "TensorDict(\n", + " fields={\n", + " Attn: Tensor(torch.Size([8, 2, 10, 3]), dtype=torch.float32),\n", + " K: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", + " Q: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", + " V: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", + " X_decode: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", + " X_encode: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n", + " X_out: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32)},\n", + " batch_size=torch.Size([8]),\n", + " device=cpu,\n", + " is_shared=False)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "to_dim = 5\n", "from_dim = 6\n", @@ -640,82 +783,133 @@ "\n", "tokens = TensorDict(\n", " {\n", - " \"X_encode\":torch.randn(batch_size, to_len, to_dim),\n", - " \"X_decode\":torch.randn(batch_size, from_len, from_dim)\n", + " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n", + " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n", " },\n", - " batch_size=[batch_size]\n", - ")" + " batch_size=[batch_size],\n", + ")\n", + "\n", + "transformer = TransformerTensorDict(\n", + " 6, \"X_encode\", \"X_decode\", to_dim, to_len, from_dim, latent_dim, num_heads\n", + ")\n", + "\n", + "transformer(tokens)\n", + "tokens" ] }, { - "cell_type": "code", - "execution_count": 14, - "id": "19f9fff2", + "cell_type": "markdown", + "id": "3f6448dd-5d0d-43fd-9e57-a0ac3b30ecba", "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "linear(): argument 'input' (position 1) must be Tensor, not NoneType", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m transformer \u001b[38;5;241m=\u001b[39m TransformerTensorDict(\n\u001b[1;32m 2\u001b[0m \u001b[38;5;241m6\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_to\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m num_heads\n\u001b[1;32m 10\u001b[0m )\n\u001b[0;32m---> 12\u001b[0m \u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m tokens\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", - " \u001b[0;31m[... skipping similar frames: Module._call_impl at line 1130 (2 times), TensorDictSequence.forward at line 228 (1 times)]\u001b[0m\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/sequence.py:228\u001b[0m, in \u001b[0;36mTensorDictSequence.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kwargs):\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule:\n\u001b[0;32m--> 228\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensorDictSequence does not support keyword arguments other than \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensordict_out\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 232\u001b[0m )\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:336\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 336\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Input \u001b[0;32mIn [8]\u001b[0m, in \u001b[0;36mTokensToQKV.forward\u001b[0;34m(self, X_to, X_from)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, X_to, X_from):\n\u001b[0;32m----> 8\u001b[0m Q \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_to\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m K \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mk(X_from)\n\u001b[1;32m 10\u001b[0m V \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv(X_from)\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/ENTER/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mTypeError\u001b[0m: linear(): argument 'input' (position 1) must be Tensor, not NoneType" - ] - } - ], "source": [ - "transformer = TransformerTensorDict(\n", - " 6,\n", - " \"X_to\",\n", - " \"X_from\",\n", - " to_dim,\n", - " to_len,\n", - " from_dim,\n", - " latent_dim,\n", - " num_heads\n", - ")\n", + "We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations" + ] + }, + { + "cell_type": "markdown", + "id": "22e65356-d8b3-4197-84b8-598330c1ddc8", + "metadata": {}, + "source": [ + "## TensorDictModule for RL" + ] + }, + { + "cell_type": "markdown", + "id": "8d49a911-933c-476f-8c9a-00e006ed043c", + "metadata": {}, + "source": [ + "In the context of RL torchrl offers a few wrappers on `TensorDictModule`" + ] + }, + { + "cell_type": "markdown", + "id": "e33904a6-d405-45db-a713-47493ca8ee33", + "metadata": { + "tags": [] + }, + "source": [ + "### `ProbabilisticTensorDictModule`" + ] + }, + { + "cell_type": "markdown", + "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f", + "metadata": {}, + "source": [ + "`ProbabilisticTDModule` is a special case of a `TensorDictModule` where the output is\n", + "sampled given some rule, specified by the input `default_interaction_mode`\n", + "argument and the `exploration_mode()` global function.\n", "\n", - "transformer(tokens)\n", - "tokens" + "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n", + "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n", + "responsible for constructing the distribution (through the `get_dist()` method)\n", + "and/or sampling from this distribution (through a regular `__call__()` to the\n", + "module)." ] }, { "cell_type": "markdown", - "id": "23b2a0c6", + "id": "406b1caa-bcec-4317-b685-10df23352154", "metadata": {}, "source": [ - "Now we can look at the model:" + "### `Actor`" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0881f74f", + "cell_type": "markdown", + "id": "e139de7d-0250-49c0-b495-8b5a404821f5", + "metadata": {}, + "source": [ + "Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `[\"action\"]`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "cceeade9-47f1-4e92-897a-dd226c9371a6", + "metadata": {}, + "source": [ + "### `ProbabilisticActor`" + ] + }, + { + "cell_type": "markdown", + "id": "4fd0f53e-90aa-49a9-9d8f-5a260255e556", + "metadata": {}, + "source": [ + "General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.\n", + "Similarly to `Actor`, it comes with default values for the `out_keys` (`[\"action\"]`).\n" + ] + }, + { + "cell_type": "markdown", + "id": "dbd48bb2-b93b-4766-b7a7-19d500f17e2d", + "metadata": {}, + "source": [ + "### `ActorCriticOperator`" + ] + }, + { + "cell_type": "markdown", + "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044", + "metadata": {}, + "source": [ + "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`.\n", + "and wraps both and actor network and a value Network. \n", + "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action." + ] + }, + { + "cell_type": "markdown", + "id": "bd08362a-8bb8-49fb-8038-1a60c5c01ea2", "metadata": {}, - "outputs": [], "source": [ - "transformer" + "Have fun with TensorDictModule!" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -729,7 +923,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.12" } }, "nbformat": 4, From f7d862250a0923ca09ee759e897f6fef5cdea893 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Wed, 13 Jul 2022 16:30:18 +0100 Subject: [PATCH 10/21] Made suggered modifications --- torchrl/version.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 torchrl/version.py diff --git a/torchrl/version.py b/torchrl/version.py deleted file mode 100644 index 363b21c4ca6..00000000000 --- a/torchrl/version.py +++ /dev/null @@ -1,2 +0,0 @@ -__version__ = '0.0.1a0+38e8beb' -git_version = '38e8bebb6736db965a2789a9cb066f5bc2f7711b' From 7010b5482430eee6ba682dd2cfae125c8a949ae4 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Wed, 13 Jul 2022 17:28:49 +0100 Subject: [PATCH 11/21] Made changes --- tutorials/tensordictmodule.ipynb | 281 ++++++++++++++++++++----------- 1 file changed, 185 insertions(+), 96 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 950c7dd0fe2..5ab06be52bc 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -26,9 +26,17 @@ "It is up to the user to define the keys to be read as input and output." ] }, + { + "cell_type": "markdown", + "id": "129a6de9-cf97-4565-a229-c05ad18df882", + "metadata": {}, + "source": [ + "## `TensorDictModule` by examples" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "5b0241ab", "metadata": {}, "outputs": [], @@ -40,15 +48,6 @@ "from torchrl.modules import TensorDictModule, TensorDictSequence" ] }, - { - "cell_type": "markdown", - "id": "129a6de9-cf97-4565-a229-c05ad18df882", - "metadata": {}, - "source": [ - "## `TensorDictModule` by examples\n", - "Let's learn about `TensorDictModule by exploring some examples" - ] - }, { "cell_type": "markdown", "id": "9d1c188a", @@ -62,66 +61,41 @@ "id": "1d21a711", "metadata": {}, "source": [ - "Let's imagine we have 2 entries `TensorDict`, a and b and we only want to pass a to our network." + "Let's suppose we have `TensorDict` with 2 entries `\"a\"` and `\"b\"` but only the value associated with `\"a\"` has to be read by the network." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 15, "id": "6f33781f", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", - " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", + " a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),\n", + " b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},\n", + " batch_size=torch.Size([5]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] } ], "source": [ "tensordict = TensorDict(\n", - " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)}, batch_size=[5]\n", + " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)},\n", + " batch_size=[5],\n", ")\n", - "linear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a_out\"])\n", - "linear(tensordict)\n", - "assert (tensordict[\"b\"] == torch.zeros(5, 4, 3)).all()\n", - "tensordict" - ] - }, - { - "cell_type": "markdown", - "id": "3a818629", - "metadata": {}, - "source": [ - "We can also do it inplace:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "4d1342d0", - "metadata": {}, - "outputs": [], - "source": [ - "tensordict = TensorDict(\n", - " {\"a\": torch.randn(5, 3), \"b\": torch.zeros(5, 4, 3)}, batch_size=[5]\n", + "linear = TensorDictModule(\n", + " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a_out\"]\n", ")\n", - "\n", - "linear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"a\"])\n", - "\n", "linear(tensordict)\n", - "assert tensordict[\"a\"].shape == torch.Size([5, 10])" + "assert (tensordict.get(\"b\") == 0).all()\n", + "print(tensordict)" ] }, { @@ -137,12 +111,12 @@ "id": "06a20c22", "metadata": {}, "source": [ - "Now lets imagine a more complex network that takes 2 entries and average them into a single output:" + "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "69098393", "metadata": {}, "outputs": [], @@ -159,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "2dd686bb", "metadata": {}, "outputs": [ @@ -176,7 +150,7 @@ " is_shared=False)" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -203,12 +177,12 @@ "metadata": {}, "source": [ "### Example 3: Multiple outputs\n", - "We can also map to multiple outputs" + "TensorDictModule not only supports multiple inputs but also multiple outputs." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "0b7f709b", "metadata": {}, "outputs": [], @@ -225,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "1b2b465f", "metadata": {}, "outputs": [ @@ -242,7 +216,7 @@ " is_shared=False)" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -251,7 +225,9 @@ "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", "\n", "splitlinear = TensorDictModule(\n", - " MultiHeadLinear(3, 4, 10), in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"]\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", ")\n", "splitlinear(tensordict)" ] @@ -265,12 +241,51 @@ "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place." ] }, + { + "cell_type": "markdown", + "id": "11d2d2a7-6a55-4f31-972b-041be387f9df", + "metadata": {}, + "source": [ + "### Example 4: Combining multiples `TensorDictModule` with `TensorDictSequence`" + ] + }, + { + "cell_type": "markdown", + "id": "89b157d5-322c-45d6-bec9-20440b78a2bf", + "metadata": {}, + "source": [ + "To combine multiples `TensorDictModule`instances, we can une `TensorDictSequence`. This block will take the input of the n-1th `TensorDictModule` in a list and feed it to the nth `TensorDictModule`" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", + "metadata": {}, + "outputs": [], + "source": [ + "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "\n", + "splitlinear = TensorDictModule(\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", + ")\n", + "mergelinear = TensorDictModule(\n", + " MergeLinear(4, 10, 13), in_keys=[\"output_1\", \"output_2\"], out_keys=[\"output\"]\n", + ")\n", + "\n", + "split_and_merge_linear = TensorDictSequence(splitlinear, mergelinear)\n", + "\n", + "assert split_and_merge_linear(tensordict)['output'].shape == torch.Size([5, 13])" + ] + }, { "cell_type": "markdown", "id": "760118ea", "metadata": {}, "source": [ - "### Example 4: Compatibility with functorch" + "### Example 5: Compatibility with functorch" ] }, { @@ -306,12 +321,12 @@ } ], "source": [ - "from functorch import make_functional_with_buffers, vmap\n", - "\n", "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", "\n", "splitlinear = TensorDictModule(\n", - " MultiHeadLinear(3, 4, 10), in_keys=[\"a\"], out_keys=[\"output_1\", \"output_2\"]\n", + " MultiHeadLinear(3, 4, 10),\n", + " in_keys=[\"a\"],\n", + " out_keys=[\"output_1\", \"output_2\"],\n", ")\n", "func, (params, buffers) = splitlinear.make_functional_with_buffers()\n", "func(tensordict, params=params, buffers=buffers)" @@ -353,14 +368,14 @@ } ], "source": [ - "from functorch import make_functional_with_buffers, vmap\n", - "\n", "num_models = 10\n", "\n", "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", "\n", "splitlinear_models = [\n", - " TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"])\n", + " TensorDictModule(\n", + " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"]\n", + " )\n", " for _ in range(num_models)\n", "]\n", "\n", @@ -368,7 +383,8 @@ "def transpose_stack(tuple_of_tuple_of_tensors):\n", " tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))\n", " results = tuple(\n", - " torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors\n", + " torch.stack(shards).detach()\n", + " for shards in tuple_of_tuple_of_tensors\n", " )\n", " return results\n", "\n", @@ -413,7 +429,9 @@ "\n", "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", "\n", - "splitlinear = TensorDictModule(nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"])\n", + "splitlinear = TensorDictModule(\n", + " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"]\n", + ")\n", "func, param, buffers = make_functional_with_buffers(splitlinear)\n", "func(param, buffers, tensordict)" ] @@ -423,15 +441,14 @@ "id": "6304a098", "metadata": {}, "source": [ - "### Example 5: A transformer with TensorDict?\n", - "Let's attempt to create a transformer with `TensorDict` and `TensorDictModule`.\n", + "### Example 6: Implementing a transformer using TensorDictModule\n", + "We can easily create a transformer that reads TensorDict objects using TensorDictModule.\n", "\n", - "Here's a diagram that sums up the architecture:\n", + "The following figure shows the classical transformer architecture (Vaswani et al, 2017) \n", "\n", "\n", "\n", - "Disclaimer: This implementation don't claim to be \"better\" than a classical tensor-based implementation. It is just meant to showcase the `TensorDictModule` features.\n", - "For simplicity we will not have positional encoders.\n", + "We have let the positional encoders aside for simplicity.\n", "\n", "Let's first implement the classical transformers blocks." ] @@ -466,9 +483,15 @@ " batch_size, to_num, latent_dim = Q.shape\n", " _, from_num, _ = K.shape\n", " d_tensor = latent_dim // self.num_heads\n", - " Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)\n", - " K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", - " V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)\n", + " Q = Q.reshape(\n", + " batch_size, to_num, self.num_heads, d_tensor\n", + " ).transpose(1, 2)\n", + " K = K.reshape(\n", + " batch_size, from_num, self.num_heads, d_tensor\n", + " ).transpose(1, 2)\n", + " V = V.reshape(\n", + " batch_size, from_num, self.num_heads, d_tensor\n", + " ).transpose(1, 2)\n", " return Q, K, V\n", "\n", "\n", @@ -482,7 +505,11 @@ " batch_size, n_heads, to_num, d_in = Q.shape\n", " attn = self.softmax(Q @ K.transpose(2, 3) / d_in)\n", " out = attn @ V\n", - " out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))\n", + " out = self.out(\n", + " out.transpose(1, 2).reshape(\n", + " batch_size, to_num, n_heads * d_in\n", + " )\n", + " )\n", " return out, attn\n", "\n", "\n", @@ -514,7 +541,7 @@ "id": "b5f6f291", "metadata": {}, "source": [ - "Now, we can build the encoder and decoder blocks that will be part of the transformer thanks to the TensorDictModule. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block." + "We can build the encoder and decoder blocks that will be part of the transformer thanks to the `TensorDictModule`. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block." ] }, { @@ -526,7 +553,14 @@ "source": [ "class AttentionBlockTensorDict(TensorDictSequence):\n", " def __init__(\n", - " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ):\n", " super().__init__(\n", " TensorDictModule(\n", @@ -535,7 +569,9 @@ " out_keys=[\"Q\", \"K\", \"V\"],\n", " ),\n", " TensorDictModule(\n", - " SplitHeads(num_heads), in_keys=[\"Q\", \"K\", \"V\"], out_keys=[\"Q\", \"K\", \"V\"]\n", + " SplitHeads(num_heads),\n", + " in_keys=[\"Q\", \"K\", \"V\"],\n", + " out_keys=[\"Q\", \"K\", \"V\"],\n", " ),\n", " TensorDictModule(\n", " Attention(latent_dim, to_dim),\n", @@ -552,14 +588,29 @@ "\n", "class TransformerBlockEncoderTensorDict(TensorDictSequence):\n", " def __init__(\n", - " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ):\n", " super().__init__(\n", " AttentionBlockTensorDict(\n", - " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ),\n", " TensorDictModule(\n", - " FFN(to_dim, 4 * to_dim), in_keys=[to_name], out_keys=[\"X_out\"]\n", + " FFN(to_dim, 4 * to_dim),\n", + " in_keys=[to_name],\n", + " out_keys=[\"X_out\"],\n", " ),\n", " TensorDictModule(\n", " SkipLayerNorm(to_len, to_dim),\n", @@ -571,14 +622,33 @@ "\n", "class TransformerBlockDecoderTensorDict(TensorDictSequence):\n", " def __init__(\n", - " self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " self,\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ):\n", " super().__init__(\n", " AttentionBlockTensorDict(\n", - " to_name, to_name, to_dim, to_len, to_dim, latent_dim, num_heads\n", + " to_name,\n", + " to_name,\n", + " to_dim,\n", + " to_len,\n", + " to_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ),\n", " TransformerBlockEncoderTensorDict(\n", - " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " ),\n", " )" ] @@ -642,7 +712,7 @@ "id": "34d2b6cb", "metadata": {}, "source": [ - "The output of the transformer layer can now be found at `tokens[\"X_to\"]`" + "The output of the attention can now be found at `tokens[\"X_to\"]`" ] }, { @@ -650,7 +720,7 @@ "id": "42dbfae5", "metadata": {}, "source": [ - "We can now create the transformer encoder and decoder.\n", + "We create the transformer encoder and decoder.\n", "\n", "For an encoder, we just need to take the same tokens for both queries, keys and values.\n", "\n", @@ -679,7 +749,13 @@ " super().__init__(\n", " *[\n", " TransformerBlockEncoderTensorDict(\n", - " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " )\n", " for _ in range(num_blocks)\n", " ]\n", @@ -701,7 +777,13 @@ " super().__init__(\n", " *[\n", " TransformerBlockDecoderTensorDict(\n", - " to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads\n", + " to_name,\n", + " from_name,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", " )\n", " for _ in range(num_blocks)\n", " ]\n", @@ -790,7 +872,14 @@ ")\n", "\n", "transformer = TransformerTensorDict(\n", - " 6, \"X_encode\", \"X_decode\", to_dim, to_len, from_dim, latent_dim, num_heads\n", + " 6,\n", + " \"X_encode\",\n", + " \"X_decode\",\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " latent_dim,\n", + " num_heads,\n", ")\n", "\n", "transformer(tokens)\n", From 7836c02e372520382997c61c8d214fa4794ef5fc Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 14 Jul 2022 11:38:24 +0100 Subject: [PATCH 12/21] Suggested changes and do and dont --- tutorials/src/transformer.py | 78 ++++ tutorials/tensordictmodule.ipynb | 624 ++++++++++++++++--------------- 2 files changed, 410 insertions(+), 292 deletions(-) create mode 100644 tutorials/src/transformer.py diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py new file mode 100644 index 00000000000..fad1fd0b84a --- /dev/null +++ b/tutorials/src/transformer.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn + + +class TokensToQKV(nn.Module): + def __init__(self, to_dim, from_dim, latent_dim): + super().__init__() + self.q = nn.Linear(to_dim, latent_dim) + self.k = nn.Linear(from_dim, latent_dim) + self.v = nn.Linear(from_dim, latent_dim) + + def forward(self, X_to, X_from): + Q = self.q(X_to) + K = self.k(X_from) + V = self.v(X_from) + return Q, K, V + + +class SplitHeads(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, Q, K, V): + batch_size, to_num, latent_dim = Q.shape + _, from_num, _ = K.shape + d_tensor = latent_dim // self.num_heads + Q = Q.reshape( + batch_size, to_num, self.num_heads, d_tensor + ).transpose(1, 2) + K = K.reshape( + batch_size, from_num, self.num_heads, d_tensor + ).transpose(1, 2) + V = V.reshape( + batch_size, from_num, self.num_heads, d_tensor + ).transpose(1, 2) + return Q, K, V + + +class Attention(nn.Module): + def __init__(self, latent_dim, to_dim): + super().__init__() + self.softmax = nn.Softmax(dim=-1) + self.out = nn.Linear(latent_dim, to_dim) + + def forward(self, Q, K, V): + batch_size, n_heads, to_num, d_in = Q.shape + attn = self.softmax(Q @ K.transpose(2, 3) / d_in) + out = attn @ V + out = self.out( + out.transpose(1, 2).reshape( + batch_size, to_num, n_heads * d_in + ) + ) + return out, attn + + +class SkipLayerNorm(nn.Module): + def __init__(self, to_len, to_dim): + super().__init__() + self.layer_norm = nn.LayerNorm((to_len, to_dim)) + + def forward(self, x_0, x_1): + return self.layer_norm(x_0 + x_1) + + +class FFN(nn.Module): + def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): + super().__init__() + self.FFN = nn.Sequential( + nn.Linear(to_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, to_dim), + nn.Dropout(dropout_rate), + ) + + def forward(self, X): + return self.FFN(X) \ No newline at end of file diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 5ab06be52bc..590b4a8d252 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "5b0241ab", "metadata": {}, "outputs": [], @@ -61,12 +61,12 @@ "id": "1d21a711", "metadata": {}, "source": [ - "Let's suppose we have `TensorDict` with 2 entries `\"a\"` and `\"b\"` but only the value associated with `\"a\"` has to be read by the network." + "We have a `TensorDict` with 2 entries `\"a\"` and `\"b\"` but only the value associated with `\"a\"` has to be read by the network." ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "id": "6f33781f", "metadata": {}, "outputs": [ @@ -111,12 +111,12 @@ "id": "06a20c22", "metadata": {}, "source": [ - "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor." + "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a `TensorDictModule` instance read multiple input values, one must register them in the `in_keys` keyword argument of the constructor.\"" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "69098393", "metadata": {}, "outputs": [], @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "2dd686bb", "metadata": {}, "outputs": [ @@ -150,7 +150,7 @@ " is_shared=False)" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -177,12 +177,12 @@ "metadata": {}, "source": [ "### Example 3: Multiple outputs\n", - "TensorDictModule not only supports multiple inputs but also multiple outputs." + "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values values, one must register them in the `out_keys` keyword argument of the constructor.\"" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "0b7f709b", "metadata": {}, "outputs": [], @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "1b2b465f", "metadata": {}, "outputs": [ @@ -216,7 +216,7 @@ " is_shared=False)" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -254,12 +254,14 @@ "id": "89b157d5-322c-45d6-bec9-20440b78a2bf", "metadata": {}, "source": [ - "To combine multiples `TensorDictModule`instances, we can une `TensorDictSequence`. This block will take the input of the n-1th `TensorDictModule` in a list and feed it to the nth `TensorDictModule`" + "To combine multiples `TensorDictModule` instances, we can use `TensorDictSequence`. We create a list where each `TensorDictModule` must be executed sequentially. `TensorDictSequence` will take the output of the n-1th `TensorDictModule` in the list and feed it to the nth `TensorDictModule`.\n", + "\n", + "We can also gather the input needed by `TensorDictSequence` with the `in_keys` property, and the output keys are found at the `out_keys` attribute." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", "metadata": {}, "outputs": [], @@ -272,12 +274,16 @@ " out_keys=[\"output_1\", \"output_2\"],\n", ")\n", "mergelinear = TensorDictModule(\n", - " MergeLinear(4, 10, 13), in_keys=[\"output_1\", \"output_2\"], out_keys=[\"output\"]\n", + " MergeLinear(4, 10, 13),\n", + " in_keys=[\"output_1\", \"output_2\"],\n", + " out_keys=[\"output\"],\n", ")\n", "\n", "split_and_merge_linear = TensorDictSequence(splitlinear, mergelinear)\n", "\n", - "assert split_and_merge_linear(tensordict)['output'].shape == torch.Size([5, 13])" + "assert split_and_merge_linear(tensordict)[\n", + " \"output\"\n", + "].shape == torch.Size([5, 13])" ] }, { @@ -293,12 +299,12 @@ "id": "e2718a12", "metadata": {}, "source": [ - "TensorDictModule is compatible with functorch. We can use make_functional_with_buffers on top of it." + "`TensorDictModule` comes with its own `make_functional_with_buffers` method to make it functional." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "b553bed1", "metadata": {}, "outputs": [ @@ -315,7 +321,7 @@ " is_shared=False)" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -337,23 +343,23 @@ "id": "50ac0393", "metadata": {}, "source": [ - "We can also use vmap. Let's do some model ensembling with it." + "We can also use `vmap`. We can do model ensembling with it." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "86ccb7be", "metadata": {}, "outputs": [ { "ename": "RuntimeError", - "evalue": "batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-i1bssf2g/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. ", + "evalue": "batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-ulv3uwfj/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 28\u001b[0m params \u001b[38;5;241m=\u001b[39m transpose_stack(params)\n\u001b[1;32m 29\u001b[0m buffers \u001b[38;5;241m=\u001b[39m transpose_stack(buffers)\n\u001b[0;32m---> 30\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 32\u001b[0m buffers \u001b[38;5;241m=\u001b[39m transpose_stack(buffers)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 35\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:333\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(err_msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 328\u001b[0m kwargs_pruned \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 329\u001b[0m key: item\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, item \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 332\u001b[0m }\n\u001b[0;32m--> 333\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuffers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_pruned\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", @@ -363,7 +369,7 @@ "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/make_functional.py:282\u001b[0m, in \u001b[0;36mFunctionalModuleWithBuffers.forward\u001b[0;34m(self, params, buffers, *args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m old_state \u001b[38;5;241m=\u001b[39m _swap_state(\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model,\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map,\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28mlist\u001b[39m(params) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(buffers))\n\u001b[1;32m 281\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstateless_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 284\u001b[0m \u001b[38;5;66;03m# Remove the loaded state on self.stateless_model\u001b[39;00m\n\u001b[1;32m 285\u001b[0m _swap_state(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map, old_state)\n", "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mRuntimeError\u001b[0m: batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-i1bssf2g/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. " + "\u001b[0;31mRuntimeError\u001b[0m: batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-ulv3uwfj/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. " ] } ], @@ -383,13 +389,15 @@ "def transpose_stack(tuple_of_tuple_of_tensors):\n", " tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))\n", " results = tuple(\n", - " torch.stack(shards).detach()\n", - " for shards in tuple_of_tuple_of_tensors\n", + " torch.stack(shards) for shards in tuple_of_tuple_of_tensors\n", " )\n", " return results\n", "\n", "\n", + "### Let's extract the functional version of the common model\n", "func = splitlinear_models[0].make_functional_with_buffers()[0]\n", + "\n", + "### We also extract parameters and buffers for every block and stack them together\n", "params, buffers = zip(\n", " *[\n", " splitlinear.make_functional_with_buffers()[1]\n", @@ -398,51 +406,280 @@ ")\n", "params = transpose_stack(params)\n", "buffers = transpose_stack(buffers)\n", + "\n", + "### Finally we can apply the function to the stack of params and buffer using vmap\n", "func(tensordict, params=params, buffers=buffers, vmap=True).shape" ] }, + { + "cell_type": "markdown", + "id": "31be6c45-10fb-4fd1-a52f-92214b76c00a", + "metadata": {}, + "source": [ + "## Do's and don't with `TensorDictModule`\n", + "\n", + "When `TensorDictModule`, we need to use `TensorDictSequence` if we want to merge it with other operations.\n", + "\n", + "Do not use `nn.Module` wrappers. This would break some of `TensorDictModule` features such as `functorch` compatibility." + ] + }, + { + "cell_type": "markdown", + "id": "22e65356-d8b3-4197-84b8-598330c1ddc8", + "metadata": {}, + "source": [ + "## TensorDictModule for RL" + ] + }, + { + "cell_type": "markdown", + "id": "8d49a911-933c-476f-8c9a-00e006ed043c", + "metadata": {}, + "source": [ + "In the context of RL torchrl offers a few wrappers on `TensorDictModule`" + ] + }, + { + "cell_type": "markdown", + "id": "e33904a6-d405-45db-a713-47493ca8ee33", + "metadata": { + "tags": [] + }, + "source": [ + "### `ProbabilisticTensorDictModule`" + ] + }, + { + "cell_type": "markdown", + "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f", + "metadata": {}, + "source": [ + "`ProbabilisticTDModule` is a special case of a `TensorDictModule` where the output is\n", + "sampled given some rule, specified by the input `default_interaction_mode`\n", + "argument and the `exploration_mode()` global function.\n", + "\n", + "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n", + "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n", + "responsible for constructing the distribution (through the `get_dist()` method)\n", + "and/or sampling from this distribution (through a regular `__call__()` to the\n", + "module)." + ] + }, { "cell_type": "code", - "execution_count": 11, - "id": "1ed39eab", + "execution_count": 12, + "id": "9f25a6ba-c9f5-4eea-8e4b-c4f456d19b4c", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " a: Tensor(torch.Size([5, 3]), dtype=torch.float32),\n", - " output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},\n", - " batch_size=torch.Size([5]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", + " input: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", + " input: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] } ], "source": [ - "from functorch import make_functional_with_buffers\n", - "\n", - "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", + "from torchrl.modules import ProbabilisticTensorDictModule\n", + "from torchrl.modules import TanhNormal, NormalParamWrapper\n", + "import functorch\n", + "td = TensorDict({\"input\": torch.randn(3, 4), \"hidden\": torch.randn(3, 8)}, [3,])\n", + "net = NormalParamWrapper(torch.nn.GRUCell(4, 8))\n", + "module = TensorDictModule(net, in_keys=[\"input\", \"hidden\"], out_keys=[\"loc\", \"scale\"])\n", + "td_module = ProbabilisticTensorDictModule(\n", + " module=module,\n", + " dist_param_keys=[\"loc\", \"scale\"],\n", + " out_key_sample=[\"action\"],\n", + " distribution_class=TanhNormal,\n", + " return_log_prob=True,\n", + " )\n", + "print(td)\n", + "td_module(td)\n", + "print(td)" + ] + }, + { + "cell_type": "markdown", + "id": "406b1caa-bcec-4317-b685-10df23352154", + "metadata": {}, + "source": [ + "### `Actor`" + ] + }, + { + "cell_type": "markdown", + "id": "e139de7d-0250-49c0-b495-8b5a404821f5", + "metadata": {}, + "source": [ + "Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `[\"action\"]`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "cceeade9-47f1-4e92-897a-dd226c9371a6", + "metadata": {}, + "source": [ + "### `ProbabilisticActor`" + ] + }, + { + "cell_type": "markdown", + "id": "4fd0f53e-90aa-49a9-9d8f-5a260255e556", + "metadata": {}, + "source": [ + "General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.\n", + "Similarly to `Actor`, it comes with default values for the `out_keys` (`[\"action\"]`).\n" + ] + }, + { + "cell_type": "markdown", + "id": "dbd48bb2-b93b-4766-b7a7-19d500f17e2d", + "metadata": {}, + "source": [ + "### `ActorCriticOperator`" + ] + }, + { + "cell_type": "markdown", + "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044", + "metadata": {}, + "source": [ + "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`.\n", + "and wraps both and actor network and a value Network. \n", + "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "5b6c6035-f9cc-41e7-bf3a-f88936f93b70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorDict(\n", + " fields={\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n", + "TensorDict(\n", + " fields={\n", + " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " observation: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),\n", + " scale: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", + " state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},\n", + " batch_size=torch.Size([3]),\n", + " device=cpu,\n", + " is_shared=False)\n" + ] + } + ], + "source": [ + "from torchrl.modules import (\n", + " MLP,\n", + " ActorCriticOperator,\n", + " NormalParamWrapper,\n", + " TanhNormal,\n", + " ValueOperator,\n", + ")\n", + "from torchrl.modules.tensordict_module import ProbabilisticActor\n", "\n", - "splitlinear = TensorDictModule(\n", - " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"]\n", + "module_hidden = torch.nn.Linear(4, 4)\n", + "td_module_hidden = TensorDictModule(\n", + " module=module_hidden,\n", + " in_keys=[\"observation\"],\n", + " out_keys=[\"hidden\"],\n", + ")\n", + "module_action = NormalParamWrapper(torch.nn.Linear(4, 8))\n", + "module_action = TensorDictModule(\n", + " module_action, in_keys=[\"hidden\"], out_keys=[\"loc\", \"scale\"]\n", + ")\n", + "td_module_action = ProbabilisticActor(\n", + " module=module_action,\n", + " dist_param_keys=[\"loc\", \"scale\"],\n", + " out_key_sample=[\"action\"],\n", + " distribution_class=TanhNormal,\n", + " return_log_prob=True,\n", + ")\n", + "module_value = MLP(in_features=8, out_features=1, num_cells=[])\n", + "td_module_value = ValueOperator(\n", + " module=module_value,\n", + " in_keys=[\"hidden\", \"action\"],\n", + " out_keys=[\"state_action_value\"],\n", + ")\n", + "td_module = ActorCriticOperator(\n", + " td_module_hidden, td_module_action, td_module_value\n", + ")\n", + "td = TensorDict(\n", + " {\"observation\": torch.randn(3, 4)},\n", + " [\n", + " 3,\n", + " ],\n", ")\n", - "func, param, buffers = make_functional_with_buffers(splitlinear)\n", - "func(param, buffers, tensordict)" + "print(td)\n", + "td_clone = td_module(td.clone())\n", + "print(td_clone)\n", + "td_clone = td_module.get_policy_operator()(td.clone())\n", + "print(td_clone) # no value\n", + "td_clone = td_module.get_critic_operator()(td.clone())\n", + "print(td_clone) # no action" ] }, { "cell_type": "markdown", "id": "6304a098", - "metadata": {}, + "metadata": { + "tags": [] + }, "source": [ - "### Example 6: Implementing a transformer using TensorDictModule\n", - "We can easily create a transformer that reads TensorDict objects using TensorDictModule.\n", + "## Showcase: Implementing a transformer using TensorDictModule\n", + "To demonstrate the flexibility of `TensorDictModule`, we are going to create a transformer that reads `TensorDict` objects using `TensorDictModule`.\n", "\n", "The following figure shows the classical transformer architecture (Vaswani et al, 2017) \n", "\n", @@ -450,103 +687,38 @@ "\n", "We have let the positional encoders aside for simplicity.\n", "\n", - "Let's first implement the classical transformers blocks." + "Let's first import the classical transformers blocks (see `src/transformer.py`for more details.)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "e1f7ba7b", "metadata": {}, "outputs": [], "source": [ - "class TokensToQKV(nn.Module):\n", - " def __init__(self, to_dim, from_dim, latent_dim):\n", - " super().__init__()\n", - " self.q = nn.Linear(to_dim, latent_dim)\n", - " self.k = nn.Linear(from_dim, latent_dim)\n", - " self.v = nn.Linear(from_dim, latent_dim)\n", - "\n", - " def forward(self, X_to, X_from):\n", - " Q = self.q(X_to)\n", - " K = self.k(X_from)\n", - " V = self.v(X_from)\n", - " return Q, K, V\n", - "\n", - "\n", - "class SplitHeads(nn.Module):\n", - " def __init__(self, num_heads):\n", - " super().__init__()\n", - " self.num_heads = num_heads\n", - "\n", - " def forward(self, Q, K, V):\n", - " batch_size, to_num, latent_dim = Q.shape\n", - " _, from_num, _ = K.shape\n", - " d_tensor = latent_dim // self.num_heads\n", - " Q = Q.reshape(\n", - " batch_size, to_num, self.num_heads, d_tensor\n", - " ).transpose(1, 2)\n", - " K = K.reshape(\n", - " batch_size, from_num, self.num_heads, d_tensor\n", - " ).transpose(1, 2)\n", - " V = V.reshape(\n", - " batch_size, from_num, self.num_heads, d_tensor\n", - " ).transpose(1, 2)\n", - " return Q, K, V\n", - "\n", - "\n", - "class Attention(nn.Module):\n", - " def __init__(self, latent_dim, to_dim):\n", - " super().__init__()\n", - " self.softmax = nn.Softmax(dim=-1)\n", - " self.out = nn.Linear(latent_dim, to_dim)\n", - "\n", - " def forward(self, Q, K, V):\n", - " batch_size, n_heads, to_num, d_in = Q.shape\n", - " attn = self.softmax(Q @ K.transpose(2, 3) / d_in)\n", - " out = attn @ V\n", - " out = self.out(\n", - " out.transpose(1, 2).reshape(\n", - " batch_size, to_num, n_heads * d_in\n", - " )\n", - " )\n", - " return out, attn\n", - "\n", - "\n", - "class SkipLayerNorm(nn.Module):\n", - " def __init__(self, to_len, to_dim):\n", - " super().__init__()\n", - " self.layer_norm = nn.LayerNorm((to_len, to_dim))\n", - "\n", - " def forward(self, x_0, x_1):\n", - " return self.layer_norm(x_0 + x_1)\n", - "\n", - "\n", - "class FFN(nn.Module):\n", - " def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):\n", - " super().__init__()\n", - " self.FFN = nn.Sequential(\n", - " nn.Linear(to_dim, hidden_dim),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_dim, to_dim),\n", - " nn.Dropout(dropout_rate),\n", - " )\n", - "\n", - " def forward(self, X):\n", - " return self.FFN(X)" + "from tutorials.src.transformer import (\n", + " FFN,\n", + " Attention,\n", + " SkipLayerNorm,\n", + " SplitHeads,\n", + " TokensToQKV,\n", + ")" ] }, { "cell_type": "markdown", - "id": "b5f6f291", + "id": "c3258540-acb2-4090-a374-822dfcb857bd", "metadata": {}, "source": [ - "We can build the encoder and decoder blocks that will be part of the transformer thanks to the `TensorDictModule`. Since the changes affect the `TensorDict`, we just need to map outputs to the right name such as it is picked up by the next block." + "We first create the `AttentionBlockTensorDict`, the attention block using `TensorDictModule` and `TensorDictSequence`.\n", + "\n", + "The wiring operation that connects the modules to each other requires us to indicate which key each of them must read and write. Unlike `nn.Sequence`, a `TensorDictSequence` can read/write more than one input/output. Moreover, its components inputs need not be identical to the previous layers outputs, allowing us to code complicated neural architecture." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "eb9775bd", "metadata": {}, "outputs": [], @@ -583,9 +755,24 @@ " in_keys=[to_name, \"X_out\"],\n", " out_keys=[to_name],\n", " ),\n", - " )\n", - "\n", - "\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "b5f6f291", + "metadata": {}, + "source": [ + "We build the encoder and decoder blocks that will be part of the transformer thanks to `TensorDictModule`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f902006d-3f89-4ea6-84e0-a193a53e42db", + "metadata": {}, + "outputs": [], + "source": [ "class TransformerBlockEncoderTensorDict(TensorDictSequence):\n", " def __init__(\n", " self,\n", @@ -653,68 +840,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": 14, - "id": "e9601f5a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TensorDict(\n", - " fields={\n", - " Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),\n", - " K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", - " Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),\n", - " V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),\n", - " X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),\n", - " X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),\n", - " X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},\n", - " batch_size=torch.Size([8]),\n", - " device=cpu,\n", - " is_shared=False)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "to_dim = 5\n", - "from_dim = 6\n", - "latent_dim = 10\n", - "to_len = 3\n", - "from_len = 10\n", - "batch_size = 8\n", - "num_heads = 2\n", - "\n", - "tokens = TensorDict(\n", - " {\n", - " \"X_to\": torch.randn(batch_size, to_len, to_dim),\n", - " \"X_from\": torch.randn(batch_size, from_len, from_dim),\n", - " },\n", - " batch_size=[batch_size],\n", - ")\n", - "\n", - "transformer_block = AttentionBlockTensorDict(\n", - " \"X_to\", \"X_from\", to_dim, to_len, from_dim, latent_dim, num_heads\n", - ")\n", - "\n", - "transformer_block(tokens)\n", - "\n", - "tokens" - ] - }, - { - "cell_type": "markdown", - "id": "34d2b6cb", - "metadata": {}, - "source": [ - "The output of the attention can now be found at `tokens[\"X_to\"]`" - ] - }, { "cell_type": "markdown", "id": "42dbfae5", @@ -729,7 +854,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "1c6c85b5", "metadata": {}, "outputs": [], @@ -826,10 +951,18 @@ " )" ] }, + { + "cell_type": "markdown", + "id": "15b1b4e2-918d-40bc-a245-15be0e9cc276", + "metadata": {}, + "source": [ + "We now test our new `TransformerTensorDict`" + ] + }, { "cell_type": "code", - "execution_count": 16, - "id": "09fa9f9a", + "execution_count": 18, + "id": "7a680452-1462-4ee6-ba04-dce0bb855870", "metadata": {}, "outputs": [ { @@ -849,7 +982,7 @@ " is_shared=False)" ] }, - "execution_count": 16, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -894,99 +1027,6 @@ "We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations" ] }, - { - "cell_type": "markdown", - "id": "22e65356-d8b3-4197-84b8-598330c1ddc8", - "metadata": {}, - "source": [ - "## TensorDictModule for RL" - ] - }, - { - "cell_type": "markdown", - "id": "8d49a911-933c-476f-8c9a-00e006ed043c", - "metadata": {}, - "source": [ - "In the context of RL torchrl offers a few wrappers on `TensorDictModule`" - ] - }, - { - "cell_type": "markdown", - "id": "e33904a6-d405-45db-a713-47493ca8ee33", - "metadata": { - "tags": [] - }, - "source": [ - "### `ProbabilisticTensorDictModule`" - ] - }, - { - "cell_type": "markdown", - "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f", - "metadata": {}, - "source": [ - "`ProbabilisticTDModule` is a special case of a `TensorDictModule` where the output is\n", - "sampled given some rule, specified by the input `default_interaction_mode`\n", - "argument and the `exploration_mode()` global function.\n", - "\n", - "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n", - "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n", - "responsible for constructing the distribution (through the `get_dist()` method)\n", - "and/or sampling from this distribution (through a regular `__call__()` to the\n", - "module)." - ] - }, - { - "cell_type": "markdown", - "id": "406b1caa-bcec-4317-b685-10df23352154", - "metadata": {}, - "source": [ - "### `Actor`" - ] - }, - { - "cell_type": "markdown", - "id": "e139de7d-0250-49c0-b495-8b5a404821f5", - "metadata": {}, - "source": [ - "Actor inherits from `TensorDictModule` and comes with a default value for `out_keys` of `[\"action\"]`.\n" - ] - }, - { - "cell_type": "markdown", - "id": "cceeade9-47f1-4e92-897a-dd226c9371a6", - "metadata": {}, - "source": [ - "### `ProbabilisticActor`" - ] - }, - { - "cell_type": "markdown", - "id": "4fd0f53e-90aa-49a9-9d8f-5a260255e556", - "metadata": {}, - "source": [ - "General class for probabilistic actors in RL that inherits from `ProbabilisticTensorDictModule`.\n", - "Similarly to `Actor`, it comes with default values for the `out_keys` (`[\"action\"]`).\n" - ] - }, - { - "cell_type": "markdown", - "id": "dbd48bb2-b93b-4766-b7a7-19d500f17e2d", - "metadata": {}, - "source": [ - "### `ActorCriticOperator`" - ] - }, - { - "cell_type": "markdown", - "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044", - "metadata": {}, - "source": [ - "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`.\n", - "and wraps both and actor network and a value Network. \n", - "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action." - ] - }, { "cell_type": "markdown", "id": "bd08362a-8bb8-49fb-8038-1a60c5c01ea2", From 7dd7efb752258cab9b84bda0244cbdccd2e9ec6f Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 14 Jul 2022 13:44:38 +0100 Subject: [PATCH 13/21] Formating --- tutorials/src/transformer.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py index fad1fd0b84a..38cea672034 100644 --- a/tutorials/src/transformer.py +++ b/tutorials/src/transformer.py @@ -25,15 +25,9 @@ def forward(self, Q, K, V): batch_size, to_num, latent_dim = Q.shape _, from_num, _ = K.shape d_tensor = latent_dim // self.num_heads - Q = Q.reshape( - batch_size, to_num, self.num_heads, d_tensor - ).transpose(1, 2) - K = K.reshape( - batch_size, from_num, self.num_heads, d_tensor - ).transpose(1, 2) - V = V.reshape( - batch_size, from_num, self.num_heads, d_tensor - ).transpose(1, 2) + Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2) + K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) + V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) return Q, K, V @@ -47,11 +41,7 @@ def forward(self, Q, K, V): batch_size, n_heads, to_num, d_in = Q.shape attn = self.softmax(Q @ K.transpose(2, 3) / d_in) out = attn @ V - out = self.out( - out.transpose(1, 2).reshape( - batch_size, to_num, n_heads * d_in - ) - ) + out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in)) return out, attn @@ -75,4 +65,4 @@ def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): ) def forward(self, X): - return self.FFN(X) \ No newline at end of file + return self.FFN(X) From 81be639380aded5cb66f7c5ef8390ec24a3e7966 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 14 Jul 2022 13:47:48 +0100 Subject: [PATCH 14/21] Formating --- tutorials/src/transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py index 38cea672034..3cfd8d307a5 100644 --- a/tutorials/src/transformer.py +++ b/tutorials/src/transformer.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn From 78a40b286c6f121a8382e0ff9404997818064a4e Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Fri, 15 Jul 2022 14:12:47 +0100 Subject: [PATCH 15/21] Did some changes --- tutorials/tensordictmodule.ipynb | 87 +++++++++++++++++++------------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 590b4a8d252..af7405cd28f 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -36,10 +36,19 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "5b0241ab", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -66,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "6f33781f", "metadata": {}, "outputs": [ @@ -83,6 +92,14 @@ " device=cpu,\n", " is_shared=False)\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] } ], "source": [ @@ -111,12 +128,12 @@ "id": "06a20c22", "metadata": {}, "source": [ - "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a `TensorDictModule` instance read multiple input values, one must register them in the `in_keys` keyword argument of the constructor.\"" + "Suppose we have a slightly more complex network that takes 2 entries and averages them into a single output tensor. To make a `TensorDictModule` instance read multiple input values, one must register them in the `in_keys` keyword argument of the constructor." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "69098393", "metadata": {}, "outputs": [], @@ -133,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "2dd686bb", "metadata": {}, "outputs": [ @@ -150,7 +167,7 @@ " is_shared=False)" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -177,12 +194,12 @@ "metadata": {}, "source": [ "### Example 3: Multiple outputs\n", - "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values values, one must register them in the `out_keys` keyword argument of the constructor.\"" + "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values values, one must register them in the `out_keys` keyword argument of the constructor." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "0b7f709b", "metadata": {}, "outputs": [], @@ -199,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "1b2b465f", "metadata": {}, "outputs": [ @@ -216,7 +233,7 @@ " is_shared=False)" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -237,7 +254,7 @@ "id": "859630c3", "metadata": {}, "source": [ - "As we shown previously, the `TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.\n", + "As we shown previously,`TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.\n", "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place." ] }, @@ -261,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", "metadata": {}, "outputs": [], @@ -304,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "b553bed1", "metadata": {}, "outputs": [ @@ -321,7 +338,7 @@ " is_shared=False)" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -348,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "86ccb7be", "metadata": {}, "outputs": [ @@ -359,7 +376,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 32\u001b[0m buffers \u001b[38;5;241m=\u001b[39m transpose_stack(buffers)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 35\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m buffers \u001b[38;5;241m=\u001b[39m [torch\u001b[38;5;241m.\u001b[39mrandn(num_models, \u001b[38;5;241m*\u001b[39mb\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m buffers[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:333\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(err_msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 328\u001b[0m kwargs_pruned \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 329\u001b[0m key: item\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, item \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 332\u001b[0m }\n\u001b[0;32m--> 333\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuffers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_pruned\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", @@ -386,14 +403,6 @@ "]\n", "\n", "\n", - "def transpose_stack(tuple_of_tuple_of_tensors):\n", - " tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))\n", - " results = tuple(\n", - " torch.stack(shards) for shards in tuple_of_tuple_of_tensors\n", - " )\n", - " return results\n", - "\n", - "\n", "### Let's extract the functional version of the common model\n", "func = splitlinear_models[0].make_functional_with_buffers()[0]\n", "\n", @@ -404,8 +413,9 @@ " for splitlinear in splitlinear_models\n", " ]\n", ")\n", - "params = transpose_stack(params)\n", - "buffers = transpose_stack(buffers)\n", + "## For simplicity we reinit the params. In a real application you need to stack all params\n", + "params = [torch.randn(num_models, *p.shape) for p in params[0]]\n", + "buffers = [torch.randn(num_models, *b.shape) for b in buffers[0]]\n", "\n", "### Finally we can apply the function to the stack of params and buffer using vmap\n", "func(tensordict, params=params, buffers=buffers, vmap=True).shape" @@ -418,9 +428,16 @@ "source": [ "## Do's and don't with `TensorDictModule`\n", "\n", - "When `TensorDictModule`, we need to use `TensorDictSequence` if we want to merge it with other operations.\n", + "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. Do use `TensorDictSequence` instead.\n", + "\n", + "Don't use a different name for the output of a `TensorDictModule` as the output tensordict is just the input modified in-place:\n", + "\n", + "```python\n", + "tensordict = module(tensordict) # ok!\n", + "tensordict_out = module(tensordict) # don't!\n", + "```\n", "\n", - "Do not use `nn.Module` wrappers. This would break some of `TensorDictModule` features such as `functorch` compatibility." + "Don't use `nn.Sequence` but do usw `TensorDictSequence`. Same as nn.Module, it would break `functorch` compatibility" ] }, { @@ -467,22 +484,22 @@ }, { "cell_type": "code", - "execution_count": 12, - "id": "9f25a6ba-c9f5-4eea-8e4b-c4f456d19b4c", + "execution_count": 14, + "id": "9dd7846a-f12c-492e-a2ef-b0c67969234d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TensorDict(\n", + "TensorDict before going through module: TensorDict(\n", " fields={\n", " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", " input: Tensor(torch.Size([3, 4]), dtype=torch.float32)},\n", " batch_size=torch.Size([3]),\n", " device=cpu,\n", " is_shared=False)\n", - "TensorDict(\n", + "TensorDict after going through module now as keys action, loc and scale: TensorDict(\n", " fields={\n", " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", " hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),\n", @@ -510,9 +527,9 @@ " distribution_class=TanhNormal,\n", " return_log_prob=True,\n", " )\n", - "print(td)\n", + "print(f\"TensorDict before going through module: {td}\")\n", "td_module(td)\n", - "print(td)" + "print(f\"TensorDict after going through module now as keys action, loc and scale: {td}\")" ] }, { From 49edf5dff82c1d73f44dcf573f6ac75ff7effc27 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Mon, 18 Jul 2022 11:05:50 +0100 Subject: [PATCH 16/21] Made suggested changes --- tutorials/tensordictmodule.ipynb | 71 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index af7405cd28f..1b42bb522db 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -8,6 +8,14 @@ "# TensorDictModule" ] }, + { + "cell_type": "raw", + "id": "06f6d0ee-a4df-42e9-a2ab-ceb62e2ad5bf", + "metadata": {}, + "source": [ + "# TensorDictModule" + ] + }, { "cell_type": "markdown", "id": "94bd315a", @@ -18,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "0652352c", + "id": "bbc7e457-48b5-42d2-a8cf-092f0419d2d4", "metadata": {}, "source": [ "For a convenient usage of the `TensorDict` class with `nn.Module`, TorchRL provides an interface between the two named `TensorDictModule`.
\n", @@ -36,19 +44,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "5b0241ab", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -75,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "6f33781f", "metadata": {}, "outputs": [ @@ -92,14 +91,6 @@ " device=cpu,\n", " is_shared=False)\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] } ], "source": [ @@ -133,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "69098393", "metadata": {}, "outputs": [], @@ -150,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "2dd686bb", "metadata": {}, "outputs": [ @@ -167,7 +158,7 @@ " is_shared=False)" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -199,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "0b7f709b", "metadata": {}, "outputs": [], @@ -216,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "1b2b465f", "metadata": {}, "outputs": [ @@ -233,7 +224,7 @@ " is_shared=False)" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -271,14 +262,14 @@ "id": "89b157d5-322c-45d6-bec9-20440b78a2bf", "metadata": {}, "source": [ - "To combine multiples `TensorDictModule` instances, we can use `TensorDictSequence`. We create a list where each `TensorDictModule` must be executed sequentially. `TensorDictSequence` will take the output of the n-1th `TensorDictModule` in the list and feed it to the nth `TensorDictModule`.\n", + "To combine multiples `TensorDictModule` instances, we can use `TensorDictSequence`. We create a list where each `TensorDictModule` must be executed sequentially. `TensorDictSequence` will read and write keys to the tensordict following the sequence of modules provided.\n", "\n", - "We can also gather the input needed by `TensorDictSequence` with the `in_keys` property, and the output keys are found at the `out_keys` attribute." + "We can also gather the inputs needed by `TensorDictSequence` with the `in_keys` property, and the outputs keys are found at the `out_keys` attribute." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", "metadata": {}, "outputs": [], @@ -321,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "b553bed1", "metadata": {}, "outputs": [ @@ -338,7 +329,7 @@ " is_shared=False)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -365,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "86ccb7be", "metadata": {}, "outputs": [ @@ -376,7 +367,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m buffers \u001b[38;5;241m=\u001b[39m [torch\u001b[38;5;241m.\u001b[39mrandn(num_models, \u001b[38;5;241m*\u001b[39mb\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m buffers[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m buffers \u001b[38;5;241m=\u001b[39m [torch\u001b[38;5;241m.\u001b[39mrandn(num_models, \u001b[38;5;241m*\u001b[39mb\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m buffers[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:333\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(err_msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 328\u001b[0m kwargs_pruned \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 329\u001b[0m key: item\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, item \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 332\u001b[0m }\n\u001b[0;32m--> 333\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuffers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_pruned\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", @@ -428,7 +419,11 @@ "source": [ "## Do's and don't with `TensorDictModule`\n", "\n", - "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. Do use `TensorDictSequence` instead.\n", + "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. \n", + "\n", + "Don't use `nn.Sequence`, similar to nn.Module, it would break features such as `functorch` compatibility.\n", + "\n", + "Do use `TensorDictSequence` instead.\n", "\n", "Don't use a different name for the output of a `TensorDictModule` as the output tensordict is just the input modified in-place:\n", "\n", @@ -437,7 +432,11 @@ "tensordict_out = module(tensordict) # don't!\n", "```\n", "\n", - "Don't use `nn.Sequence` but do usw `TensorDictSequence`. Same as nn.Module, it would break `functorch` compatibility" + "Don't use `make_functional_with_buffers` from `functorch` directly.\n", + "\n", + "Do use `TensorDictModule.make_functional_with_buffers` instead.\n", + "\n", + "\n" ] }, { From 84de9745679c88d5a6aed9fabdd26f0ff6230c7a Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Tue, 19 Jul 2022 11:03:03 +0100 Subject: [PATCH 17/21] Added tensordictmodule tutorial to README.MD --- tutorials/README.md | 2 ++ tutorials/tensordictmodule.ipynb | 8 -------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tutorials/README.md b/tutorials/README.md index 7ac2d8a1765..a094f78e710 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -7,5 +7,7 @@ For an overview of TorchRL, try the [TorchRL demo](demo.ipynb). Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict is about and what it can do. +To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule demo](tensordictmodule.ipynb). + Checkout the [environment demo](envs.ipynb) for a deep dive in the envs functionalities. diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 1b42bb522db..be98271a13e 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -8,14 +8,6 @@ "# TensorDictModule" ] }, - { - "cell_type": "raw", - "id": "06f6d0ee-a4df-42e9-a2ab-ceb62e2ad5bf", - "metadata": {}, - "source": [ - "# TensorDictModule" - ] - }, { "cell_type": "markdown", "id": "94bd315a", From 9999c2b6500a896758141f4a5154f51de814e592 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Tue, 19 Jul 2022 11:36:32 +0100 Subject: [PATCH 18/21] Clean rerun --- tutorials/tensordictmodule.ipynb | 85 +++++++++++++++++--------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index be98271a13e..b0534ce46a6 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -36,10 +36,19 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "5b0241ab", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -66,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "6f33781f", "metadata": {}, "outputs": [ @@ -83,6 +92,14 @@ " device=cpu,\n", " is_shared=False)\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] } ], "source": [ @@ -116,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "69098393", "metadata": {}, "outputs": [], @@ -133,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "2dd686bb", "metadata": {}, "outputs": [ @@ -150,7 +167,7 @@ " is_shared=False)" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -182,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "0b7f709b", "metadata": {}, "outputs": [], @@ -199,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "1b2b465f", "metadata": {}, "outputs": [ @@ -216,7 +233,7 @@ " is_shared=False)" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -261,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "7e36d071-df67-4232-a8a9-78e79b32fef2", "metadata": {}, "outputs": [], @@ -304,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "b553bed1", "metadata": {}, "outputs": [ @@ -321,7 +338,7 @@ " is_shared=False)" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -348,29 +365,19 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "86ccb7be", "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-ulv3uwfj/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. ", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m buffers \u001b[38;5;241m=\u001b[39m [torch\u001b[38;5;241m.\u001b[39mrandn(num_models, \u001b[38;5;241m*\u001b[39mb\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m buffers[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m### Finally we can apply the function to the stack of params and buffer using vmap\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", - "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:346\u001b[0m, in \u001b[0;36mTensorDictModule.forward\u001b[0;34m(self, tensordict, tensordict_out, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 341\u001b[0m tensordict: _TensorDict,\n\u001b[1;32m 342\u001b[0m tensordict_out: Optional[_TensorDict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 343\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 344\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _TensorDict:\n\u001b[1;32m 345\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(tensordict\u001b[38;5;241m.\u001b[39mget(in_key, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m in_key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_keys)\n\u001b[0;32m--> 346\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensors, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 348\u001b[0m tensors \u001b[38;5;241m=\u001b[39m (tensors,)\n", - "File \u001b[0;32m~/Documents/pytorch/rl/torchrl/modules/tensordict_module/common.py:333\u001b[0m, in \u001b[0;36mTensorDictModule._call_module\u001b[0;34m(self, tensors, **kwargs)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(err_msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 328\u001b[0m kwargs_pruned \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 329\u001b[0m key: item\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, item \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbuffers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 332\u001b[0m }\n\u001b[0;32m--> 333\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuffers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_pruned\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py:361\u001b[0m, in \u001b[0;36mvmap..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 359\u001b[0m _check_out_dims_is_int_or_int_pytree(out_dims, func)\n\u001b[1;32m 360\u001b[0m batch_size, flat_in_dims, flat_args, args_spec \u001b[38;5;241m=\u001b[39m _process_batched_inputs(in_dims, args, func)\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_flat_vmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 362\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_in_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs_spec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandomness\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 363\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py:487\u001b[0m, in \u001b[0;36m_flat_vmap\u001b[0;34m(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 486\u001b[0m batched_inputs \u001b[38;5;241m=\u001b[39m _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)\n\u001b[0;32m--> 487\u001b[0m batched_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatched_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)\n\u001b[1;32m 489\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", - "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/.local/lib/python3.9/site-packages/functorch/_src/make_functional.py:282\u001b[0m, in \u001b[0;36mFunctionalModuleWithBuffers.forward\u001b[0;34m(self, params, buffers, *args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m old_state \u001b[38;5;241m=\u001b[39m _swap_state(\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model,\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map,\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28mlist\u001b[39m(params) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(buffers))\n\u001b[1;32m 281\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstateless_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 284\u001b[0m \u001b[38;5;66;03m# Remove the loaded state on self.stateless_model\u001b[39;00m\n\u001b[1;32m 285\u001b[0m _swap_state(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstateless_model, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mall_names_map, old_state)\n", - "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py:1186\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mRuntimeError\u001b[0m: batched == nullptr INTERNAL ASSERT FAILED at \"/private/var/folders/fn/c72nxv0x0xj4bzdgq3b0r86r0000gn/T/pip-req-build-ulv3uwfj/functorch/csrc/Interpreter.cpp\":95, please report a bug to PyTorch. " - ] + "data": { + "text/plain": [ + "torch.Size([10, 5])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -475,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "id": "9dd7846a-f12c-492e-a2ef-b0c67969234d", "metadata": {}, "outputs": [ @@ -576,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "id": "5b6c6035-f9cc-41e7-bf3a-f88936f93b70", "metadata": {}, "outputs": [ @@ -700,7 +707,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "id": "e1f7ba7b", "metadata": {}, "outputs": [], @@ -726,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "eb9775bd", "metadata": {}, "outputs": [], @@ -776,7 +783,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "id": "f902006d-3f89-4ea6-84e0-a193a53e42db", "metadata": {}, "outputs": [], @@ -862,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "id": "1c6c85b5", "metadata": {}, "outputs": [], @@ -969,7 +976,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "id": "7a680452-1462-4ee6-ba04-dce0bb855870", "metadata": {}, "outputs": [ @@ -990,7 +997,7 @@ " is_shared=False)" ] }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } From a5acde6b366f37c079a14aa024db6f6dd4dbc241 Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 21 Jul 2022 15:47:05 +0100 Subject: [PATCH 19/21] Added benchmark --- tutorials/README.md | 6 +- tutorials/src/transformer.py | 97 +++++++++ tutorials/tensordictmodule.ipynb | 326 +++++++++++++++++++++++++------ 3 files changed, 369 insertions(+), 60 deletions(-) diff --git a/tutorials/README.md b/tutorials/README.md index a094f78e710..516f8d79540 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -4,10 +4,10 @@ Get a sense of TorchRL functionalities through our tutorials. For an overview of TorchRL, try the [TorchRL demo](demo.ipynb). -Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict +Make sure you test the [TensorDict tutorial](tensordict.ipynb) to see what TensorDict is about and what it can do. -To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule demo](tensordictmodule.ipynb). +To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](tensordictmodule.ipynb). -Checkout the [environment demo](envs.ipynb) for a deep dive in the envs +Checkout the [environment tutorial](envs.ipynb) for a deep dive in the envs functionalities. diff --git a/tutorials/src/transformer.py b/tutorials/src/transformer.py index 3cfd8d307a5..352f08b9f44 100644 --- a/tutorials/src/transformer.py +++ b/tutorials/src/transformer.py @@ -65,3 +65,100 @@ def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): def forward(self, X): return self.FFN(X) + + +class AttentionBlock(nn.Module): + def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim) + self.split_heads = SplitHeads(num_heads) + self.attention = Attention(latent_dim, to_dim) + self.skip = SkipLayerNorm(to_len, to_dim) + + def forward(self, X_to, X_from): + Q, K, V = self.tokens_to_qkv(X_to, X_from) + Q, K, V = self.split_heads(Q, K, V) + out, attention = self.attention(Q, K, V) + out = self.skip(X_to, out) + return out + + +class EncoderTransformerBlock(nn.Module): + def __init__(self, to_dim, to_len, latent_dim, num_heads): + super().__init__() + self.attention_block = AttentionBlock( + to_dim, to_len, to_dim, latent_dim, num_heads + ) + self.FFN = FFN(to_dim, 4 * to_dim) + self.skip = SkipLayerNorm(to_len, to_dim) + + def forward(self, X_to): + X_to = self.attention_block(X_to, X_to) + X_out = self.FFN(X_to) + return self.skip(X_out, X_to) + + +class DecoderTransformerBlock(nn.Module): + def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.attention_block = AttentionBlock( + to_dim, to_len, from_dim, latent_dim, num_heads + ) + self.encoder_block = EncoderTransformerBlock( + to_dim, to_len, latent_dim, num_heads + ) + + def forward(self, X_to, X_from): + X_to = self.attention_block(X_to, X_from) + X_to = self.encoder_block(X_to) + return X_to + + +class TransformerEncoder(nn.Module): + def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads): + super().__init__() + self.encoder = nn.ModuleList( + [ + EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads) + for i in range(num_blocks) + ] + ) + + def forward(self, X_to): + for i in range(len(self.encoder)): + X_to = self.encoder[i](X_to) + return X_to + + +class TransformerDecoder(nn.Module): + def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads): + super().__init__() + self.decoder = nn.ModuleList( + [ + DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads) + for i in range(num_blocks) + ] + ) + + def forward(self, X_to, X_from): + for i in range(len(self.decoder)): + X_to = self.decoder[i](X_to, X_from) + return X_to + + +class Transformer(nn.Module): + def __init__( + self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads + ): + super().__init__() + self.encoder = TransformerEncoder( + num_blocks, to_dim, to_len, latent_dim, num_heads + ) + self.decoder = TransformerDecoder( + num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads + ) + + def forward(self, X_to, X_from): + X_to = self.encoder(X_to) + X_out = self.decoder(X_from, X_to) + return X_out diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index b0534ce46a6..24b9c77452f 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -194,7 +194,7 @@ "metadata": {}, "source": [ "### Example 3: Multiple outputs\n", - "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values values, one must register them in the `out_keys` keyword argument of the constructor." + "Similarly, `TensorDictModule` not only supports multiple inputs but also multiple outputs. To make a `TensorDictModule` instance write to multiple output values, one must register them in the `out_keys` keyword argument of the constructor." ] }, { @@ -254,8 +254,11 @@ "id": "859630c3", "metadata": {}, "source": [ - "As we shown previously,`TensorDictModule` can take any `nn.Module` and perform the operations on a `TensorDict`. When having multiple input keys and output keys, make sure they match the order in the module.\n", - "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. Unless a `vmap` operator is used, the `TensorDict` is modified in-place." + "When having multiple input keys and output keys, make sure they match the order in the module.\n", + "\n", + "`TensorDictModule` can work with `TensorDict` instances that contain more tensors than what the `in_keys` attribute indicates. \n", + "\n", + "Unless a `vmap` operator is used, the `TensorDict` is modified in-place." ] }, { @@ -316,7 +319,7 @@ "id": "e2718a12", "metadata": {}, "source": [ - "`TensorDictModule` comes with its own `make_functional_with_buffers` method to make it functional." + "`TensorDictModule` comes with its own `make_functional_with_buffers` method to make it functional (you should not be using `functorch.make_functional_with_buffers(tensordictmodule)`, that will not work in general)." ] }, { @@ -360,7 +363,7 @@ "id": "50ac0393", "metadata": {}, "source": [ - "We can also use `vmap`. We can do model ensembling with it." + "We can also use the `vmap` operator, here's an example of model ensembling with it:" ] }, { @@ -370,45 +373,24 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([10, 5])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "the output tensordict shape is: torch.Size([10, 5])\n" + ] } ], "source": [ - "num_models = 10\n", - "\n", "tensordict = TensorDict({\"a\": torch.randn(5, 3)}, batch_size=[5])\n", - "\n", - "splitlinear_models = [\n", - " TensorDictModule(\n", - " nn.Linear(3, 10), in_keys=[\"a\"], out_keys=[\"output\"]\n", + "num_models = 10\n", + "model = TensorDictModule(\n", + " nn.Linear(3, 4), in_keys=[\"a\"], out_keys=[\"output\"]\n", " )\n", - " for _ in range(num_models)\n", - "]\n", - "\n", - "\n", - "### Let's extract the functional version of the common model\n", - "func = splitlinear_models[0].make_functional_with_buffers()[0]\n", - "\n", - "### We also extract parameters and buffers for every block and stack them together\n", - "params, buffers = zip(\n", - " *[\n", - " splitlinear.make_functional_with_buffers()[1]\n", - " for splitlinear in splitlinear_models\n", - " ]\n", - ")\n", - "## For simplicity we reinit the params. In a real application you need to stack all params\n", - "params = [torch.randn(num_models, *p.shape) for p in params[0]]\n", - "buffers = [torch.randn(num_models, *b.shape) for b in buffers[0]]\n", - "\n", - "### Finally we can apply the function to the stack of params and buffer using vmap\n", - "func(tensordict, params=params, buffers=buffers, vmap=True).shape" + "fmodel, (params, buffers) = model.make_functional_with_buffers()\n", + "params = [torch.randn(num_models, *p.shape, device=p.device) for p in params]\n", + "buffers = [torch.randn(num_models, *b.shape, device=b.device) for b in buffers]\n", + "result_td = fmodel(tensordict, params=params, buffers=buffers, vmap=True)\n", + "print(\"the output tensordict shape is: \", result_td.shape)" ] }, { @@ -420,20 +402,16 @@ "\n", "Don't use `nn.Module` wrappers with `TensorDictModule` componants. This would break some of `TensorDictModule` features such as `functorch` compatibility. \n", "\n", - "Don't use `nn.Sequence`, similar to nn.Module, it would break features such as `functorch` compatibility.\n", + "Don't use `nn.Sequence`, similar to `nn.Module`, it would break features such as `functorch` compatibility. Do use `TensorDictSequence` instead.\n", "\n", - "Do use `TensorDictSequence` instead.\n", - "\n", - "Don't use a different name for the output of a `TensorDictModule` as the output tensordict is just the input modified in-place:\n", + "Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place:\n", "\n", "```python\n", "tensordict = module(tensordict) # ok!\n", "tensordict_out = module(tensordict) # don't!\n", "```\n", "\n", - "Don't use `make_functional_with_buffers` from `functorch` directly.\n", - "\n", - "Do use `TensorDictModule.make_functional_with_buffers` instead.\n", + "Don't use `make_functional_with_buffers` from `functorch` directly but use `TensorDictModule.make_functional_with_buffers` instead.\n", "\n", "\n" ] @@ -451,7 +429,7 @@ "id": "8d49a911-933c-476f-8c9a-00e006ed043c", "metadata": {}, "source": [ - "In the context of RL torchrl offers a few wrappers on `TensorDictModule`" + "TorchRL provides a few RL-specific `TensorDictModule` instances that serves domain-specific needs." ] }, { @@ -469,15 +447,17 @@ "id": "fea4eead-47b4-4029-a8ff-e3c3faf51b0f", "metadata": {}, "source": [ - "`ProbabilisticTDModule` is a special case of a `TensorDictModule` where the output is\n", + "`ProbabilisticTensorDictModule` is a special case of a `TensorDictModule` where the output is\n", "sampled given some rule, specified by the input `default_interaction_mode`\n", - "argument and the `exploration_mode()` global function.\n", + "argument and the `exploration_mode()` global function. If they conflict, the context manager precedes.\n", "\n", "It consists in a wrapper around another `TensorDictModule` that returns a tensordict\n", "updated with the distribution parameters. `ProbabilisticTensorDictModule` is\n", "responsible for constructing the distribution (through the `get_dist()` method)\n", "and/or sampling from this distribution (through a regular `__call__()` to the\n", - "module)." + "module).\n", + "\n", + "One can find the parameters in the output tensordict as well as the log probability if needed" ] }, { @@ -576,8 +556,8 @@ "id": "8cc42407-4e95-4bf0-8901-5d1a4e3b2044", "metadata": {}, "source": [ - "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`.\n", - "and wraps both and actor network and a value Network. \n", + "Similarly, `ActorCriticOperator` inherits from `TensorDictSequence`and wraps both an actor network and a value Network.\n", + "\n", "`ActorCriticOperator` will first compute the action from the actor and then the value according to this action." ] }, @@ -609,7 +589,7 @@ " batch_size=torch.Size([3]),\n", " device=cpu,\n", " is_shared=False)\n", - "TensorDict(\n", + "Policy: TensorDict(\n", " fields={\n", " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", @@ -620,7 +600,7 @@ " batch_size=torch.Size([3]),\n", " device=cpu,\n", " is_shared=False)\n", - "TensorDict(\n", + "Critic: TensorDict(\n", " fields={\n", " action: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", " hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),\n", @@ -681,9 +661,23 @@ "td_clone = td_module(td.clone())\n", "print(td_clone)\n", "td_clone = td_module.get_policy_operator()(td.clone())\n", - "print(td_clone) # no value\n", + "print(f\"Policy: {td_clone}\") # no value\n", "td_clone = td_module.get_critic_operator()(td.clone())\n", - "print(td_clone) # no action" + "print(f\"Critic: {td_clone}\") # no action" + ] + }, + { + "cell_type": "markdown", + "id": "11d0f8ea-0292-4ca0-9460-2a2149f7aeef", + "metadata": {}, + "source": [ + "Other blocks exist such as:\n", + "\n", + "The `ValueOperator` which is a general class for value functions in RL.\n", + "\n", + "the `ActorCriticWrapper` which wraps together an actor and a value model that do not share a common observation embedding network.\n", + "\n", + "The `ActorValueOperator` which wraps together an actor and a value model that share a common observation embedding network." ] }, { @@ -939,6 +933,7 @@ " to_dim,\n", " to_len,\n", " from_dim,\n", + " from_len,\n", " latent_dim,\n", " num_heads,\n", " ):\n", @@ -1010,6 +1005,7 @@ "from_len = 10\n", "batch_size = 8\n", "num_heads = 2\n", + "num_blocks = 6\n", "\n", "tokens = TensorDict(\n", " {\n", @@ -1020,12 +1016,13 @@ ")\n", "\n", "transformer = TransformerTensorDict(\n", - " 6,\n", + " num_blocks,\n", " \"X_encode\",\n", " \"X_decode\",\n", " to_dim,\n", " to_len,\n", " from_dim,\n", + " from_len,\n", " latent_dim,\n", " num_heads,\n", ")\n", @@ -1042,6 +1039,221 @@ "We've achieved to create a transformer with `TensorDictModule`. This shows that `TensorDictModule`is a flexible module that can implement complex operarations" ] }, + { + "cell_type": "markdown", + "id": "bb30fb1b-ef8f-4638-af44-69374dd9cfe9", + "metadata": {}, + "source": [ + "### Benchmarking" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f75eb50b-b5c4-47ef-9e33-4fa6dfb489ba", + "metadata": {}, + "outputs": [], + "source": [ + "from tutorials.src.transformer import Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c4ff0abf-1f01-45bd-9dfc-cd26374137c7", + "metadata": {}, + "outputs": [], + "source": [ + "to_dim = 5\n", + "from_dim = 6\n", + "latent_dim = 10\n", + "to_len = 3\n", + "from_len = 10\n", + "batch_size = 8\n", + "num_heads = 2\n", + "num_blocks = 6" + ] + }, + { + "cell_type": "markdown", + "id": "6dd7805a-8f4d-4cab-a723-f230a510a96a", + "metadata": {}, + "source": [ + "#### Init of tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3e08ff04-1086-4315-bf5e-caa960183c94", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 219 µs, sys: 128 µs, total: 347 µs\n", + "Wall time: 217 µs\n" + ] + } + ], + "source": [ + "%%time\n", + "td_tokens = TensorDict(\n", + " {\n", + " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n", + " \"X_decode\": torch.randn(batch_size, from_len, from_dim),\n", + " },\n", + " batch_size=[batch_size],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "665c4168-9ac8-45e5-98bc-6e5cc511a209", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 96 µs, sys: 21 µs, total: 117 µs\n", + "Wall time: 91.1 µs\n" + ] + } + ], + "source": [ + "%%time\n", + "X_encode = torch.randn(batch_size, to_len, to_dim)\n", + "X_decode = torch.randn(batch_size, from_len, from_dim)" + ] + }, + { + "cell_type": "markdown", + "id": "53ba793f-af3d-4e09-89a2-72b3f8964158", + "metadata": {}, + "source": [ + "#### Init of models" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f3c2fd50-bc9b-4911-bd7c-8f8f03bd4ea4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5.82 ms, sys: 1.11 ms, total: 6.93 ms\n", + "Wall time: 6.11 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "tdtransformer = TransformerTensorDict(\n", + " num_blocks,\n", + " \"X_encode\",\n", + " \"X_decode\",\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dfbadd6b-7847-4399-9b22-7e5c58524334", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 40.2 ms, sys: 1.59 ms, total: 41.8 ms\n", + "Wall time: 40.9 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "transformer = Transformer(\n", + " num_blocks,\n", + " to_dim,\n", + " to_len,\n", + " from_dim,\n", + " from_len,\n", + " latent_dim,\n", + " num_heads\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6a63de8f-ee8e-4ddf-bf89-f72c2896e1c3", + "metadata": { + "tags": [] + }, + "source": [ + "#### Inference time" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "02a4116b-2b75-47fc-8bc1-3903aa7cd504", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.35 ms, sys: 6.95 ms, total: 13.3 ms\n", + "Wall time: 7.87 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "tokens = tdtransformer(td_tokens)\n", + "#" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "40158aab-b53a-4a99-82cb-f5595eef7159", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5.5 ms, sys: 8.56 ms, total: 14.1 ms\n", + "Wall time: 7.18 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "X_out = transformer(X_encode, X_decode)" + ] + }, + { + "cell_type": "markdown", + "id": "664adff3-1466-47c3-9a80-a0f26171addd", + "metadata": {}, + "source": [ + "We can see on this minimal example that the overhead introduced by `TensorDictModule` is minimal." + ] + }, { "cell_type": "markdown", "id": "bd08362a-8bb8-49fb-8038-1a60c5c01ea2", From 52172ba5753617b22b96f47e097b245912ef2d1c Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Thu, 21 Jul 2022 16:45:40 +0100 Subject: [PATCH 20/21] Warning clean-up --- tutorials/tensordictmodule.ipynb | 45 ++++++++++---------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 24b9c77452f..16a1ea4b872 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -39,16 +39,7 @@ "execution_count": 1, "id": "5b0241ab", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -92,14 +83,6 @@ " device=cpu,\n", " is_shared=False)\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ndufour/anaconda3/envs/torch_rl/lib/python3.9/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] } ], "source": [ @@ -1092,8 +1075,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 219 µs, sys: 128 µs, total: 347 µs\n", - "Wall time: 217 µs\n" + "CPU times: user 399 µs, sys: 387 µs, total: 786 µs\n", + "Wall time: 469 µs\n" ] } ], @@ -1118,8 +1101,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 96 µs, sys: 21 µs, total: 117 µs\n", - "Wall time: 91.1 µs\n" + "CPU times: user 295 µs, sys: 312 µs, total: 607 µs\n", + "Wall time: 412 µs\n" ] } ], @@ -1147,8 +1130,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5.82 ms, sys: 1.11 ms, total: 6.93 ms\n", - "Wall time: 6.11 ms\n" + "CPU times: user 41.6 ms, sys: 1.91 ms, total: 43.5 ms\n", + "Wall time: 42.6 ms\n" ] } ], @@ -1177,8 +1160,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 40.2 ms, sys: 1.59 ms, total: 41.8 ms\n", - "Wall time: 40.9 ms\n" + "CPU times: user 6.14 ms, sys: 1.75 ms, total: 7.89 ms\n", + "Wall time: 6.68 ms\n" ] } ], @@ -1215,8 +1198,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 6.35 ms, sys: 6.95 ms, total: 13.3 ms\n", - "Wall time: 7.87 ms\n" + "CPU times: user 7.72 ms, sys: 8.6 ms, total: 16.3 ms\n", + "Wall time: 9.65 ms\n" ] } ], @@ -1236,8 +1219,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5.5 ms, sys: 8.56 ms, total: 14.1 ms\n", - "Wall time: 7.18 ms\n" + "CPU times: user 6.02 ms, sys: 9.13 ms, total: 15.2 ms\n", + "Wall time: 7.83 ms\n" ] } ], @@ -1265,7 +1248,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, From 45ccd8b63767a680a34811991eb10c9edf69ee6e Mon Sep 17 00:00:00 2001 From: Nicolas Dufour Date: Fri, 22 Jul 2022 10:45:30 +0100 Subject: [PATCH 21/21] Made suggested changes --- tutorials/tensordictmodule.ipynb | 74 ++++---------------------------- 1 file changed, 9 insertions(+), 65 deletions(-) diff --git a/tutorials/tensordictmodule.ipynb b/tutorials/tensordictmodule.ipynb index 16a1ea4b872..71b2baf68d0 100644 --- a/tutorials/tensordictmodule.ipynb +++ b/tutorials/tensordictmodule.ipynb @@ -1057,31 +1057,13 @@ "num_blocks = 6" ] }, - { - "cell_type": "markdown", - "id": "6dd7805a-8f4d-4cab-a723-f230a510a96a", - "metadata": {}, - "source": [ - "#### Init of tokens" - ] - }, { "cell_type": "code", "execution_count": 19, "id": "3e08ff04-1086-4315-bf5e-caa960183c94", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 399 µs, sys: 387 µs, total: 786 µs\n", - "Wall time: 469 µs\n" - ] - } - ], + "outputs": [], "source": [ - "%%time\n", "td_tokens = TensorDict(\n", " {\n", " \"X_encode\": torch.randn(batch_size, to_len, to_dim),\n", @@ -1096,47 +1078,19 @@ "execution_count": 20, "id": "665c4168-9ac8-45e5-98bc-6e5cc511a209", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 295 µs, sys: 312 µs, total: 607 µs\n", - "Wall time: 412 µs\n" - ] - } - ], + "outputs": [], "source": [ - "%%time\n", "X_encode = torch.randn(batch_size, to_len, to_dim)\n", "X_decode = torch.randn(batch_size, from_len, from_dim)" ] }, - { - "cell_type": "markdown", - "id": "53ba793f-af3d-4e09-89a2-72b3f8964158", - "metadata": {}, - "source": [ - "#### Init of models" - ] - }, { "cell_type": "code", "execution_count": 21, "id": "f3c2fd50-bc9b-4911-bd7c-8f8f03bd4ea4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 41.6 ms, sys: 1.91 ms, total: 43.5 ms\n", - "Wall time: 42.6 ms\n" - ] - } - ], + "outputs": [], "source": [ - "%%time\n", "tdtransformer = TransformerTensorDict(\n", " num_blocks,\n", " \"X_encode\",\n", @@ -1155,18 +1109,8 @@ "execution_count": 22, "id": "dfbadd6b-7847-4399-9b22-7e5c58524334", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 6.14 ms, sys: 1.75 ms, total: 7.89 ms\n", - "Wall time: 6.68 ms\n" - ] - } - ], + "outputs": [], "source": [ - "%%time\n", "transformer = Transformer(\n", " num_blocks,\n", " to_dim,\n", @@ -1198,8 +1142,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 7.72 ms, sys: 8.6 ms, total: 16.3 ms\n", - "Wall time: 9.65 ms\n" + "CPU times: user 6.25 ms, sys: 6.73 ms, total: 13 ms\n", + "Wall time: 7.57 ms\n" ] } ], @@ -1219,8 +1163,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 6.02 ms, sys: 9.13 ms, total: 15.2 ms\n", - "Wall time: 7.83 ms\n" + "CPU times: user 5.32 ms, sys: 9.24 ms, total: 14.6 ms\n", + "Wall time: 7.15 ms\n" ] } ], @@ -1234,7 +1178,7 @@ "id": "664adff3-1466-47c3-9a80-a0f26171addd", "metadata": {}, "source": [ - "We can see on this minimal example that the overhead introduced by `TensorDictModule` is minimal." + "We can see on this minimal example that the overhead introduced by `TensorDictModule` is marginal." ] }, {