-
Notifications
You must be signed in to change notification settings - Fork 830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/multivariate on baseline models #1373
Conversation
…t8co/darts into feat/multivariate-on-baseline-models
@@ -35,16 +36,15 @@ def __str__(self): | |||
|
|||
def fit(self, series: TimeSeries): | |||
super().fit(series) | |||
self.mean_val = np.mean(series.univariate_values()) | |||
self.mean_val = np.mean(series.values(), axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.mean_val = np.mean(series.values(), axis=0) | |
self.mean_val = np.mean(series.values(copy=False), axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking we may want to test that the series is deterministic (has only one sample) in all these models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's done in the ForecastingModel
.
Good point. I misinterpreted "deterministic" at first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put a check in LocalForecastingModel
class.
@@ -1403,7 +1409,7 @@ def _get_encoders_n(self, n) -> int: | |||
return n | |||
|
|||
|
|||
class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC): | |||
class FutureCovariatesLocalForecastingModel(LocalMultivariateForecastingModel, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this class structure too much, because it means that local future covariates models must necessarily support multivariate series. Whereas we could in principles have some future covariates models working on univariate series only (although we maybe don't have any at the moment).
I think this one should still inherit from LocalForecastingModel
only, WDYT?
Also a cleaner approach might be to rethink most of our model super-classes and see them as "mix-in interfaces" more. But that would be beyond the scope of this PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about the mixin approach but it seemed out of scope. I thought that requiring support for multivariate series was ok because that's how it was done previously. By default all models that inherit from LocalForecastingModel
were checked if they are fitted on univariate time series with explicit if statement that excluded FutureCovaraitesLocalForecastingModel
. The only difference with this approach is that I don't exclude that specific class but all models that are in LocalMultivaraiteForecastingModel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the end I put checks directly in the models that don't support multivariate input. It is probably the most versatile solution because it can be changed per model basis and it doesn't unnecessarily increase complexity of class structure.
Codecov ReportBase: 94.03% // Head: 93.96% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1373 +/- ##
==========================================
- Coverage 94.03% 93.96% -0.07%
==========================================
Files 82 82
Lines 8905 8905
==========================================
- Hits 8374 8368 -6
- Misses 531 537 +6
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Fixes #639
Summary
Extend baseline models to support multivariate time series.
Other Information
New class
LocalMultivaraiteForecastingModel
was added to differentiate 'local' models that can be trained on both uni- and multivariate time series.