You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have looked into [Transformer](https://github.com/unit8co/darts/blame/master/darts/models/forecasting/transformer_model.py) and I have found some errors.
Frist,
In line 167, 170,
src=self.encoder(src) *math.sqrt(self.input_size)
tgt=self.encoder(tgt) *math.sqrt(self.input_size)
I don't think we have to multiply math.sqrt(self.input_size) to inputs (src or tgt).
Because torch.nn.MultiheadAttention take cares this normalization.
Second,
In line 173 - 174,
x=self.transformer(src=src,
tgt=tgt)
There is no tgt_mask for this prediction. In order to use teacher forcing at training stage, user must feed tgt_mask to forward function (specifically square_subsequent_mask defined below). Otherwise decoder inputs before time t can see future decoder inputs (e.g, t+1, t+2, ...) which doesn't exist at inference stage.
[docs] @staticmethoddefgenerate_square_subsequent_mask(sz: int) ->Tensor:
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """returntorch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
I'm not sure these things are errors.
But, in my opinion, it seems this is not correct.
Thank you!
The text was updated successfully, but these errors were encountered:
I have looked into
[Transformer](https://github.com/unit8co/darts/blame/master/darts/models/forecasting/transformer_model.py)
and I have found some errors.Frist,
In line 167, 170,
I don't think we have to multiply
math.sqrt(self.input_size)
to inputs (src
ortgt
).Because
torch.nn.MultiheadAttention
take cares this normalization.Second,
In line 173 - 174,
There is no
tgt_mask
for this prediction. In order to use teacher forcing at training stage, user must feedtgt_mask
to forward function (specificallysquare_subsequent_mask
defined below). Otherwise decoder inputs before time t can see future decoder inputs (e.g, t+1, t+2, ...) which doesn't exist at inference stage.I'm not sure these things are errors.
But, in my opinion, it seems this is not correct.
Thank you!
The text was updated successfully, but these errors were encountered: