diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py index 5de4bb4ca3e..4403cf663af 100755 --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -150,14 +150,15 @@ SOS_token = 0 EOS_token = 1 +PAD_token = 2 class Lang: def __init__(self, name): self.name = name self.word2index = {} self.word2count = {} - self.index2word = {0: "SOS", 1: "EOS"} - self.n_words = 2 # Count SOS and EOS + self.index2word = {0: "SOS", 1: "EOS", 2: "PAD"} + self.n_words = 3 # Count SOS, EOS, and PAD def addSentence(self, sentence): for word in sentence.split(' '): @@ -335,13 +336,23 @@ def __init__(self, input_size, hidden_size, dropout_p=0.1): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size - self.embedding = nn.Embedding(input_size, hidden_size) + self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD_token) self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True) self.dropout = nn.Dropout(dropout_p) def forward(self, input): + # Compute actual lengths (excluding padding) + lengths = (input != PAD_token).sum(dim=1).cpu() + embedded = self.dropout(self.embedding(input)) - output, hidden = self.gru(embedded) + + # Pack padded sequences + packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False) + output, hidden = self.gru(packed) + + # Unpack sequences + output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) + return output, hidden ###################################################################### @@ -375,7 +386,7 @@ def forward(self, input): class DecoderRNN(nn.Module): def __init__(self, hidden_size, output_size): super(DecoderRNN, self).__init__() - self.embedding = nn.Embedding(output_size, hidden_size) + self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token) self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True) self.out = nn.Linear(hidden_size, output_size) @@ -480,7 +491,7 @@ def forward(self, query, keys): class AttnDecoderRNN(nn.Module): def __init__(self, hidden_size, output_size, dropout_p=0.1): super(AttnDecoderRNN, self).__init__() - self.embedding = nn.Embedding(output_size, hidden_size) + self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_token) self.attention = BahdanauAttention(hidden_size) self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True) self.out = nn.Linear(hidden_size, output_size) @@ -563,8 +574,8 @@ def get_dataloader(batch_size): input_lang, output_lang, pairs = prepareData('eng', 'fra', True) n = len(pairs) - input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32) - target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32) + input_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32) + target_ids = np.full((n, MAX_LENGTH), PAD_token, dtype=np.int32) for idx, (inp, tgt) in enumerate(pairs): inp_ids = indexesFromSentence(input_lang, inp) @@ -583,6 +594,28 @@ def get_dataloader(batch_size): ###################################################################### +# .. note:: +# When working with batched sequences of variable lengths, proper padding +# handling is crucial: +# +# 1. **Padding Token**: We use a dedicated ``PAD_token`` (index 2) to pad +# shorter sequences to the batch's maximum length. This is better than +# using 0 (SOS token) as padding. +# +# 2. **Encoder Padding**: The encoder uses ``pack_padded_sequence`` and +# ``pad_packed_sequence`` to handle variable-length sequences efficiently. +# This ensures the GRU's final hidden state represents the actual sentence +# content, not padding tokens. +# +# 3. **Loss Masking**: The loss function uses ``ignore_index=PAD_token`` to +# exclude padding tokens from the loss computation. This prevents the model +# from learning to predict padding and ensures gradients only flow from +# actual target tokens. +# +# 4. **Embedding Padding**: All embedding layers use ``padding_idx=PAD_token`` +# to ensure padding tokens have zero embeddings that don't get updated +# during training. +# # Training the Model # ------------------ # @@ -678,7 +711,7 @@ def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001, encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) - criterion = nn.NLLLoss() + criterion = nn.NLLLoss(ignore_index=PAD_token) for epoch in range(1, n_epochs + 1): loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)