From 0c1bc6c7ed4c8963e5c006b370c29b90009212c4 Mon Sep 17 00:00:00 2001 From: vinhtran Date: Thu, 23 Feb 2023 10:29:00 +0700 Subject: [PATCH] fix len greedy search --- beginner_source/translation_transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/beginner_source/translation_transformer.py b/beginner_source/translation_transformer.py index 6e2538d1599..2a4f4ad8320 100644 --- a/beginner_source/translation_transformer.py +++ b/beginner_source/translation_transformer.py @@ -357,7 +357,7 @@ def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE) - for i in range(max_len-1): + for i in range(max_len): memory = memory.to(DEVICE) tgt_mask = (generate_square_subsequent_mask(ys.size(0)) .type(torch.bool)).to(DEVICE) @@ -371,7 +371,8 @@ def greedy_decode(model, src, src_mask, max_len, start_symbol): torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) if next_word == EOS_IDX: break - return ys + # Return the target sequence without the start symbol + return ys[:, 1:] # actual function to translate input sentence into target language