Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
228b158
Added TensorDict tutorial
nicolas-dufour Jul 7, 2022
1f03199
Fixed english mistakes and small refactoring
nicolas-dufour Jul 7, 2022
f4039fa
init
vmoens Jul 8, 2022
1337768
init
vmoens Jul 8, 2022
f4db1c2
Retrieved bug fixes
nicolas-dufour Jul 8, 2022
ab49165
Merge remote-tracking branch 'origin_fb/bugfix_in_keys_exclusion' int…
nicolas-dufour Jul 8, 2022
04fe61c
Added suggered changes and cleaned up
nicolas-dufour Jul 8, 2022
0399f49
Added transformer figure
nicolas-dufour Jul 11, 2022
97bb800
init
vmoens Jul 11, 2022
dd726bf
Merge pull request #1 from vmoens/pr-255
nicolas-dufour Jul 11, 2022
9ecff5e
Merge branch 'main' of github.com:nicolas-dufour/torchrl into tensord…
nicolas-dufour Jul 11, 2022
4fb2571
TensorDictModule initial commit
nicolas-dufour Jul 11, 2022
38e8beb
Details fixed
nicolas-dufour Jul 13, 2022
f3a12be
Made suggered modifications
nicolas-dufour Jul 13, 2022
f7d8622
Made suggered modifications
nicolas-dufour Jul 13, 2022
7010b54
Made changes
nicolas-dufour Jul 13, 2022
7836c02
Suggested changes and do and dont
nicolas-dufour Jul 14, 2022
7dd7efb
Formating
nicolas-dufour Jul 14, 2022
81be639
Formating
nicolas-dufour Jul 14, 2022
8300358
Merge branch 'main' of github.com:nicolas-dufour/torchrl into tensord…
nicolas-dufour Jul 14, 2022
78a40b2
Did some changes
nicolas-dufour Jul 15, 2022
49edf5d
Made suggested changes
nicolas-dufour Jul 18, 2022
0ee7a42
Merge branch 'main' into tensordictmodule_tutorial
nicolas-dufour Jul 19, 2022
84de974
Added tensordictmodule tutorial to README.MD
nicolas-dufour Jul 19, 2022
9999c2b
Clean rerun
nicolas-dufour Jul 19, 2022
a5acde6
Added benchmark
nicolas-dufour Jul 21, 2022
52172ba
Warning clean-up
nicolas-dufour Jul 21, 2022
45ccd8b
Made suggested changes
nicolas-dufour Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tutorials/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ Get a sense of TorchRL functionalities through our tutorials.

For an overview of TorchRL, try the [TorchRL demo](demo.ipynb).

Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict
Make sure you test the [TensorDict tutorial](tensordict.ipynb) to see what TensorDict
is about and what it can do.

Checkout the [environment demo](envs.ipynb) for a deep dive in the envs
To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](tensordictmodule.ipynb).

Checkout the [environment tutorial](envs.ipynb) for a deep dive in the envs
functionalities.
Binary file added tutorials/media/transformer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
164 changes: 164 additions & 0 deletions tutorials/src/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch.nn as nn


class TokensToQKV(nn.Module):
def __init__(self, to_dim, from_dim, latent_dim):
super().__init__()
self.q = nn.Linear(to_dim, latent_dim)
self.k = nn.Linear(from_dim, latent_dim)
self.v = nn.Linear(from_dim, latent_dim)

def forward(self, X_to, X_from):
Q = self.q(X_to)
K = self.k(X_from)
V = self.v(X_from)
return Q, K, V


class SplitHeads(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads

def forward(self, Q, K, V):
batch_size, to_num, latent_dim = Q.shape
_, from_num, _ = K.shape
d_tensor = latent_dim // self.num_heads
Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
return Q, K, V


class Attention(nn.Module):
def __init__(self, latent_dim, to_dim):
super().__init__()
self.softmax = nn.Softmax(dim=-1)
self.out = nn.Linear(latent_dim, to_dim)

def forward(self, Q, K, V):
batch_size, n_heads, to_num, d_in = Q.shape
attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
out = attn @ V
out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
return out, attn


class SkipLayerNorm(nn.Module):
def __init__(self, to_len, to_dim):
super().__init__()
self.layer_norm = nn.LayerNorm((to_len, to_dim))

def forward(self, x_0, x_1):
return self.layer_norm(x_0 + x_1)


class FFN(nn.Module):
def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
super().__init__()
self.FFN = nn.Sequential(
nn.Linear(to_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, to_dim),
nn.Dropout(dropout_rate),
)

def forward(self, X):
return self.FFN(X)


class AttentionBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
self.split_heads = SplitHeads(num_heads)
self.attention = Attention(latent_dim, to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)

def forward(self, X_to, X_from):
Q, K, V = self.tokens_to_qkv(X_to, X_from)
Q, K, V = self.split_heads(Q, K, V)
out, attention = self.attention(Q, K, V)
out = self.skip(X_to, out)
return out


class EncoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, to_dim, latent_dim, num_heads
)
self.FFN = FFN(to_dim, 4 * to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)

def forward(self, X_to):
X_to = self.attention_block(X_to, X_to)
X_out = self.FFN(X_to)
return self.skip(X_out, X_to)


class DecoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, from_dim, latent_dim, num_heads
)
self.encoder_block = EncoderTransformerBlock(
to_dim, to_len, latent_dim, num_heads
)

def forward(self, X_to, X_from):
X_to = self.attention_block(X_to, X_from)
X_to = self.encoder_block(X_to)
return X_to


class TransformerEncoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.encoder = nn.ModuleList(
[
EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
for i in range(num_blocks)
]
)

def forward(self, X_to):
for i in range(len(self.encoder)):
X_to = self.encoder[i](X_to)
return X_to


class TransformerDecoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.decoder = nn.ModuleList(
[
DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
for i in range(num_blocks)
]
)

def forward(self, X_to, X_from):
for i in range(len(self.decoder)):
X_to = self.decoder[i](X_to, X_from)
return X_to


class Transformer(nn.Module):
def __init__(
self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
):
super().__init__()
self.encoder = TransformerEncoder(
num_blocks, to_dim, to_len, latent_dim, num_heads
)
self.decoder = TransformerDecoder(
num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
)

def forward(self, X_to, X_from):
X_to = self.encoder(X_to)
X_out = self.decoder(X_from, X_to)
return X_out
Loading