Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/sample weights #2404

Merged
merged 42 commits into from
Jun 17, 2024
Merged

Feat/sample weights #2404

merged 42 commits into from
Jun 17, 2024

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Jun 5, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #1175, fixes #2107.

(this is a continuation of #2362)

Summary

  • adds support for training regression models with sample weights.
  • fit() now has two new parameters sample_weight and val_sample_weight which allow to pass sample weights for training and validation (val_sample_weights only for models that support validation sets)
sample_weight
            Optionally, some sample weights to apply to the target `series` labels.
            They are applied per observation, per label (each step in `output_chunk_length`), and per component.
            If a string, then the weights are generated using built-in weighting functions. The available options are
            `"linear_decay"` or `"exponential_decay"`. The weights are only computed the longest series in `series`,
            and then applied globally to all `series` to have a common time weighting.
            If a `TimeSeries` or `Sequence[TimeSeries]`, then those weights are used. The number of series must
            match the number of target `series` and each series must contain at least all time steps from the
            corresponding target `series`. If the weight series only have a single component / column, then the weights
            are applied globally to all components in `series`. Otherwise, for component-specific weights, the number
            of components must match those of `series`.
val_sample_weight
            Same as for `sample_weight` but for the evaluation dataset.

f"Possible values are: equal, linear_decay, exponential_decay.",
)
elif isinstance(sample_weight, TimeSeries):
# The error is caught later, should we still verify it here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is an error if in the docstring you are saying that If a TimeSeries is passed, then those weights are used.?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a placeholder, we are trying to think if it's worth it to check that the time index of the passed series matches the index of the training series for example.

Copy link

codecov bot commented Jun 5, 2024

Codecov Report

Attention: Patch coverage is 96.29630% with 7 lines in your changes missing coverage. Please review.

Project coverage is 93.78%. Comparing base (05f6ddf) to head (ae06f47).
Report is 1 commits behind head on master.

Current head ae06f47 differs from pull request most recent head 0b9b9cc

Please upload reports for the commit 0b9b9cc to get more accurate results.

Files Patch % Lines
darts/utils/multioutput.py 77.77% 4 Missing ⚠️
darts/utils/data/tabularization.py 97.05% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2404      +/-   ##
==========================================
+ Coverage   93.77%   93.78%   +0.01%     
==========================================
  Files         138      138              
  Lines       14384    14492     +108     
==========================================
+ Hits        13488    13592     +104     
- Misses        896      900       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

This was referenced Jun 7, 2024
Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

One minor comment about the index for the generated weights; maybe we should consider the two most extremes index values rather than the length of the longest series? WDYT?

darts/models/forecasting/regression_model.py Outdated Show resolved Hide resolved
darts/utils/data/tabularization.py Show resolved Hide resolved
@dennisbader dennisbader merged commit 6835c36 into master Jun 17, 2024
9 checks passed
@dennisbader dennisbader deleted the feat/sample_weights branch June 17, 2024 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Applying sample weights to the training process Easy Sample Weights
4 participants