Issue with forecaster.update on multi indexed data frame #5833
-
I am trying to update a forecaster but it doesn't seem to work on my multi indexed data frame: python code:
df = pd.DataFrame
df = pd.DataFrame([['A', '2024-01-01', 1, 11], ['A', '2024-01-02', 2, 12], ['A', '2024-01-03', 3, 13], ['A', '2024-01-04', 4, 14],\
['B', '2024-01-01', 101, 1], ['B', '2024-01-02', 102, 2], ['B', '2024-01-03', 103, 3], ['B', '2024-01-04', 104, 4],\
['C', '2024-01-01', 10000, 10], ['C', '2024-01-02', 20000, 20], ['C', '2024-01-03', 30000, 30], ['C', '2024-01-04', 40000, 43]],\
columns=['Patients', 'Dates', 'M1', 'M2'])
df['Dates']= pd.to_datetime(df['Dates'])
df = df.reset_index(drop = True)
df = df.set_index(["Patients", "Dates"])
y_train, y_test = temporal_train_test_split(df, test_size=1)
fh = 1
forecaster = VAR()
forecaster.fit(y_train)
y_pred = forecaster.predict(fh=fh)
Day = y_test.index.get_level_values("Dates")[-1]
y_update = y_test[y_test.index.get_level_values("Dates") == Day]
forecaster.update(y_update)
y_pred_updated = forecaster.predict(fh) Code might look overcomplicated, due to extraction from a code iterating on dates and my poor knowledge of python with this message: |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
This looks like a genuine bug, I am moving it to the bug tracker. |
Beta Was this translation helpful? Give feedback.
-
@ManuB68 This is a bug/issue for sure, but this is not caused by multi-index. I use multi-index data with update in our production work as well, and it works fine. The error is caused because Whether it's a bug that we can handle is something we can discuss in the issue @fkiraly created, but to handle your issue, you can consider using a range index instead of datetime index. I use that to handle some other month/week seasonality issues as well, and it does not face this problem. Here's an example: >>>
>>> import pandas
>>> from sktime.forecasting.naive import NaiveForecaster
>>> from sktime.split import temporal_train_test_split
>>>
>>> model = NaiveForecaster()
>>>
>>> # datetime index
>>> sample_data = pandas.DataFrame(
... data={"M1": [1, 2, 3, 4], "M2": [11, 12, 13, 14]},
... index=pandas.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"]),
... )
>>>
>>> y_train, y_test = temporal_train_test_split(sample_data, test_size=1)
>>>
>>> # fails even with non-multi-index data
>>> model_1 = model.clone()
>>>
>>> model_1.fit(y_train)
NaiveForecaster()
>>> model_1.update(y_test)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 886, in update
self._update_y_X(y_inner, X_inner)
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 1588, in _update_y_X
self._set_cutoff_from_y(y)
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 1643, in _set_cutoff_from_y
cutoff_idx = get_cutoff(y, self.cutoff, return_index=True)
File "/home/anirban/sktime-fork/sktime/datatypes/_utilities.py", line 296, in get_cutoff
return sub_idx(obj.index, ix) if return_index else obj.index[ix]
File "/home/anirban/sktime-fork/sktime/datatypes/_utilities.py", line 282, in sub_idx
res.freq = pd.infer_freq(idx)
File "/home/anirban/conda-environments/sktime/lib/python3.10/site-packages/pandas/tseries/frequencies.py", line 155, in infer_freq
inferer = _FrequencyInferer(index)
File "/home/anirban/conda-environments/sktime/lib/python3.10/site-packages/pandas/tseries/frequencies.py", line 189, in __init__
raise ValueError("Need at least 3 dates to infer frequency")
ValueError: Need at least 3 dates to infer frequency
>>>
>>> # fails even during fit
>>> model_2 = model.clone()
>>>
>>> model_2.fit(y_test)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 369, in fit
self._update_y_X(y_inner, X_inner)
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 1588, in _update_y_X
self._set_cutoff_from_y(y)
File "/home/anirban/sktime-fork/sktime/forecasting/base/_base.py", line 1643, in _set_cutoff_from_y
cutoff_idx = get_cutoff(y, self.cutoff, return_index=True)
File "/home/anirban/sktime-fork/sktime/datatypes/_utilities.py", line 296, in get_cutoff
return sub_idx(obj.index, ix) if return_index else obj.index[ix]
File "/home/anirban/sktime-fork/sktime/datatypes/_utilities.py", line 282, in sub_idx
res.freq = pd.infer_freq(idx)
File "/home/anirban/conda-environments/sktime/lib/python3.10/site-packages/pandas/tseries/frequencies.py", line 155, in infer_freq
inferer = _FrequencyInferer(index)
File "/home/anirban/conda-environments/sktime/lib/python3.10/site-packages/pandas/tseries/frequencies.py", line 189, in __init__
raise ValueError("Need at least 3 dates to infer frequency")
ValueError: Need at least 3 dates to infer frequency
>>>
>>> # range index
>>> sample_data = pandas.DataFrame(data={"M1": [1, 2, 3, 4], "M2": [11, 12, 13, 14]})
>>>
>>> y_train, y_test = temporal_train_test_split(sample_data, test_size=1)
>>>
>>> # works
>>> model_3 = model.clone()
>>>
>>> model_3.fit(y_train)
NaiveForecaster()
>>> model_3.update(y_test)
/home/anirban/sktime-fork/sktime/forecasting/base/_base.py:1928: UserWarning: NotImplementedWarning: NaiveForecaster does not have a custom `update` method implemented. NaiveForecaster will be refit each time `update` is called with update_params=True. To refit less often, use the wrappers in the forecasting.stream module, e.g., UpdateEvery.
warn(
/home/anirban/sktime-fork/sktime/forecasting/base/_base.py:1928: UserWarning: NotImplementedWarning: NaiveForecaster does not have a custom `update` method implemented. NaiveForecaster will be refit each time `update` is called with update_params=True. To refit less often, use the wrappers in the forecasting.stream module, e.g., UpdateEvery.
warn(
NaiveForecaster()
>>>
>>> # works
>>> model_4 = model.clone()
>>>
>>> model_4.fit(y_test)
NaiveForecaster()
>>> |
Beta Was this translation helpful? Give feedback.
Hi @ManuB68, in your example,
y_test
has only 1 row for each patient (coming from test_size=1). That's small and less that 3 which pandas wants for frequency interpretation.sktime
fits the model separately for each patient (for each of the lowest level of hierarchy), so it's 1 and not the shape of entirey_test
.I also use the similar format for my own office work, hundreds of series identifiers (e.g. store number), and update works. If you find a specific example where it works normally but just adding multiindex causes failure, please provide a minimal reproducible example with dummy values in the issue #5853.