Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reset ptl trainer when loading torch models #1124

Merged
merged 4 commits into from
Aug 7, 2022
Merged

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Aug 6, 2022

Fixes #1116.

Summary

  • Fixes error when loading TorchForecastingModels where trainer is not associated with a model.
  • fixed issue where custom feed forward moduels were ignored in previous PyTorch version

@codecov-commenter
Copy link

codecov-commenter commented Aug 6, 2022

Codecov Report

Merging #1124 (be82df8) into master (caccce1) will increase coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #1124      +/-   ##
==========================================
+ Coverage   93.65%   93.74%   +0.08%     
==========================================
  Files          79       80       +1     
  Lines        8137     8147      +10     
==========================================
+ Hits         7621     7637      +16     
+ Misses        516      510       -6     
Impacted Files Coverage Δ
darts/timeseries.py 92.15% <ø> (-0.07%) ⬇️
darts/dataprocessing/dtw/dtw.py 94.20% <100.00%> (-0.13%) ⬇️
darts/models/components/transformer.py 100.00% <100.00%> (ø)
...arts/models/forecasting/torch_forecasting_model.py 89.54% <100.00%> (-0.05%) ⬇️
darts/models/forecasting/transformer_model.py 100.00% <100.00%> (ø)
darts/models/forecasting/block_rnn_model.py 98.24% <0.00%> (-0.04%) ⬇️
darts/models/forecasting/nhits.py 98.55% <0.00%> (-0.02%) ⬇️
darts/datasets/__init__.py 100.00% <0.00%> (ø)
... and 2 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

# overwrite the feed forward block
def _ff_block(self, x):
x = self.activation(x)
return self.dropout2(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The glu variants have dropout built-in

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I think we should add the second dropout following the original implementation (see here):

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)

then per TransformerEncoderLayer:

    def forward(...):
        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
        x = self.norm2(x + self._ff_block(x))
    return x

This is equivalent to what they do in the Annotated Transfomer with SublayerConnection, EncoderLayer and Position-wise Feed-Forward Networks.

Our current FeedForward class does (changed a bit so it's easier to compare):

    def forward(self, x)
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return x

So when we overwrite the PyTorch TransformerEncoderLayer, the last dropout would be missing.

# use glu variant feedforward layers
self.activation = getattr(glu_variants, activation)(
d_model=d_model, d_ff=dim_feedforward, dropout=dropout
)
encoder_layer = _CustomFeedForwardEncoderLayer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the glu variants are full feedforward layers. It wasn't fully correct before to set the activation = glu_variants.
For this to be fully correct, I think there should be a separate arg for ff_bock and leave activation alone.

This customFeedForwardEncoderLayer makes more sense.

dropout=dropout,
activation=self.activation,
)
encoder_norm = nn.LayerNorm(d_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add the new norms when this gets merged

Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (though I didn't check the correctness in details)
Thanks!

@dennisbader dennisbader merged commit 4e5f1e6 into master Aug 7, 2022
@dennisbader dennisbader deleted the fix/ptl_update branch August 7, 2022 14:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to load a saved pytorch model due to trainer being required [BUG]
4 participants