In [9]:
import sys
sys.path.append('../')
import torch
from transformer.layers import MultiHeadAttention, FeedForward, ResidualConnection
from transformer.decoder import DecoderLayer

# Define config
d_model = 7 # feature dimension
d_ff = 2048 # feed forward dimesion
h =  1 # number of heads
batch_size = 1 # batch_size
seq_len = 4 # sequence length
dropout = 0.1 # dropout ratio

# Create an instance of the MultiHeadAttention and FeedForward classes
self_attention = MultiHeadAttention(d_model, h, dropout)
encoder_decoder_attention = MultiHeadAttention(d_model, h, dropout)
feed_forward = FeedForward(d_model, d_ff, dropout)  

# Create an instance of the DecoderLayer class
decoder_layer = DecoderLayer(d_model, self_attention, encoder_decoder_attention, feed_forward, dropout)

# Create a random tensor to represent a batch of sequences
torch.manual_seed(68) # for reproducible result of random process
x = torch.rand(batch_size, seq_len, d_model) 
torch.manual_seed(101) # for reproducible result of random process
encoder_output = torch.rand(batch_size, seq_len, d_model)  

# Pass the tensor through the decoder layer
output = decoder_layer(x, encoder_output)

print("Initial input tensor: \n", x)
print("Encoder output tensor: \n", encoder_output)
print("Output's shape: \n", output.shape)  
print("Output: \n", output)  

Initial input tensor: 
 tensor([[[0.3991, 0.5521, 0.1004, 0.2844, 0.9998, 0.7077, 0.8031],
         [0.2066, 0.3589, 0.8509, 0.8753, 0.4669, 0.6566, 0.6026],
         [0.2785, 0.1350, 0.2257, 0.9548, 0.8214, 0.1386, 0.6055],
         [0.2300, 0.7895, 0.4098, 0.0428, 0.4400, 0.2381, 0.4967]]])
Encoder output tensor: 
 tensor([[[0.1980, 0.4503, 0.0909, 0.8872, 0.2894, 0.0186, 0.9095],
         [0.3406, 0.4309, 0.7324, 0.4776, 0.0716, 0.5834, 0.7521],
         [0.7649, 0.1443, 0.7152, 0.3953, 0.6244, 0.3684, 0.8823],
         [0.3746, 0.1458, 0.3671, 0.5645, 0.5272, 0.1141, 0.0992]]])
Output's shape: 
 torch.Size([1, 4, 7])
Output: 
 tensor([[[ 0.0586,  0.5650, -0.0017,  0.3634,  1.0980,  1.2257,  0.5185],
         [ 0.4581,  0.2700,  0.6408,  1.0297,  0.0420,  1.0079,  0.1963],
         [-0.1930,  0.0786, -0.3701,  1.2002,  0.5512,  0.0592,  0.0707],
         [ 0.3331,  0.5842,  0.3511,  0.2846,  0.2464,  0.9387,  0.7481]]],
       grad_fn=<AddBackward0>)
