# TensorDict tutorial

`TensorDict` is a new tensor structure introduced in TorchRL. 

With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. `TensorDict` aims at making it more convenient to deal with multiple tensors at the same time. 

Furthermore, different RL algorithms can deal with different input and outputs. The `TensorDict` class makes it possible to abstract away the differences between these algorithmes. 

TensorDict combines the convinience of using `dict`s to organize your data with the power of pytorch tensors.


#### Improving the modularity of codes

Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. 

Suppose we want to train a common algorithm over these two datasets (i.e. an algorithm that would ignore the mask or infer it when needed). 

In classical pytorch we would need to do the following:
```python
#Method A
for i in range(optim_steps):
    images, labels = get_data_A()
    loss = loss_module(images, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
````

```python
#Method B
for i in range(optim_steps):
    images, masks, labels = get_data_B()
    loss = loss_module(images, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
```

We can see that this limits the reusability of code. A lot of code has to be rewriten because of the modality difference between the 2 datasets.
The idea of TensorDict is to do the following:

```python
# General Method
for i in range(optim_steps):
    tensordict = get_data()
    loss = loss_module(tensordict)
    loss.backward()
    optim.step()
    optim.zero_grad()
```


Now we can reuse the same training loop across datasets and losses.

#### Can't i do this with a python dict?

One could argue that you could achieve the same results with a dataset that outputs a pytorch dict. 
```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
    ...
    
        return {"modality_A": torch.Tensor(torch.randn(2)), "modality_B": torch.Tensor(torch.randn(2))}
    
```

However to achieve this you would need to write a complicated collate function that make sure that every modality is agregated properly.

```python

def collate_dict_fn(dict_list):
    final_dict = {}
    for key in dict_list[0].keys():
        final_dict[key]= []
        for single_dict in dict_list:
            final_dict[key].append(single_dict[key])
        final_dict[key] = torch.stack(final_dict[key], dim=0)
    return final_dict


dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)

````
With TensorDicts this is now much simpler:

```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
        ...
    
        return TensorDict({"modality_A": torch.Tensor(torch.randn(2)), "modality_B": torch.Tensor(torch.randn(2)), batch_size=[]})
```


Here, the collate function is as simple as:
```python
collate_tensordict_fn = lambda tds : torch.stack(tds, dim=0)

dataloader = Dataloader(DictDataset(), collate_fn = collate_tensordict_fn)
```

TensorDict inherits multiple properties from `torch.Tensor` and `dict` that we will detail furtherdown.

## `TensorDict` dictionary features

`TensorDict` shares a lot of features with python dictionaries

In [1]:
from torchrl.data import TensorDict
import torch

In [2]:
a = torch.zeros(3, 4, 5)
b = torch.zeros(3, 4)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


### `get(key)`
If we want to access a certain key, we can index the tensordict or alternatively use the `get` method:

In [3]:
print(tensordict["a"] is tensordict.get("a") is a)
print(tensordict["a"].shape)

True
torch.Size([3, 4, 5])


The `get` method also supports default values:

In [4]:
out = tensordict.get("foo", torch.ones(3))
out

tensor([1., 1., 1.])

## `set(key, value)`
The `set()` method can be used to set new values. Regular indexing also does the job:

In [5]:
c = torch.zeros((3, 4, 2, 2))
tensordict.set("c", c)
print(f"td[\"c\"] is c: {c is tensordict['c']}")

d = torch.zeros((3, 4, 2, 2))
tensordict["d"] = d
print(f"td[\"d\"] is d: {d is tensordict['d']}")

td["c"] is c: True
td["d"] is d: True


## Other methods:
### `keys`
We can access the keys of a tensordict:

In [6]:
for key in tensordict.keys():
    print(key)

a
b
c
d


### `values`
The values of a `TensorDict` can be retrieved with the `values()` function. Note that, unlike python `dict`s, the `values()` method returns a generator and not a list.

In [7]:
for value in tensordict.values():
    print(value.shape)

torch.Size([3, 4, 5])
torch.Size([3, 4, 1])
torch.Size([3, 4, 2, 2])
torch.Size([3, 4, 2, 2])


### TensorDict.update()
The `update` method can be used to update a TensorDict with another one (or with a dict):

In [8]:
tensordict.update({"a": torch.ones((3, 4, 5)), "d": 2*torch.ones((3, 4, 2))})
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)), "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
print(f"a is now equal to 1: {(tensordict['a'] == 1).all()}")
print(f"d is now equal to 2: {(tensordict['d'] == 2).all()}")

