From c0b889d5f43150f288ecdd5dbd16c146d79e5bdf Mon Sep 17 00:00:00 2001 From: Alex Shroyer Date: Sun, 26 Nov 2023 20:10:21 -0500 Subject: [PATCH] simpler subsequent mask generator (#1198) --- word_language_model/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/word_language_model/model.py b/word_language_model/model.py index 1b972abb91..94d773bd1e 100644 --- a/word_language_model/model.py +++ b/word_language_model/model.py @@ -120,9 +120,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): self.init_weights() def _generate_square_subsequent_mask(self, sz): - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) - return mask + return torch.log(torch.tril(torch.ones(sz,sz))) def init_weights(self): initrange = 0.1