Skip to content

Commit

Permalink
Propagate static covs and hierarchy in missing value filler (#1076)
Browse files Browse the repository at this point in the history
* Propagate static covs and hierarchy in missing value filler

* Fix typo
  • Loading branch information
hrzn committed Jul 18, 2022
1 parent 1fadc58 commit 7f32ef6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
class MissingValuesFillerTestCase(unittest.TestCase):

time = pd.date_range("20130101", "20130130")
static_covariate = pd.DataFrame({"0": [1]})

const_series = TimeSeries.from_times_and_values(time, np.array([2.0] * len(time)))
const_series = TimeSeries.from_times_and_values(
time, np.array([2.0] * len(time)), static_covariates=static_covariate
)
const_series_with_holes = TimeSeries.from_times_and_values(
time, np.array([2.0] * 10 + [np.nan] * 5 + [2.0] * 10 + [np.nan] * 5)
time,
np.array([2.0] * 10 + [np.nan] * 5 + [2.0] * 10 + [np.nan] * 5),
static_covariates=static_covariate,
)

lin = [float(i) for i in range(len(time))]
Expand All @@ -36,3 +41,11 @@ def test_fill_lin_series_with_auto_value(self):
auto_transformer = MissingValuesFiller()
transformed = auto_transformer.transform(self.lin_series_with_holes)
self.assertEqual(self.lin_series, transformed)

def test_fill_static_covariates_preserved(self):
const_transformer = MissingValuesFiller(fill=2.0)
transformed = const_transformer.transform(self.const_series_with_holes)
self.assertEqual(
self.const_series.static_covariates.values,
transformed.static_covariates.values,
)
9 changes: 8 additions & 1 deletion darts/utils/missing_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def _const_fill(series: TimeSeries, fill: float = 0) -> TimeSeries:
series.pd_dataframe().fillna(value=fill),
freq=series.freq,
columns=series.columns,
static_covariates=series.static_covariates,
hierarchy=series.hierarchy,
)


Expand Down Expand Up @@ -160,4 +162,9 @@ def _auto_fill(series: TimeSeries, **interpolate_kwargs) -> TimeSeries:
interpolate_kwargs["limit_direction"] = "both"
interpolate_kwargs["inplace"] = True
series_temp.interpolate(**interpolate_kwargs)
return TimeSeries.from_dataframe(series_temp, freq=series.freq)
return TimeSeries.from_dataframe(
series_temp,
freq=series.freq,
static_covariates=series.static_covariates,
hierarchy=series.hierarchy,
)

0 comments on commit 7f32ef6

Please sign in to comment.