a is now equal to 1: True
d is now equal to 2: True


### TensorDict del key
TensorDict also support keys deletion with the `del` operator:

In [9]:
del tensordict["c"]
print(tensordict.keys())

dict_keys(['a', 'b', 'd'])


## TensorDict as a Tensor-like object

But wait? Can't we do this with a classical dict? 
Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.
TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.

TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:
- device;
- memory location (shared, memory-mapped array, ...);
- batch size (i.e. n^th first dimensions).

In [10]:
from torchrl.data import TensorDict
import torch

In [11]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


#### Batch size

Tensor dict has a batch size which is shared across all tensors. The batch size can be [], unidimensional or multidimensional according to your needs.

In [13]:
print(f"Our Tensor dict is of size {tensordict.shape}")

Our Tensor dict is of size torch.Size([3, 4])


You cannot have items that don't share the batch size inside the same TensorDict:

In [15]:
# we cannot add tensors that violate the batch size:
try:
    tensordict.update({"c": torch.zeros(4, 3, 1)})
except RuntimeError as err:
    print(f"Caramba! We got this error: {err}")

Caramba! We got this error: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])


In [18]:
# If we reset the batch size, it has to comply with the tensordict batch size
try:
    tensordict.batch_size = [4,4]
except RuntimeError as err:
    print(f"Caramba! We got this error: {err}")

Caramba! We got this error: the tensor a has shape torch.Size([3, 4, 5]) which is incompatible with the new shape torch.Size([4, 4])


#### Devices

### Device
TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device

### Memory sharing via physical memory usage

When on cpu, one can use either `tensordict.memmap_()` or `tensordict.share_memory_()` to send a `tensordict` to represent it as a memory-mapped collection of tensors or put it in shared memory resp.

### Cloning
TensorDict supports cloning. Cloning returns the same SubTensorDict item than the original item.

In [20]:
tensordict_clone = tensordict.clone()
tensordict_clone["a"] = torch.ones(*tensordict.shape, 5)
print("redefining a tensor in the clone does not impact the original tensordict: ", (tensordict["a"] == tensordict_clone["a"]).all())

redefining a tensor in the clone does not impact the original tensordict:  tensor(False)


### Tensor operations
We can perform tensor operations among the batch dimensions:

### Slicing and indexing
Slicing and indexing is supported along the batch dimensions

In [22]:
tensordict[0]

TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [23]:
tensordict[1:]

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 4]),
    device=cpu,
    is_shared=False)

In [24]:
tensordict[:, 2:]

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 2]),
    device=cpu,
    is_shared=False)

#### Setting values with indexing
We can also edit certain tensor features by deliminting certain indexes:

In [25]:
subtd = tensordict[:, torch.tensor([1, 3])]  # a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data
tensordict.fill_("a", -1)
assert (subtd["a"] == -1).all()  # the "a" key-value pair has changed

