Skip to content

Commit a5b37c9

Browse files
fix TokenEmbedding forward method
1 parent 6537199 commit a5b37c9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

beginner_source/translation_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, vocab_size: int, emb_size):
132132
self.emb_size = emb_size
133133

134134
def forward(self, tokens: Tensor):
135-
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
135+
return self.embedding(tokens.long()) / math.sqrt(self.emb_size)
136136

137137
# Seq2Seq Network
138138
class Seq2SeqTransformer(nn.Module):

0 commit comments

Comments
 (0)