Skip to content

Commit

Permalink
&yarnabrina [BUG] correct ForecastX behaviour in case of multivaria…
Browse files Browse the repository at this point in the history
…te `y` (#4719)

For multivariate `y`, the compositor `ForecastX` would always broadcast
over variables of `y`, even if wrapped forecasters were multivariate.

However, `ForecastX` itself should never broadcast over `y`, that should
be left to component forecasters.

Fixed here, this also fixes #4718

Includes a slightly extended test case from the bug report #4718 by
@yarnabrina
  • Loading branch information
fkiraly committed Jun 20, 2023
1 parent 77f305a commit d9dc6d3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions sktime/forecasting/compose/_pipeline.py
Expand Up @@ -1229,6 +1229,7 @@ class ForecastX(BaseForecaster):
_tags = {
"X_inner_mtype": SUPPORTED_MTYPES,
"y_inner_mtype": SUPPORTED_MTYPES,
"scitype:y": "both",
"X-y-must-have-same-index": False,
"fit_is_empty": False,
"ignores-exogeneous-X": False,
Expand Down
42 changes: 42 additions & 0 deletions sktime/forecasting/compose/tests/test_pipeline.py
Expand Up @@ -480,3 +480,45 @@ def test_forecastx_logic():

# compare that test and comparison case results are equal
assert np.allclose(y_pred, y_pred_manual)


def test_forecastx_attrib_broadcast():
"""Test ForecastX broadcasting and forecaster attributes."""
from sktime.forecasting.compose import ForecastX
from sktime.forecasting.naive import NaiveForecaster

df = pd.DataFrame(
{
"a": ["series_1", "series_1", "series_1"],
"b": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
"c": [1, 2, 3],
"d": [4, 5, 6],
"e": [7, 8, 9],
}
)
df = df.set_index(["a", "b"])

model = ForecastX(NaiveForecaster(), NaiveForecaster())

model_1 = model.clone()
model_1.fit(df[["c"]], X=df[["d", "e"]], fh=[1, 2, 3])

assert hasattr(model_1, "forecaster_X_")
assert isinstance(model_1.forecaster_X_, NaiveForecaster)
assert model_1.forecaster_X_.is_fitted

assert hasattr(model_1, "forecaster_y_")
assert isinstance(model_1.forecaster_y_, NaiveForecaster)
assert model_1.forecaster_y_.is_fitted

model_2 = model.clone()
model_2.fit(df[["c", "d"]], X=df[["e"]], fh=[1, 2, 3])
assert hasattr(model_2, "forecaster_X_")

assert hasattr(model_2, "forecaster_X_")
assert isinstance(model_2.forecaster_X_, NaiveForecaster)
assert model_2.forecaster_X_.is_fitted

assert hasattr(model_2, "forecaster_y_")
assert isinstance(model_2.forecaster_y_, NaiveForecaster)
assert model_2.forecaster_y_.is_fitted

0 comments on commit d9dc6d3

Please sign in to comment.