# TensorDict tutorial

TensorDict is a new tensor structure introduced in torchrl. 

With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. TensorDict aims at making it more convenient to deal with multiple tensors at the same time. 

Furthermore, different RL algorithms can deal with different input and outputs. The TensorDict allows to abstract away the differences between these algorithmes. 

TensorDict combines the convinience of using dicts to organize your data with the power of pytorch tensors.


#### Improving the modularity of codes

Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. 

We want to train 2 methods (Algo A on dataset A and algo B on dataset B) that share the same training loop. 

In classical pytorch we would need to do the following:
```python
#Method A
for i in range(optim_steps):
    images, labels = get_data_A()
    loss = loss_module_A(images, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
````

```python
#Method B
for i in range(optim_steps):
    images, masks, labels = get_data_B()
    loss = loss_module_B(images, masks, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
```
We can see that this limits the reusability of code. A lot of code has to be rewriten because of the modality difference between the 2 datasets.
The idea of TensorDict is to do the following:
```python
# General Method
get_data = instantiate(cfg.data)
loss_module = instantiate(cfg.module)
for i in range(optim_steps):
    tensordict = get_data()
    loss = loss_module(tensordict)
    loss.backward()
    optim.step()
    optim.zero_grad()
```


Now we can reuse the same training loop for all methods that we want. We just need to make sure that <code>instantiate(cfg.data)</code> and <code>instantiate(cfg.module)</code> maps to the desired method and data.



#### Can't i do this with a python dict?

One could argue that you could achieve the same results with a dataset that outputs a pytorch dict. 
```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
    ...
    
        return {"modality_A": torch.Tensor(torch.randn(2)), "modality_B": torch.Tensor(torch.randn(2))}
    
```
However to achieve this you would need to write a complicated collate function that make sure that every modality is agregated properly.

```python

def collate_dict_fn(dict_list):
    final_dict = {}
    for key in dict_list[0].keys():
        final_dict[key]= []
        for single_dict in dict_list:
            final_dict[key].append(single_dict[key])
        final_dict[key] = torch.stack(final_dict[key], dim=0)
    return final_dict


dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)

````
With TensorDicts this is now much simpler:

```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
        ...
    
        return TensorDict({"modality_A": torch.Tensor(torch.randn(2)), "modality_B": torch.Tensor(torch.randn(2)), batch_size=[]})
```


Here, the collate function is as simple as:
```python
collate_tensordict_fn = lambda tds : torch.stack(tds, dim=0)

dataloader = Dataloader(DictDataset(), collate_fn = collate_tensordict_fn)
```
This is an exemple of how TensorDict could facilitate such operations.

TensorDict inherits multiple properties from torch tensors that we will detail furtherdown, which make them quite practical.

## TensorDict dictionary features

TensorDict shares a lot of features with python dictionaries

In [1]:
from torchrl.data import TensorDict
import torch

In [2]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


If we want to access a certain key, it is explicit:

In [3]:
print(tensordict["a"])
print(tensordict["a"].shape)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
torch.Size([3, 4, 5])


Also works with get()

In [4]:
print(tensordict.get("a"))
print(tensordict.get("a").shape)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
torch.Size([3, 4, 5])


#### TensorDict.keys()
Keys can be retrieved to TensorDict

In [5]:
for key in tensordict.keys():
    print(key)

a
b


#### TensorDict.values()
The values of a TensorDict can be retrieved with the values() function. On the contrary of python dicts, the values() function return a generator and not a list for memory efficiency reasons. Indeed, python dictionnary are not designed to store tensors which can take a lot of space in memory.

In [6]:
tensordict.values()

<generator object _TensorDict.values at 0x1273dd040>

In [7]:
for value in tensordict.values():
    print(value.shape)

torch.Size([3, 4, 5])
torch.Size([3, 4, 1])


#### TensorDict.set()
The set function can be used to set new values

In [8]:
tensordict.set("c", torch.zeros((3, 4, 2, 2)))
print(f"c is set as {tensordict['c']}")

c is set as tensor([[[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]]])


#### TensorDict.update()
The update function can be used to update the dict with other dict values (Or TensorDict)

