Skip to content

Commit

Permalink
fix: transformer decoder layer embed
Browse files Browse the repository at this point in the history
  • Loading branch information
typoverflow committed Oct 15, 2023
1 parent c2d8326 commit a03bda6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
2 changes: 1 addition & 1 deletion offlinerllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

__version__ = "0.1.1"
__version__ = "0.1.2"
5 changes: 1 addition & 4 deletions offlinerllib/module/net/attention/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,7 @@ def forward(
if tgt_attention_mask is not None:
tgt_mask = torch.bitwise_or(tgt_attention_mask.to(torch.bool), tgt_mask)
if do_embedding:
tgt = self.input_embed(tgt)
if timesteps is not None:
timesteps = torch.arange(L).repeat(B, 1).to(tgt.device)
tgt = tgt + self.pos_embed(timesteps)
tgt = self.pos_encoding(self.input_embed(tgt))
output = self.embed_dropout(tgt)
for i, block in enumerate(self.blocks):
output = block(
Expand Down

0 comments on commit a03bda6

Please sign in to comment.