Skip to content

Commit

Permalink
fixed an issue using fit_transform() with reconciliators
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed Aug 23, 2022
1 parent f953005 commit 85cab08
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
10 changes: 5 additions & 5 deletions darts/dataprocessing/transformers/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_projection_matrix(series):
return np.concatenate([np.zeros((m, n - m)), np.eye(m)], axis=1)

@staticmethod
def ts_transform(series: TimeSeries) -> TimeSeries:
def ts_transform(series: TimeSeries, *args, **kwargs) -> TimeSeries:
S = _get_summation_matrix(series)
G = BottomUpReconciliator.get_projection_matrix(series)
return _reconcile_from_S_and_G(series, S, G)
Expand All @@ -103,12 +103,12 @@ class TopDownReconciliator(FittableDataTransformer):
"""

@staticmethod
def ts_fit(series: TimeSeries) -> np.ndarray:
def ts_fit(series: TimeSeries, *args, **kwargs) -> np.ndarray:
G = TopDownReconciliator.get_projection_matrix(series)
return G

@staticmethod
def ts_transform(series: TimeSeries, G: np.ndarray) -> TimeSeries:
def ts_transform(series: TimeSeries, G: np.ndarray, *args, **kwargs) -> TimeSeries:
S = _get_summation_matrix(series)
return _reconcile_from_S_and_G(series, S, G)

Expand Down Expand Up @@ -189,12 +189,12 @@ def __init__(self, method="ols"):
self.method = method

@staticmethod
def ts_fit(series: TimeSeries, method: str) -> np.ndarray:
def ts_fit(series: TimeSeries, method: str, *args, **kwargs) -> np.ndarray:
S, G = MinTReconciliator.get_matrices(series, method)
return S, G

@staticmethod
def ts_transform(series: TimeSeries, S_and_G) -> TimeSeries:
def ts_transform(series: TimeSeries, S_and_G, *args, **kwargs) -> TimeSeries:
S, G = S_and_G
return _reconcile_from_S_and_G(series, S, G)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def test_top_down(self):
recon.fit(self.pred)
self._assert_reconciliation(recon)

# fit_transform() should also work
recon = TopDownReconciliator()
_ = recon.fit_transform(self.pred)

def test_mint(self):
# ols
recon = MinTReconciliator("ols")
Expand All @@ -131,6 +135,10 @@ def test_mint(self):
recon.fit(self.series)
self._assert_reconciliation(recon)

# fit_transform() should also work
recon = MinTReconciliator()
_ = recon.fit_transform(self.series)

def test_summation_matrix(self):
np.testing.assert_equal(
_get_summation_matrix(self.series_complex),
Expand Down

0 comments on commit 85cab08

Please sign in to comment.