Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Implementation errors on TransformerModel. #672

Open
Sangwon91 opened this issue Dec 6, 2021 · 0 comments
Open

[BUG] Implementation errors on TransformerModel. #672

Sangwon91 opened this issue Dec 6, 2021 · 0 comments
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@Sangwon91
Copy link

I have looked into [Transformer](https://github.com/unit8co/darts/blame/master/darts/models/forecasting/transformer_model.py) and I have found some errors.

Frist,

In line 167, 170,

src = self.encoder(src) * math.sqrt(self.input_size)
tgt = self.encoder(tgt) * math.sqrt(self.input_size)

I don't think we have to multiply math.sqrt(self.input_size) to inputs (src or tgt).
Because torch.nn.MultiheadAttention take cares this normalization.

Second,

In line 173 - 174,

        x = self.transformer(src=src,
                             tgt=tgt)

There is no tgt_mask for this prediction. In order to use teacher forcing at training stage, user must feed tgt_mask to forward function (specifically square_subsequent_mask defined below). Otherwise decoder inputs before time t can see future decoder inputs (e.g, t+1, t+2, ...) which doesn't exist at inference stage.

[docs]    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

I'm not sure these things are errors.
But, in my opinion, it seems this is not correct.

Thank you!

@Sangwon91 Sangwon91 added bug Something isn't working triage Issue waiting for triaging labels Dec 6, 2021
@madtoinou madtoinou added good first issue Good for newcomers and removed triage Issue waiting for triaging labels Feb 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants