Skip to content

Commit

Permalink
fix: transformer refactor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
typoverflow committed Oct 15, 2023
1 parent c2d8326 commit a1c4457
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
2 changes: 1 addition & 1 deletion offlinerllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

__version__ = "0.1.1"
__version__ = "0.1.2"
11 changes: 4 additions & 7 deletions offlinerllib/module/net/attention/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _mha_block(self, input, key_value, attention_mask, key_padding_mask):
key=key_value,
value=key_value,
need_weights=False,
attention_mask=attention_mask,
attn_mask=attention_mask,
key_padding_mask=key_padding_mask
)[0]
return self.dropout2(input)
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
) for _ in range(num_layers)
])

self.out_ln = nn.LayerNorm() if out_ln else nn.Identity()
self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity()
self.causal = causal

def forward(
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(
) for _ in range(num_layers)
])

self.out_ln = nn.LayerNorm() if out_ln else nn.Identity()
self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity()
self.causal = causal

def forward(
Expand All @@ -276,10 +276,7 @@ def forward(
if tgt_attention_mask is not None:
tgt_mask = torch.bitwise_or(tgt_attention_mask.to(torch.bool), tgt_mask)
if do_embedding:
tgt = self.input_embed(tgt)
if timesteps is not None:
timesteps = torch.arange(L).repeat(B, 1).to(tgt.device)
tgt = tgt + self.pos_embed(timesteps)
tgt = self.pos_encoding(self.input_embed(tgt))
output = self.embed_dropout(tgt)
for i, block in enumerate(self.blocks):
output = block(
Expand Down

0 comments on commit a1c4457

Please sign in to comment.