From a1c44575374fb16f5dbb9a834349520b8e8d773e Mon Sep 17 00:00:00 2001 From: typoverflow Date: Sun, 15 Oct 2023 21:05:48 +0800 Subject: [PATCH] fix: transformer refactor bugs --- offlinerllib/__init__.py | 2 +- offlinerllib/module/net/attention/transformer.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/offlinerllib/__init__.py b/offlinerllib/__init__.py index abe9369..47f13fe 100644 --- a/offlinerllib/__init__.py +++ b/offlinerllib/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/offlinerllib/module/net/attention/transformer.py b/offlinerllib/module/net/attention/transformer.py index 32339d3..de826a6 100644 --- a/offlinerllib/module/net/attention/transformer.py +++ b/offlinerllib/module/net/attention/transformer.py @@ -153,7 +153,7 @@ def _mha_block(self, input, key_value, attention_mask, key_padding_mask): key=key_value, value=key_value, need_weights=False, - attention_mask=attention_mask, + attn_mask=attention_mask, key_padding_mask=key_padding_mask )[0] return self.dropout2(input) @@ -194,7 +194,7 @@ def __init__( ) for _ in range(num_layers) ]) - self.out_ln = nn.LayerNorm() if out_ln else nn.Identity() + self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity() self.causal = causal def forward( @@ -254,7 +254,7 @@ def __init__( ) for _ in range(num_layers) ]) - self.out_ln = nn.LayerNorm() if out_ln else nn.Identity() + self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity() self.causal = causal def forward( @@ -276,10 +276,7 @@ def forward( if tgt_attention_mask is not None: tgt_mask = torch.bitwise_or(tgt_attention_mask.to(torch.bool), tgt_mask) if do_embedding: - tgt = self.input_embed(tgt) - if timesteps is not None: - timesteps = torch.arange(L).repeat(B, 1).to(tgt.device) - tgt = tgt + self.pos_embed(timesteps) + tgt = self.pos_encoding(self.input_embed(tgt)) output = self.embed_dropout(tgt) for i, block in enumerate(self.blocks): output = block(