In [9]:
tensordict.update({"a":torch.ones((3, 4, 5)), "d":torch.ones((3, 4, 2))})
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)), "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
print(f"a is now {tensordict['a']}")
print(f"d is set as {tensordict['d']}")

a is now tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])
d is set as tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])


#### TensorDict del key
TensorDict also support keys deletion with the del operator:

In [10]:
del tensordict["c"]
print(tensordict.keys())

dict_keys(['a', 'b', 'd'])


## TensorDict as a pytorch Tensor

But wait? Can't we do this with a classical dict? 
Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.
TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.

TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:
- device;
- memory location (shared, memory-mapped array, ...);
- batch size (i.e. n^th first dimensions).

In [11]:
from torchrl.data import TensorDict
import torch

In [12]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)
print(f"Our Tensor dict is of size {tensordict.shape}")

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)
Our Tensor dict is of size torch.Size([3, 4])


#### Batch size

Tensor dict has a batch size which is shared across all tensors. The batch size can be [], unidimensional or multidimensional according to your needs.

In [13]:
print(f"Our Tensor dict is of size {tensordict.shape}")

Our Tensor dict is of size torch.Size([3, 4])


You cannot have items that don't share the batch size inside the same TensorDict:

In [14]:
tensordict.update({"c": torch.zeros(4, 3, 1)})

RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])

In [None]:
tensordict.batch_size = [4,4]

#### Devices

TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device

#### Memory sharing

When on cpu, you can use either `TensorDict.memmap_()` or `TensorDict.share_memory_()` to setup you tensor dict as a memmap or send it to shared memory resp.

#### Cloning
TensorDict supports cloning. Cloning returns the same SubTensorDict item than the original item.

In [15]:
tensordict_clone = tensordict.clone()
tensordict_clone["a"] = torch.ones(*tensordict.shape, 5)
print(tensordict["a"])
print(tensordict_clone["a"])

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])


### Tensor operations
We can perform tensor operations among the batch dimensions:

#### Slicing and indexing
Slicing and indexing is supported among the batch dimension

In [16]:
tensordict[0]

TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [17]:
tensordict[1:]

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 4]),
    device=cpu,
    is_shared=False)

In [18]:
tensordict[:, 2:]

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 2]),
    device=cpu,
    is_shared=False)

#### Setting values with indexing
We can also edit certain tensor features by deliminting certain indexes:

In [19]:
td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
tensordict["a"], tensordict["b"]

(tensor([[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]],
 
         [[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]),
 tensor([[[1.],
          [1.],
          [1.],
          [1.]],
 
         [[1.],
          [1.],
          [1.],
          [1.]],
 
         [[0.],
          [0.],
          [0.],
          [0.]]]))

#### Masking

We can perform masking on the indexes. Mask must be a tensor.

In [20]:
mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]).bool()
tensordict[mask]

TensorDict(
    fields={
        a: Tensor(torch.Size([6, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([6, 1]), dtype=torch.float32)},
    batch_size=torch.Size([6]),
    device=cpu,
    is_shared=False)

TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\*shape), squeeze(dim), unsqueeze(dim), permute(\*dims) requiring the operations to comply with the batch_size

#### View

Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict

In [21]:
tensordict.view(-1)

ViewedTensorDict(
	source=TensorDict(
	    fields={
	        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
	        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
	    batch_size=torch.Size([3, 4]),
	    device=cpu,
	    is_shared=False), 
	op=view(size=torch.Size([-1])))

#### Permute

In [22]:
tensordict.permute(1,0)

PermutedTensorDict(
	source=TensorDict(
	    fields={
	        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
	        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
	    batch_size=torch.Size([3, 4]),
	    device=cpu,
	    is_shared=False), 
	op=permute(dims=(1, 0)))

#### Reshape
Reshape allows reshaping the tensordict batch size

In [23]:
tensordict.reshape(-1)

TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

#### Unbind and Cat
TensorDict can unbind and cat among a dim over the tensordict batch size

In [24]:
#Cat
list_tensordict = tensordict.unbind(0)
print(list_tensordict)
torch.cat(list_tensordict, dim=0)

(TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False), TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False), TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False))


TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

