A notebook to test the stacking of the encoder blocks (and maybe the decoders too) - to fully
understand 2 things:

1. how the stacking works with the input and output shapes
2. how the nn.ModuleList class works to acheieve above

In [1]:
import torch
import torch.nn as nn

from EncoderBlock import Encoder
from PositionalEncoding import PositionalEncoding

In [6]:
# Parameters
vocab_size = 100  # Number of unique tokens in the vocabulary
embedding_dim = 512  # Dimensionality of embeddings
batch_size = 64  # Number of sequences in a batch
seq_len = 10  # Length of each sequence

# Create random input data: batch_size sequences of length seq_len with values from 0 to vocab_size-1
dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))
print('DUMMY_INPUT:', dummy_input.shape, dummy_input.dtype)

embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
pos_enc = PositionalEncoding(d_model=embedding_dim)

x = embedding(dummy_input.long())        
x = pos_enc(x)

DUMMY_INPUT: torch.Size([64, 10]) torch.int64


In [7]:
encoder1 = Encoder(vocab_size, embedding_dim)
encoder2 = Encoder(vocab_size, embedding_dim)
encoder3 = Encoder(vocab_size, embedding_dim)

In [8]:
output1 = encoder1(x)
print('After encoder1:', output1.shape, output1.dtype)

output2 = encoder2(output1)
print('After encoder2:', output2.shape, output2.dtype)

output3 = encoder3(output2)
print('After encoder3:', output3.shape, output3.dtype)

ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
After encoder1: torch.Size([64, 10, 512]) torch.float32
ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
After encoder2: torch.Size([64, 10, 512]) torch.float32
ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
After encoder3: torch.Size([64, 10, 512]) torch.float32


In [9]:
import torch.nn as nn

# Create a list of encoder instances
encoders = [encoder1, encoder2, encoder3]

# Stack the encoders using nn.ModuleList
stacked_encoders = nn.ModuleList(encoders)




In [10]:
for i, encoder in enumerate(stacked_encoders):
    x = encoder(x)
    print(f"Output after encoder {i+1}: {x.shape}")


ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
Output after encoder 1: torch.Size([64, 10, 512])
ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
Output after encoder 2: torch.Size([64, 10, 512])
ENCODER_BLOCK_OUTPUT torch.Size([64, 10, 512]) torch.float32
Output after encoder 3: torch.Size([64, 10, 512])
