In [4]:
import sys
sys.path.append('../')
import torch
from transformer.layers import MultiHeadAttention, FeedForward, LayerNorm
from transformer.encoder import EncoderLayer, Encoder

In [5]:

# Define the dimensions, number of heads, dropout rate, and number of layers
d_model = 512
h = 8
dropout = 0.1
num_layers = 6
d_ff = 2048 # the dimension of the feed forward network
batch = 10
seq_len = 20

# Create an instance of the MultiHeadAttention and FeedForward classes
self_attention_engine = MultiHeadAttention(d_model, h, dropout)
feed_forward = FeedForward(d_model, d_ff, dropout)  

# Create an instance of the EncoderLayer class
encoder_layer = EncoderLayer(d_model, self_attention_engine, feed_forward, dropout)

# Create an instance of the Encoder class
encoder = Encoder(d_model, encoder_layer, num_layers)

# Create a random tensor to represent a batch of sequences
x = torch.rand(batch, seq_len, d_model)  # batch_size=10, seq_len=20, d_model=512

# Pass the tensor through the encoder
output = encoder(x)

print(output.shape)  # Should print: torch.Size([10, 20, 512])

torch.Size([10, 20, 512])


In [6]:
output

tensor([[[-1.2745, -1.3341,  1.8563,  ..., -0.7987, -0.4595, -2.0738],
         [-2.1539,  0.0743,  1.8552,  ..., -1.3345, -1.8894, -2.2995],
         [-1.1547, -0.4120,  1.2284,  ..., -0.3696, -0.2565, -1.4057],
         ...,
         [-1.6552,  0.1109,  1.2367,  ..., -0.4721, -0.9330, -2.0241],
         [-0.8082, -0.1798,  1.3354,  ..., -1.2474, -0.7440, -1.6114],
         [-1.1402, -0.5075,  0.2664,  ..., -0.5437, -1.4057, -2.1198]],

        [[-1.0775, -0.6065,  0.9312,  ...,  0.3671, -1.0054, -1.7633],
         [-1.5911, -0.3359,  1.2845,  ..., -1.0076, -1.4768, -1.8935],
         [-1.0480, -0.7467,  1.7230,  ..., -0.6521, -0.8205, -2.0892],
         ...,
         [-1.3126, -1.3663,  1.4949,  ..., -0.6949, -1.2023, -1.4043],
         [-0.3771, -1.2371,  0.7752,  ..., -0.4985, -0.8243, -0.9701],
         [-1.1369, -0.0368,  1.4173,  ..., -0.9032, -0.7359, -1.7543]],

        [[-0.2336, -1.0617,  1.4119,  ..., -0.2019, -0.5825, -1.0645],
         [-0.0915, -1.4929,  1.8660,  ..., -1