In [27]:
td2 = TensorDict({"a": torch.zeros(2, 4, 5), "b": torch.zeros(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
tensordict["a"], tensordict["b"]

(tensor([[[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],
 
         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],
 
         [[-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.]]]),
 tensor([[[0.],
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          [0.]]]))

We can set values easily just by indexing the tensordict:

#### Masking

### Masking
We can perform masking on the indexes. Mask must be a tensor.

In [29]:
mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]).bool()
tensordict[mask]

TensorDict(
    fields={
        a: Tensor(torch.Size([6, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([6, 1]), dtype=torch.float32)},
    batch_size=torch.Size([6]),
    device=cpu,
    is_shared=False)

TensorDict support other tensor operations such as torch.cat, reshape, undind(dim), view(\*shape), squeeze(dim), unsqueeze(dim), permute(\*dims) requiring the operations to comply with the batch_size

### View
Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict

In [31]:
tensordict.view(-1)

ViewedTensorDict(
	source=TensorDict(
	    fields={
	        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
	        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
	    batch_size=torch.Size([3, 4]),
	    device=cpu,
	    is_shared=False), 
	op=view(size=torch.Size([-1])))

#### Permute

In [33]:
tensordict.permute(1,0)

PermutedTensorDict(
	source=TensorDict(
	    fields={
	        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
	        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
	    batch_size=torch.Size([3, 4]),
	    device=cpu,
	    is_shared=False), 
	op=permute(dims=(1, 0)))

#### Reshape
Reshape allows reshaping the tensordict batch size

In [35]:
tensordict.reshape(-1)

TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

#### Unbind and Cat
TensorDict can unbind and cat among a dim over the tensordict batch size

In [36]:
#Cat
list_tensordict = tensordict.unbind(0)
torch.cat(list_tensordict, dim=0)

TensorDict(
    fields={
        a: Tensor(torch.Size([12, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([12, 1]), dtype=torch.float32)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)

#### Squeeze and Unsqueeze
Tensordict also supports squeeze and unsqueeze. Use `to_tensordict` to retrieve a tensordict

In [37]:
print(tensordict.unsqueeze(0).to_tensordict())
print(tensordict.squeeze(0).to_tensordict())

TensorDict(
    fields={
        a: Tensor(torch.Size([1, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([1, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([1, 3, 4]),
    device=cpu,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


#### Stacking

TensorDict supports stacking, stacking is done in a lazy fashion, returning a LazyStackedTensorDict item.

In [38]:
#Stack
staked_tensordict = torch.stack([tensordict, tensordict.clone()], dim=0)
print(staked_tensordict)
if staked_tensordict[0] is tensordict and staked_tensordict[0] is not tensordict:
    print("every tensordict is awesome!")

LazyStackedTensorDict(
    fields={
        a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=cpu,
    is_shared=False)


If we want to have a contiguous tensordict, we can call `.to_tensordict()` or `.contiguous()`. It is recommended to perform this operation before accessing the values of the stacked tensordict for efficiency purposes

In [39]:
assert isinstance(staked_tensordict.contiguous(), TensorDict)
assert isinstance(staked_tensordict.to_tensordict(), TensorDict)

## How to use them in practice? The tensor the TensorDictModule

Now that we have seen the TensorDict object, how do we use it in pratice? We introduce the TensorDictModule. The TensorDictModule is an nn.Module that takes a TensorDict in his forward method. The user defines the keys that the module will take as an input and write the output in the same TensorDict at a given set of key.

In [40]:
from torchrl.modules import TensorDictModule
import torch.nn as nn

### Example: Simple Linear layer

Let's imagine we have 2 entries Tensor dict, a and b and we only want to affect a.

In [41]:
tensordict = TensorDict({"a": torch.randn(5, 3), "b": torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a_out"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        a_out: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

We can also do it inplace

In [42]:
tensordict = TensorDict({"a":torch.randn(5, 3), "b":torch.randn(5, 4, 3)}, batch_size=[5])
linear = TensorDictModule(nn.Linear(3, 10),in_keys=["a"], out_keys=["a"])
linear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 10]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 2 input merging with 2 linear layer

Now lets imagine a more complex network that takes 2 entries and average them into a single output

In [43]:
class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out)
        self.linear_2  = nn.Linear(in_2,out)
    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2))/2

In [44]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3), "c":torch.randn(5,4)}, batch_size=[5])
mergelinear = TensorDictModule(MergeLinear(3, 4, 10),in_keys=["a","c"], out_keys=["output"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        c: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

### Example: 1 input to 2 outputs linear layer
We can also map to multiple outputs

In [45]:
class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1  = nn.Linear(in_1,out_1)
        self.linear_2  = nn.Linear(in_1,out_2)
    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

In [46]:
tensordict = TensorDict({"a":torch.randn(5,3), "b":torch.randn(5,4,3)}, batch_size=[5])
mergelinear = TensorDictModule(MultiHeadLinear(3, 4, 10),in_keys=["a"], out_keys=["output_1", "output_2"])
mergelinear(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        b: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
        output_1: Tensor(torch.Size([5, 4]), dtype=torch.float32),
        output_2: Tensor(torch.Size([5, 10]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

As we shown previously, the TensorDictModule can take any nn.Module and perform the operations inside a TensorDict. When having multiple input keys and output keys, make sure they match the order in the module.
The tensordictmodule allows to use only the tensors that we want and keep the output inside the same object. It can even perform the operations inplace by setting the output key to be the same as an already set key.

### Example: A transformer with TensorDict?
Let's attempt to create a transformer with TensorDict and TensorDictModule

Disclaimer: This implementation don't claim to be "better" than a classical tensor-based implementation. It is just meant to showcase the TensorDictModule features.
For simplicity we will not have positional encoders.

Let's first implement the classical transformers blocks

In [47]:
class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)
    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V

class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V
class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)
    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2,3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads*d_in))
        return out, attn
class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))
    def forward(self, x_0, x_1):
        return self.layer_norm(x_0+x_1)
class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate = 0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate)
        )
    def forward(self, X):
        return self.FFN(X)


Now, we can build the TransformerBlock thanks to the TensorDictModule. Since the changes affect the tensor dict, we just need to map outputs to the right name such as it is picked up by the next block.

In [48]:
class TransformerBlockTensorDict(nn.Module):
    def __init__(self, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer_block = nn.Sequential(
            TensorDictModule(TokensToQKV(to_dim, from_dim, latent_dim), in_keys=[to_name, from_name], out_keys=["Q", "K", "V"]),
            TensorDictModule(SplitHeads(num_heads), in_keys=["Q", "K", "V"], out_keys=["Q", "K", "V"]),
            TensorDictModule(Attention(latent_dim, to_dim), in_keys=["Q", "K", "V"], out_keys=["X_out","Attn"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
            TensorDictModule(FFN(to_dim, 4*to_dim), in_keys=["X_to"], out_keys=["X_out"]),
            TensorDictModule(SkipLayerNorm(to_len, to_dim), in_keys=["X_to", "X_out"], out_keys=["X_to"]),
        )
    def forward(self, X_tensor_dict):
        self.transformer_block(X_tensor_dict)
        

In [49]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])

transformer_block = TransformerBlockTensorDict("X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_block(tokens)

tokens

TensorDict(
    fields={
        Attn: Tensor(torch.Size([8, 2, 3, 10]), dtype=torch.float32),
        K: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        Q: Tensor(torch.Size([8, 2, 3, 5]), dtype=torch.float32),
        V: Tensor(torch.Size([8, 2, 10, 5]), dtype=torch.float32),
        X_from: Tensor(torch.Size([8, 10, 6]), dtype=torch.float32),
        X_out: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32),
        X_to: Tensor(torch.Size([8, 3, 5]), dtype=torch.float32)},
    batch_size=torch.Size([8]),
    device=cpu,
    is_shared=False)

The output of the transformer layer can now be found at tokens["X_to"]

In [50]:
tokens["X_to"]

tensor([[[ 1.3948,  0.9014, -0.3433,  0.7251, -1.0122],
         [ 0.1403,  1.2825, -1.2302, -1.4635,  0.7165],
         [ 0.8251, -1.5065, -1.0976, -0.0920,  0.7596]],

        [[ 1.5074, -0.4161,  0.5480,  1.1882, -1.3595],
         [ 0.4081, -0.2427,  1.0663,  1.1888, -1.2557],
         [ 0.1856, -0.5556, -1.4885,  0.5571, -1.3313]],

        [[-0.1015, -1.1877, -0.1132, -0.9613, -0.7382],
         [-0.9579,  1.7701,  1.8217,  0.5509, -0.6816],
         [ 1.6410,  0.2096, -0.6056, -0.8921,  0.2459]],

        [[-0.3648,  0.2833,  0.2137,  0.8552, -0.9622],
         [ 0.3903, -1.2456, -0.6480, -0.4978,  0.8822],
         [-0.2155,  1.4676,  0.3733,  1.6645, -2.1963]],

        [[ 0.3817, -1.3816, -0.0126,  1.1476,  1.8669],
         [ 0.5688,  0.9381, -1.3945,  0.6863, -1.5389],
         [-0.2369,  0.7993, -1.0803, -0.4629, -0.2810]],

        [[ 0.7018,  1.2853, -0.2395,  1.8107,  0.1803],
         [-1.4307, -0.3072,  0.6677,  0.5036, -1.2250],
         [ 0.6976,  0.3290, -2.0245, -

We can now create a transformer easily

In [51]:
class TransformerTensorDict(nn.Module):
    def __init__(self, num_blocks, to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.transformer = nn.ModuleList([TransformerBlockTensorDict(to_name, from_name, to_dim, to_len, from_dim, latent_dim, num_heads) for _ in range(num_blocks)])
    def forward(self, X_tensor_dict):
        for transformer_block in self.transformer:
            transformer_block(X_tensor_dict)

In [52]:
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2

tokens = TensorDict({"X_to":torch.randn(batch_size, to_len, to_dim), "X_from":torch.randn(batch_size, from_len, from_dim)}, batch_size=[batch_size])




For an encoder, we just need to take the same tokens for both queries, keys and values.

In [53]:
transformer_encoder = TransformerTensorDict(6, "X_to", "X_to", to_dim, to_len, to_dim, latent_dim, num_heads)

transformer_encoder(tokens)
tokens["X_to"]

tensor([[[ 0.3641, -0.0979, -0.0854, -0.4905, -2.0029],
         [-1.5806,  0.5326,  0.6632,  0.7920, -1.1313],
         [ 0.5733, -0.1321,  2.1497,  0.6715, -0.2256]],

        [[-0.5301,  1.3014,  0.8339, -1.6371, -0.0447],
         [-0.1658,  0.9115, -0.1586, -0.1376, -1.6049],
         [-0.1376,  0.6356,  1.5934,  0.7616, -1.6211]],

        [[-0.7808,  0.1323,  0.3524, -1.8699, -1.2661],
         [-0.3532,  0.0802,  1.8762, -0.0252, -1.0859],
         [ 0.4886,  1.1930,  1.5847,  0.0510, -0.3773]],

        [[-0.5639,  0.4097,  1.2614, -1.1658,  0.8153],
         [ 0.9616, -1.0353,  0.8457,  0.3561, -1.7874],
         [-0.3008, -0.3083,  1.2756,  0.7932, -1.5570]],

        [[ 1.6996,  0.6104,  0.7166, -1.4791, -0.0746],
         [-1.3742, -0.0917,  0.3204,  0.4544, -1.6289],
         [ 0.7167,  0.8835,  1.0875, -0.8911, -0.9494]],

        [[-0.3857,  1.2185,  0.8044, -1.2474, -0.6684],
         [-2.6451,  0.3813,  0.8183,  0.2206, -0.0257],
         [-0.2125,  0.9058,  1.3390, -

For a decoder, we now can extract info from X_from into X_to. X_to will map to queries whereas X_from will map to keys and values.

In [54]:
transformer_decoder = TransformerTensorDict(6, "X_to", "X_from", to_dim, to_len, from_dim, latent_dim, num_heads)

transformer_decoder(tokens)
tokens["X_to"]

tensor([[[ 0.4194,  1.1577, -0.2249, -0.2299, -1.6838],
         [-0.4603,  1.5326,  0.1736,  0.4194, -2.1895],
         [ 0.2842,  0.8026,  1.1243, -0.1371, -0.9882]],

        [[-0.3067,  1.4502,  0.5999, -1.0000, -0.6322],
         [-0.2193,  2.2219, -0.2861, -0.1852, -1.1817],
         [ 0.0121,  0.7552,  0.8953, -0.3838, -1.7396]],

        [[-0.2220,  0.9376,  0.4108, -0.8514, -1.4368],
         [-0.2243,  1.0650,  1.1564, -0.5117, -1.8274],
         [ 0.3239,  1.7627,  0.7365, -0.2319, -1.0873]],

        [[ 0.0923,  0.1925,  2.0104, -1.0669,  0.2943],
         [ 1.0389, -0.3586,  0.0279, -0.2836, -1.3698],
         [-0.2385,  0.1469,  1.1481,  0.5598, -2.1937]],

        [[ 1.3624,  0.0818,  0.6500, -1.9682, -0.3359],
         [-0.3614,  0.9265,  0.2931,  0.2009, -1.4732],
         [ 0.2946,  1.1101,  1.3561, -1.0921, -1.0449]],

        [[ 0.1569,  1.4944,  0.1292, -1.0900, -0.8406],
         [-1.6930,  0.3846,  0.9151,  0.1136, -1.2707],
         [ 0.4416,  1.6022,  1.1814, -

Now we can look at both models:

In [55]:
transformer_encoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=5, out_features=10, bias=True)
              (v): Linear(in_features=5, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_to'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictMod

In [56]:
transformer_decoder

TransformerTensorDict(
  (transformer): ModuleList(
    (0): TransformerBlockTensorDict(
      (transformer_block): Sequential(
        (0): TensorDictModule(
            module=TokensToQKV(
              (q): Linear(in_features=5, out_features=10, bias=True)
              (k): Linear(in_features=6, out_features=10, bias=True)
              (v): Linear(in_features=6, out_features=10, bias=True)
            ), 
            device=cpu, 
            in_keys=['X_to', 'X_from'], 
            out_keys=['Q', 'K', 'V'])
        (1): TensorDictModule(
            module=SplitHeads(), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['Q', 'K', 'V'])
        (2): TensorDictModule(
            module=Attention(
              (softmax): Softmax(dim=-1)
              (out): Linear(in_features=10, out_features=5, bias=True)
            ), 
            device=cpu, 
            in_keys=['Q', 'K', 'V'], 
            out_keys=['X_out', 'Attn'])
        (3): TensorDictM