Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/sample weight torch #2410

Merged
merged 18 commits into from
Jun 17, 2024
Merged

Feat/sample weight torch #2410

merged 18 commits into from
Jun 17, 2024

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Jun 13, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #1175, fixes #2107.

(this is a continuation of #2404)

Summary

  • adds support for training torch forecasting models with sample weights.
  • fit() now has two new parameters sample_weight and val_sample_weight which allow to pass sample weights for training and validation
sample_weight
            Optionally, some sample weights to apply to the target `series` labels.
            They are applied per observation, per label (each step in `output_chunk_length`), and per component.
            If a string, then the weights are generated using built-in weighting functions. The available options are
            `"linear_decay"` or `"exponential_decay"`. The weights are only computed the longest series in `series`,
            and then applied globally to all `series` to have a common time weighting.
            If a `TimeSeries` or `Sequence[TimeSeries]`, then those weights are used. The number of series must
            match the number of target `series` and each series must contain at least all time steps from the
            corresponding target `series`. If the weight series only have a single component / column, then the weights
            are applied globally to all components in `series`. Otherwise, for component-specific weights, the number
            of components must match those of `series`.
val_sample_weight
            Same as for `sample_weight` but for the evaluation dataset.

Copy link

codecov bot commented Jun 14, 2024

Codecov Report

Attention: Patch coverage is 88.62559% with 24 lines in your changes missing coverage. Please review.

Project coverage is 93.73%. Comparing base (05f6ddf) to head (a1fd48c).
Report is 2 commits behind head on master.

Current head a1fd48c differs from pull request most recent head c31dbc9

Please upload reports for the commit c31dbc9 to get more accurate results.

Files Patch % Lines
darts/utils/data/shifted_dataset.py 82.45% 10 Missing ⚠️
darts/utils/data/horizon_based_dataset.py 76.00% 6 Missing ⚠️
darts/utils/data/utils.py 84.61% 4 Missing ⚠️
...arts/models/forecasting/torch_forecasting_model.py 92.30% 2 Missing ⚠️
darts/utils/likelihood_models.py 91.30% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2410      +/-   ##
==========================================
- Coverage   93.77%   93.73%   -0.04%     
==========================================
  Files         138      138              
  Lines       14384    14632     +248     
==========================================
+ Hits        13488    13716     +228     
- Misses        896      916      +20     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/utils/data/shifted_dataset.py Show resolved Hide resolved
darts/utils/data/training_dataset.py Outdated Show resolved Hide resolved
@dennisbader dennisbader merged commit b532a80 into master Jun 17, 2024
7 of 9 checks passed
@dennisbader dennisbader deleted the feat/sample_weight_torch branch June 17, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Applying sample weights to the training process Easy Sample Weights
2 participants