Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions intermediate_source/seq2seq_translation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,15 @@

SOS_token = 0
EOS_token = 1
PAD_token = 2

class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2 # Count SOS and EOS
self.index2word = {0: "SOS", 1: "EOS", 2: "PAD"}
self.n_words = 3 # Count SOS, EOS, and PAD

def addSentence(self, sentence):
for word in sentence.split(' '):
Expand Down Expand Up @@ -335,13 +336,23 @@ def __init__(self, input_size, hidden_size, dropout_p=0.1):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size

self.embedding = nn.Embedding(input_size, hidden_size)
self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD_token)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.dropout = nn.Dropout(dropout_p)

def forward(self, input):
# Compute actual lengths (excluding padding)
lengths = (input != PAD_token).sum(dim=1).cpu()

embedded = self.dropout(self.embedding(input))
output, hidden = self.gru(embedded)

# Pack padded sequences
packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)
output, hidden = self.gru(packed)

# Unpack sequences
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

return output, hidden

######################################################################
Expand Down Expand Up @@ -375,7 +386,7 @@ def forward(self, input):
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size):
super(DecoderRNN, self).__init__()
self.embedding = nn.Embedding(output_size, hidden_size)
self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)

Expand Down Expand Up @@ -480,7 +491,7 @@ def forward(self, query, keys):
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.embedding = nn.Embedding(output_size, hidden_size)
self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token)
self.attention = BahdanauAttention(hidden_size)
self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
Expand Down Expand Up @@ -563,8 +574,8 @@ def get_dataloader(batch_size):
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)

n = len(pairs)
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
input_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32)
target_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32)

for idx, (inp, tgt) in enumerate(pairs):
inp_ids = indexesFromSentence(input_lang, inp)
Expand All @@ -583,6 +594,28 @@ def get_dataloader(batch_size):


######################################################################
# .. note::
# When working with batched sequences of variable lengths, proper padding
# handling is crucial:
#
# 1. **Padding Token**: We use a dedicated ``PAD_token`` (index 2) to pad
# shorter sequences to the batch's maximum length. This is better than
# using 0 (SOS token) as padding.
#
# 2. **Encoder Padding**: The encoder uses ``pack_padded_sequence`` and
# ``pad_packed_sequence`` to handle variable-length sequences efficiently.
# This ensures the GRU's final hidden state represents the actual sentence
# content, not padding tokens.
#
# 3. **Loss Masking**: The loss function uses ``ignore_index=PAD_token`` to
# exclude padding tokens from the loss computation. This prevents the model
# from learning to predict padding and ensures gradients only flow from
# actual target tokens.
#
# 4. **Embedding Padding**: All embedding layers use ``padding_idx=PAD_token``
# to ensure padding tokens have zero embeddings that don't get updated
# during training.
#
# Training the Model
# ------------------
#
Expand Down Expand Up @@ -678,7 +711,7 @@ def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()
criterion = nn.NLLLoss(ignore_index=PAD_token)

for epoch in range(1, n_epochs + 1):
loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
Expand Down