Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/mc dropout #2312

Merged
merged 7 commits into from Apr 11, 2024
Merged

Fix/mc dropout #2312

merged 7 commits into from Apr 11, 2024

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Apr 10, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2300.

Summary

  • fixes Monte Carlo Dropout for all TorchForecastingModels. MC dropout training mode was not getting activated properly in the different model training/val/pred/.. stages.
  • add MonteCarloDropout to all TorchForecastingModels that were not using them before.

Additional Information

Below an example that it works properly now.

import numpy as np
import matplotlib.pyplot as plt

from darts import concatenate
from darts.models import TCNModel, TSMixerModel
from darts.datasets import AirPassengersDataset

series = AirPassengersDataset().load().astype(np.float32)
preds = []

dropouts = [0., 0.5, 0.99]
for dropout in dropouts:
    model = TCNModel(13, 12, dropout=dropout, random_state=42)
    model.fit(series, epochs=100)
    preds.append(
        concatenate(
            model.historical_forecasts(
                series=series,
                retrain=False,
                forecast_horizon=12,
                stride=12,
                last_points_only=False,
            ),
            axis=0
        )
    )

series.plot()
for pred, dropout in zip(preds, dropouts):
    pred.plot(label=f"d={dropout}")
plt.show()
image

Copy link

codecov bot commented Apr 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.01%. Comparing base (883e35e) to head (bb8f15c).

❗ Current head bb8f15c differs from pull request most recent head b95337b. Consider uploading reports for the commit b95337b to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2312      +/-   ##
==========================================
- Coverage   94.01%   94.01%   -0.01%     
==========================================
  Files         138      138              
  Lines       14145    14138       -7     
==========================================
- Hits        13299    13292       -7     
  Misses        846      846              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@dennisbader dennisbader merged commit bd5340f into master Apr 11, 2024
9 checks passed
@dennisbader dennisbader deleted the fix/mc_dropout branch April 11, 2024 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Monte Carlo dropout does not work with pytorch-lightning >=2.2.0
2 participants