#### Squeeze and Unsqueeze
Tensordict also supports squeeze and unsqueeze. Use `to_tensordict` to retrieve a tensordict

In [25]:
print(tensordict.unsqueeze(0).to_tensordict())
print(tensordict.squeeze(0).to_tensordict())

TensorDict(
    fields={
        a: Tensor(torch.Size([1, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([1, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([1, 3, 4]),
    device=cpu,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


#### Stacking

TensorDict supports stacking, stacking is done in a lazy fashion, returning a LazyStackedTensorDict item.

In [26]:
#Stack
staked_tensordict = torch.stack([tensordict, tensordict.clone()], dim=0)
print(staked_tensordict)
if staked_tensordict[0] is tensordict:
    print("every tensordict is awesome!")

LazyStackedTensorDict(
    fields={
        a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=cpu,
    is_shared=False)
every tensordict is awesome!


If we want to have a contiguous tensordict, we can call `.to_tensordict()` or `.contiguous()`. It is recommended to perform this operation before accessing the values of the stacked tensordict for efficiency purposes

In [27]:
print(staked_tensordict.contiguous())
print(staked_tensordict.to_tensordict())

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=cpu,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=cpu,
    is_shared=False)


## How to use them in practice? The tensor the TensorDictModule

Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key.

In [28]:
from torchrl.modules import TensorDictModule
import torch.nn as nn

### Example: Simple Linear layer

Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a.

In [29]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace

In [30]:
tensordict = TensorDict({"a":torch.randn(5, 3), "b":torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 2 input merging with 2 linear layer

Now lets imagine a more complex network that takes 2 entries and average them into a single output

In [31]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out)
        self.linear_2  = nn.Linear(in_2,out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [32]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3), "c":torch.randn(5,4)}, batch_size=[5])
mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=["a","c"], out_keys=["output"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 1 input to 2 outputs linear layer
We can also map to multiple outputs

In [33]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out_1)
        self.linear_2  = nn.Linear(in_1,out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [34]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=["a"], out_keys=["output_1", "output_2"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.
The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key.

### Example: A transformer with TensorDict?
Let's attempt to create a transformer with TensorDict and TensorDictModule

Disclaimer: This implementation don't claim to be "better" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.
For simplicity we will not have positional encoders.

Let's first implement the classical transformers blocks

In [35]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2,3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block.

In [36]:
class TransformerBlockTensorDict(nn.Module):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer_block = nn.Sequential(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=["X_to"], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
        )
    def forward(self, X_tensor_dict):
        self.transformer_block(X_tensor_dict)
        

In [37]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])

transformer_block = TransformerBlockTensorDict("X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [38]:
tokens["X_to"]

tensor([[[ 1.3338,  0.5568,  1.5257, -0.9140, -0.7399],
         [-1.1602, -1.1171, -0.0355, -0.3293,  0.5658],
         [ 0.0528,  1.3808, -1.1389, -1.2089,  1.2281]],

        [[ 1.0998, -0.2481, -0.5027,  0.3425, -1.3240],
         [-1.2873, -2.0554,  1.1529, -0.5357, -0.2973],
         [ 1.0694,  0.3098,  0.3401,  0.3393,  1.5966]],

        [[ 0.2762, -0.3396, -1.6140,  0.0866,  1.0148],
         [-1.0849, -2.2679,  0.1260,  1.2162,  1.4198],
         [ 1.0519,  0.2551, -0.3484,  0.0038,  0.2044]],

        [[-1.6978, -1.0813, -0.9879, -0.1534, -0.8177],
         [ 1.0584, -1.3029, -0.5820,  0.4057,  1.6585],
         [ 0.1717,  0.3136,  0.8307,  1.1407,  1.0436]],

        [[ 0.9372, -1.2179, -0.5154, -1.0837,  0.3776],
         [-2.3519, -0.2721, -0.1398,  1.7179, -0.5435],
         [ 0.3842,  0.8391,  0.4074,  0.9533,  0.5076]],

        [[-0.0587, -0.1878,  0.5516,  0.0882,  0.9291],
         [ 1.1866, -1.2275, -0.3984, -0.0310, -0.4985],
         [-1.1856, -1.0034,  2.6606, -

We can now create a transformer easily

In [39]:
class TransformerTensorDict(nn.Module):
    def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
    def forward(self, X_tensor_dict):
        for transformer_block in self.transformer:
            transformer_block(X_tensor_dict)

In [40]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])




For an encoder, we just need to take the same tokens for both queries, keys and values.

In [41]:
transformer_encoder = TransformerTensorDict(6, "X_to", "X_to", to_dim, to_len, to_dim, latent_dim, num_heads)

transformer_encoder(tokens)
tokens["X_to"]

tensor([[[-2.5683e-01,  4.5024e-01,  9.4612e-01, -1.1000e+00,  6.6867e-01],
         [-6.6982e-01,  8.6136e-01,  1.5597e+00, -1.0453e+00,  1.2007e+00],
         [-6.9352e-01, -3.2060e-02, -4.3720e-01, -2.1913e+00,  7.3923e-01]],

        [[ 9.1422e-01, -2.7420e-01, -2.1522e-01, -1.1503e+00,  1.6155e+00],
         [ 3.1163e-01, -7.0292e-02,  1.0830e+00, -2.3652e+00,  1.4391e-01],
         [ 3.9072e-01, -3.0454e-01,  8.5802e-01, -1.4170e+00,  4.7974e-01]],

        [[-7.3433e-01,  6.3703e-01,  4.6386e-01, -1.3121e+00,  6.5461e-01],
         [-3.1634e-01,  2.7175e-01,  4.1725e-01, -6.2716e-01,  6.5264e-01],
         [-1.8677e+00,  1.3871e+00,  9.4117e-01, -1.7092e+00,  1.1414e+00]],

        [[-1.9929e-01,  7.1908e-01,  1.1399e+00, -2.0360e+00,  9.4091e-01],
         [ 1.8042e-01,  8.9284e-03,  5.7997e-01, -6.0130e-01,  1.1606e+00],
         [-3.7275e-01,  9.1703e-01, -1.2281e-01, -2.2440e+00, -7.0603e-02]],

        [[-7.2842e-01, -5.2198e-01, -8.3247e-01, -1.5439e-01, -3.7558e-01],
    

For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values.

In [42]:
transformer_decoder = TransformerTensorDict(6, "X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_decoder(tokens)
tokens["X_to"]

tensor([[[ 0.2650,  0.8114,  1.6235, -1.8246, -0.1827],
         [-0.5331,  0.8769,  1.4354, -0.7027,  0.0041],
         [-0.0240,  0.0151,  0.7448, -1.9483, -0.5608]],

        [[ 0.6278,  0.2316,  0.6965, -1.4007,  0.4978],
         [ 0.5148,  0.1977,  1.5761, -2.1702, -0.3718],
         [ 0.5269,  0.1489,  1.0460, -1.5976, -0.5239]],

        [[-0.2073,  1.7283,  0.7957, -1.2614, -0.3906],
         [-0.2481,  1.4483,  0.3833, -0.3303, -0.3051],
         [-0.7997,  1.4176,  0.4302, -1.9264, -0.7345]],

        [[ 0.4208,  0.4412,  1.3680, -2.0761, -0.3130],
         [ 0.6140,  0.1374,  1.0931,  0.1626,  0.0274],
         [-0.4122,  0.9318,  0.7548, -1.8592, -1.2908]],

        [[-0.7726,  0.0486,  0.0335, -0.0701, -0.9185],
         [-0.2630,  1.0334,  1.7184, -2.6263,  0.0084],
         [-0.5563,  1.1644,  0.8858, -0.0215,  0.3358]],

        [[ 1.0746,  0.6018,  0.5270, -0.9513,  0.1145],
         [-0.6757,  1.6353,  1.4940, -0.1454,  0.1573],
         [ 0.0575, -1.1170, -1.3292, -

Now we can look at both models:

In [43]:
transformer_encoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=5, out_features=10, bias=True)
              (v): Linear(in_features=5, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_to'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictMod

In [44]:
transformer_decoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=6, out_features=10, bias=True)
              (v): Linear(in_features=6, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_from'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictM