Skip to content

Commit

Permalink
Create AbstractPipeline (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Feb 28, 2022
1 parent db5967d commit 09a7938
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Add option `season_number` to DateFlagsTransform ([#567](https://github.com/tinkoff-ai/etna/pull/567))
-

-
-
- Create `AbstaractPipeline` ([#573](https://github.com/tinkoff-ai/etna/pull/573))
-
### Changed
- Change the way `ProphetModel` works with regressors ([#383](https://github.com/tinkoff-ai/etna/pull/383))
- Change the way `SARIMAXModel` works with regressors ([#380](https://github.com/tinkoff-ai/etna/pull/380))
Expand Down
82 changes: 82 additions & 0 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,91 @@
import warnings
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple

import pandas as pd

from etna.core import BaseMixin
from etna.datasets import TSDataset
from etna.metrics import Metric


class AbstractPipeline(ABC):
"""Interface for pipeline."""

@abstractmethod
def fit(self, ts: TSDataset) -> "AbstractPipeline":
"""Fit the Pipeline.
Parameters
----------
ts:
Dataset with timeseries data
Returns
-------
self:
Fitted Pipeline instance
"""
pass

@abstractmethod
def forecast(self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)) -> TSDataset:
"""Make predictions.
Parameters
----------
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval
Returns
-------
forecast:
Dataset with predictions
"""
pass

@abstractmethod
def backtest(
self,
ts: TSDataset,
metrics: List[Metric],
n_folds: int = 5,
mode: str = "expand",
aggregate_metrics: bool = False,
n_jobs: int = 1,
joblib_params: Dict[str, Any] = dict(verbose=11, backend="multiprocessing", mmap_mode="c"),
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Run backtest with the pipeline.
Parameters
----------
ts:
Dataset to fit models in backtest
metrics:
List of metrics to compute for each fold
n_folds:
Number of folds
mode:
One of 'expand', 'constant' -- train generation policy
aggregate_metrics:
If True aggregate metrics above folds, return raw metrics otherwise
n_jobs:
Number of jobs to run in parallel
joblib_params:
Additional parameters for joblib.Parallel
Returns
-------
metrics_df, forecast_df, fold_info_df:
Metrics dataframe, forecast dataframe and dataframe with information about folds
"""


class BasePipeline(ABC, BaseMixin):
Expand Down

0 comments on commit 09a7938

Please sign in to comment.