Skip to content

Fix bug with hardcoded frequency in PytorchForecastingTransform #107

Merged
merged 8 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102))
- Fix bug with hardcoded frequency in PytorchForecastingTransform ([#107](https://github.com/tinkoff-ai/etna-ts/pull/107))

## [1.1.1] - 2021-09-23
### Fixed
Expand Down
19 changes: 16 additions & 3 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import log_decorator
from etna.transforms import PytorchForecastingTransform


class DeepARModel(Model):
Expand Down Expand Up @@ -89,6 +90,16 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> DeepAR:
dropout=self.dropout,
)

@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
)

@log_decorator
def fit(self, ts: TSDataset) -> "DeepARModel":
"""
Expand All @@ -103,7 +114,8 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
-------
DeepARModel
"""
self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train)
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

self.trainer = pl.Trainer(
logger=tslogger.pl_loggers,
Expand All @@ -113,7 +125,7 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
gradient_clip_val=self.gradient_clip_val,
)

train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)
train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)

Expand All @@ -134,7 +146,8 @@ def forecast(self, ts: TSDataset) -> TSDataset:
TSDataset
TSDataset with predictions.
"""
prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader(
pf_transform = self._get_pf_transform(ts)
prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader(
train=False, batch_size=self.batch_size * 2
)

Expand Down
21 changes: 17 additions & 4 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import log_decorator
from etna.transforms import PytorchForecastingTransform


class TFTModel(Model):
Expand Down Expand Up @@ -97,6 +98,16 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> TemporalFusionTransfor
hidden_continuous_size=self.hidden_continuous_size,
)

@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
)

@log_decorator
def fit(self, ts: TSDataset) -> "TFTModel":
"""
Expand All @@ -111,7 +122,8 @@ def fit(self, ts: TSDataset) -> "TFTModel":
-------
TFTModel
"""
self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train)
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

self.trainer = pl.Trainer(
logger=tslogger.pl_loggers,
Expand All @@ -121,14 +133,14 @@ def fit(self, ts: TSDataset) -> "TFTModel":
gradient_clip_val=self.gradient_clip_val,
)

train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)
train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)

return self

@log_decorator
def forecast(self, ts: TSDataset) -> pd.DataFrame:
def forecast(self, ts: TSDataset) -> TSDataset:
"""
Predict future.

Expand All @@ -142,7 +154,8 @@ def forecast(self, ts: TSDataset) -> pd.DataFrame:
TSDataset
TSDataset with predictions.
"""
prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader(
pf_transform = self._get_pf_transform(ts)
prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader(
train=False, batch_size=self.batch_size * 2
)

Expand Down
55 changes: 34 additions & 21 deletions etna/transforms/pytorch_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(
):
"""Parameters for TimeSeriesDataSet object.

Notes
-----
This transform should be added at the very end of `transforms` parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should raise Exception?

Reference
---------
https://github.com/jdb78/pytorch-forecasting/blob/v0.8.5/pytorch_forecasting/data/timeseries.py#L117
Expand Down Expand Up @@ -76,6 +80,14 @@ def __init__(
self.lags = lags
self.scalers = scalers

@staticmethod
def _calculate_freq_unit(freq: str) -> pd.Timedelta:
"""Calculate frequency unit by its string representation."""
if freq[0].isdigit():
return pd.Timedelta(freq)
else:
return pd.Timedelta(1, unit=freq)

def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform":
"""
Fit TimeSeriesDataSet.
Expand All @@ -89,22 +101,22 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform":
-------
PytorchForecastingTransform
"""
ts = TSDataset(df, "1d")
self.freq = ts.freq
ts = ts.to_pandas(flatten=True)
ts = ts.dropna()
self.min_timestamp = ts.timestamp.min()
self.freq = pd.infer_freq(df.index)
ts = TSDataset(df, self.freq)
df_flat = ts.to_pandas(flatten=True)
df_flat = df_flat.dropna()
self.min_timestamp = df_flat.timestamp.min()

if self.time_varying_known_categoricals:
for feature_name in self.time_varying_known_categoricals:
ts[feature_name] = ts[feature_name].astype(str)
df_flat[feature_name] = df_flat[feature_name].astype(str)

ts["time_idx"] = ts["timestamp"] - self.min_timestamp
ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq)
ts["time_idx"] = ts["time_idx"].astype(int)
freq_unit = self._calculate_freq_unit(self.freq)
df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit
df_flat["time_idx"] = df_flat["time_idx"].astype(int)

pf_dataset = TimeSeriesDataSet(
ts,
df_flat,
time_idx="time_idx",
target="target",
group_ids=["segment"],
Expand Down Expand Up @@ -151,24 +163,25 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
We save TimeSeriesDataSet in instance to use it in the model.
It`s not right pattern of using Transforms and TSDataset.
"""
ts = TSDataset(df, "1d")
ts = ts.to_pandas(flatten=True)
ts = ts[ts.timestamp >= self.min_timestamp]
ts = ts.fillna(0)

ts["time_idx"] = ts["timestamp"] - self.min_timestamp
ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq)
ts["time_idx"] = ts["time_idx"].astype(int)
ts = TSDataset(df, self.freq)
df_flat = ts.to_pandas(flatten=True)
df_flat = df_flat[df_flat.timestamp >= self.min_timestamp]
df_flat = df_flat.fillna(0)

freq_unit = self._calculate_freq_unit(self.freq)
df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit
df_flat["time_idx"] = df_flat["time_idx"].astype(int)

if self.time_varying_known_categoricals:
for feature_name in self.time_varying_known_categoricals:
ts[feature_name] = ts[feature_name].astype(str)
df_flat[feature_name] = df_flat[feature_name].astype(str)

if inspect.stack()[1].function == "make_future":
pf_dataset_predict = TimeSeriesDataSet.from_parameters(
self.pf_dataset_params, ts, predict=True, stop_randomization=True
self.pf_dataset_params, df_flat, predict=True, stop_randomization=True
)
self.pf_dataset_predict = pf_dataset_predict
else:
pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, ts)
pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, df_flat)
self.pf_dataset_train = pf_dataset_train
return df
29 changes: 24 additions & 5 deletions tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,29 @@
from etna.datasets.tsdataset import TSDataset
from etna.metrics import MAE
from etna.models.nn import DeepARModel
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform


def test_fit_wrong_order_transform(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
add_const = AddConstTransform(in_column="target", value=1.0)
pft = PytorchForecastingTransform(
max_encoder_length=21,
max_prediction_length=8,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)

ts.fit_transform([pft, add_const])

model = DeepARModel(max_epochs=300, learning_rate=[0.1])
with pytest.raises(ValueError, match="add PytorchForecastingTransform"):
model.fit(ts)


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8, 21])
def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):
Expand All @@ -31,8 +50,8 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "1d")
ts_test = TSDataset(TSDataset.to_dataset(test), "1d")
ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False)
pft = PytorchForecastingTransform(
max_encoder_length=21,
Expand All @@ -45,10 +64,10 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):

ts_train.fit_transform([dft, pft])

tftmodel = DeepARModel(max_epochs=300, learning_rate=[0.1])
model = DeepARModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
tftmodel.fit(ts_train)
ts_pred = tftmodel.forecast(ts_pred)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")

Expand Down
31 changes: 26 additions & 5 deletions tests/test_models/nn/test_tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,31 @@
from etna.datasets.tsdataset import TSDataset
from etna.metrics import MAE
from etna.models.nn import TFTModel
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform


def test_fit_wrong_order_transform(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
add_const = AddConstTransform(in_column="target", value=1.0)
pft = PytorchForecastingTransform(
max_encoder_length=21,
min_encoder_length=21,
max_prediction_length=8,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
target_normalizer=None,
)

ts.fit_transform([pft, add_const])

model = TFTModel(max_epochs=300, learning_rate=[0.1])
with pytest.raises(ValueError, match="add PytorchForecastingTransform"):
model.fit(ts)


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8, 21])
def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):
Expand All @@ -31,8 +52,8 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "1d")
ts_test = TSDataset(TSDataset.to_dataset(test), "1d")
ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False)
pft = PytorchForecastingTransform(
max_encoder_length=21,
Expand All @@ -47,10 +68,10 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):

ts_train.fit_transform([dft, pft])

tftmodel = TFTModel(max_epochs=300, learning_rate=[0.1])
model = TFTModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
tftmodel.fit(ts_train)
ts_pred = tftmodel.forecast(ts_pred)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")
assert mae(ts_test, ts_pred) < 0.24
25 changes: 25 additions & 0 deletions tests/test_transforms/test_pytorch_forecasting_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from etna.transforms import PytorchForecastingTransform


@pytest.mark.parametrize("days_offset", [1, 2, 5, 10])
def test_time_idx(days_offset, example_tsds):
"""Check that PytorchForecastingTransform works with different frequencies correctly."""
df = example_tsds.to_pandas()
new_df = df.loc[df.index[::days_offset]]

transform = PytorchForecastingTransform(
max_encoder_length=3,
min_encoder_length=3,
max_prediction_length=3,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
)
transform.fit_transform(new_df)

time_idx = transform.pf_dataset_train.data["time"].tolist()
expected_len = new_df.shape[0]
expected_list = list(range(expected_len)) * len(example_tsds.segments)
assert time_idx == expected_list