Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/metrics #2284

Merged
merged 40 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
630bb44
metrics: avoid series intersection if they share the same time index
dennisbader Mar 4, 2024
4f18d0e
Merge branch 'master' into refactor/metrics
dennisbader Mar 5, 2024
18522f8
move timerseries utils to ts_utils.py
dennisbader Mar 6, 2024
0350f52
move around helper functions to import utils from timeseries
dennisbader Mar 6, 2024
eba33c9
improve time series slicing
dennisbader Mar 11, 2024
69c2ef2
add unit tests for slice intersect
dennisbader Mar 11, 2024
216f3cd
add test for multivariate series
dennisbader Mar 11, 2024
451c57d
add tests for slice intersect values
dennisbader Mar 11, 2024
bfdbece
replace raise_if_not in favor of raise_log
dennisbader Mar 11, 2024
d7a7b58
update metrics docs
dennisbader Mar 12, 2024
bc519d0
update all metrics logic
dennisbader Mar 12, 2024
89a9a5c
fix quantile_loss
dennisbader Mar 13, 2024
88c212f
fix notebooks
dennisbader Mar 14, 2024
d7a1244
update changelog
dennisbader Mar 14, 2024
7376fec
Merge branch 'master' into refactor/metrics
dennisbader Mar 14, 2024
ce5affd
Merge branch 'master' into refactor/metrics
dennisbader Mar 16, 2024
de20ec3
refactor backtest
dennisbader Mar 17, 2024
8ac0ca6
update docs
dennisbader Mar 17, 2024
86f5ed8
update returns docs for metrics
dennisbader Mar 17, 2024
317eb10
refactor backtest
dennisbader Mar 18, 2024
61cd3fa
refactor wrappers to support metrics with insample
dennisbader Mar 18, 2024
37518b1
update changelog
dennisbader Mar 18, 2024
d9cd67a
add more tests for metric output shapes
dennisbader Mar 19, 2024
f67484a
sanity checks for reduction functions
dennisbader Mar 20, 2024
6107317
add new metrics
dennisbader Mar 20, 2024
59dd7b6
update docs
dennisbader Mar 21, 2024
7c105ec
make ForecastingModel.residuals work with multivariate series, custom…
dennisbader Mar 22, 2024
7e28997
make hist fc and backtest output consistent
dennisbader Mar 22, 2024
9652b9a
add historical forecasts tests for output type
dennisbader Mar 24, 2024
f8a599e
improve backtest
dennisbader Mar 25, 2024
6971d50
update changelog
dennisbader Mar 26, 2024
404a2f1
scaled metrics now accept that are overlapping into
dennisbader Mar 26, 2024
f803f33
added support for scaled metrics to backtest and metric_kwargs
dennisbader Mar 26, 2024
1417a8f
update changelog
dennisbader Mar 26, 2024
545c86d
add unit tests for ts utils
dennisbader Mar 26, 2024
6fb1a72
multiply quantile loss by 2
dennisbader Mar 27, 2024
9cb7dbb
fix notebooks p1
dennisbader Mar 27, 2024
6a56c31
last changes for PR review
dennisbader Mar 27, 2024
45d5a00
apply suggestions from pr review
dennisbader Apr 4, 2024
0c433e8
udpate quickstart notebook
dennisbader Apr 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
75 changes: 75 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,87 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- 🚀🚀🚀 Improvements to metrics, historical forecasts, backtest, and residuals through major refactor. The refactor includes optimization of multiple process and improvemenets to consistency, reliability, and the documentation. Some of these necessary changes come at the cost of breaking changes. [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Metrics:
- Optimized all metrics, which now run >20 times faster than before for univariate series, and >>20 times for multivariate series. This boosts direct metric computations as well as backtesting and residuals computation!
- Added new metrics:
- Time aggregated metric `merr()` (Mean Error)
- Time aggregated scaled metrics `rmsse()`, and `msse()`: The (Root) Mean Squared Scaled Error.
- "Per time step" metrics that return a metric score per time step: `err()` (Error), `ae()` (Absolute Error), `se()` (Squared Error), `sle()` (Squared Log Error), `ase()` (Absolute Scaled Error), `sse` (Squared Scaled Error), `ape()` (Absolute Percentage Error), `sape()` (symmetric Absolute Percentage Error), `arre()` (Absolute Ranged Relative Error), `ql` (Quantile Loss)
- All scaled metrics now accept `insample` series that can be overlapping into `pred_series` (before that had to end exactly one step before `pred_series`). Darts will handle the correct time extraction for you.
- Improvements to the documentation:
- Added a summary list of all metrics to the [metrics documentation page](https://unit8co.github.io/darts/generated_api/darts.metrics.html)
- Standardized the documentation of each metric (added formula, improved return documentation, ...)
- 🔴 Improved metric output consistency based on the type of input `series`, and the applied reductions:
- `float`: A single metric score for:
- single univariate series
- single multivariate series with `component_reduction`
- sequence (list) of uni/multivariate series with `series_reduction` and `component_reduction` (and `time_reduction` for "per time step metrics")
- `np.ndarray`: A numpy array of metric scores. The array has shape (n time steps, n components) without time and component reductions. The time dimension is only available for "per time step" metrics. For:
- single multivariate series and at least `component_reduction=None` for time aggregated metrics.
- single uni/multivariate series and at least `time_reduction=None` for "per time step metrics"
- sequence of uni/multivariate series including `series_reduction` and at least one of `component_reduction=None` or `time_reduction=None` for "per time step metrics"
- `List[float]`: Same as for type `float` but for a sequence of series
- `List[np.ndarray]` Same as for type `np.ndarray` but for a sequence of series
- 🔴 Other breaking changes:
- `quantile_loss()`:
- renamed to `mql()` (Mean Quantile Loss)
- renamed quantile parameter `tau` to `q`
- the metric is now multiplied by a factor `2` to make the loss more interpretable (e.g. for `q=0.5` it is identical to the `MAE`)
- `rho_risk()`:
- renamed to `qr()` (Quantile Risk)
- renamed quantile parameter `rho` to `q`
- Renamed metric parameter `reduction` to `series_reduction`
- Renamed metric parameter `inter_reduction` to `component_reduction`
- Scaled metrics do not allow seasonality inference anymore with `m=None`.
- Custom metrics using decorators `multi_ts_support` and `multivariate_support` must now act on multivariate series (possibly containing missing values) instead of univariate series.
- `ForecastingModel.historical_forecasts()`:
- 🔴 Improved historical forecasts output consistency based on the type of input `series`: If `series` is a sequence, historical forecasts will always return a sequence/list of the same length (instead of trying to reduce to a `TimeSeries` object).
- `TimeSeries`: A single historical forecast for a single `series` and `last_points_only=True`: it contains only the predictions at step `forecast_horizon` from all historical forecasts.
- `List[TimeSeries]` A list of historical forecasts for:
- a sequence (list) of `series` and `last_points_only=True`: for each series, it contains only the predictions at step `forecast_horizon` from all historical forecasts.
- a single `series` and `last_points_only=False`: for each historical forecast, it contains the entire horizon `forecast_horizon`.
- `List[List[TimeSeries]]` A list of lists of historical forecasts for a sequence of `series` and `last_points_only=False`. For each series, and historical forecast, it contains the entire horizon `forecast_horizon`. The outer list is over the series provided in the input sequence, and the inner lists contain the historical forecasts for each series.
- `ForecastingModel.backtest()`:
- Metrics are now computed only once between all `series` and `historical_forecasts`, significantly speeding things up when using a large number of `series`.
- Added support for scaled metrics as `metric` (such as `ase`, `mase`, ...). No extra code required, backtest extracts the correct `insample` series for you.
- Added support for passing additional metric arguments with parameter `metric_kwargs`. This allows for example parallelization of the metric computation with `n_jobs`, customize the metric reduction with `*_reduction`, specify seasonality `m` for scaled metrics, etc..
- 🔴 Improved backtest output consistency based on the type of input `series`, `historical_forecast`, and the applied backtest reduction:
- `float`: A single backtest score for single uni/multivariate series, a single `metric` function and:
- `historical_forecasts` generated with `last_points_only=True`
- `historical_forecasts` generated with `last_points_only=False` and using a backtest `reduction`
- `np.ndarray`: An numpy array of backtest scores. For single series and one of:
- a single `metric` function, `historical_forecasts` generated with `last_points_only=False` and backtest `reduction=None`. The output has shape (n forecasts,).
- multiple `metric` functions and `historical_forecasts` generated with `last_points_only=False`. The output has shape (n metrics,) when using a backtest `reduction`, and (n metrics, n forecasts) when `reduction=None`
- multiple uni/multivariate series including `series_reduction` and at least one of `component_reduction=None` or `time_reduction=None` for "per time step metrics"
- `List[float]`: Same as for type `float` but for a sequence of series. The returned metric list has length `len(series)` with the `float` metric for each input `series`.
- `List[np.ndarray]` Same as for type `np.ndarray` but for a sequence of series. The returned metric list has length `len(series)` with the `np.ndarray` metrics for each input `series`.
- 🔴 Other breaking changes:
- `reduction` callable now acts on `axis=1` rather than `axis=0` to aggregate the metrics per series.
- backtest will now raise an error when user supplied `historical_forecasts` don't have the expected format based on input `series` and the `last_points_only` value.
- `ForecastingModel.residuals()`. While the default behavior of `residuals()` remains identical, the method is now very similar to `backtest()` but that it computes a "per time step" `metric` on `historical_forecasts`:
- Added support for multivariate `series`.
- Added support for all `historical_forecasts()` parameters to generate the historical forecasts for the residuals computation.
- Added support for pre-computed historical forecasts with parameter `historical_forecasts`.
- Added support for computing the residuals with any of Darts' "per time step" metric with parameter `metric` (e.g. `err()`, `ae()`, `ape()`, ...). By default uses `err()` (Error).
- Added support for parallelizing the metric computation across historical forecasts with parameter `n_jobs`.
- 🔴 Improved residuals output and consistency based on the type of input `series` and `historical_forecast`:
- `TimeSeries`: Residual `TimeSeries` for a single `series` and `historical_forecasts` generated with `last_points_only=True`.
- `List[TimeSeries]` A list of residual `TimeSeries` for a sequence (list) of `series` with `last_points_only=True`. The residual list has length `len(series)`.
- `List[List[TimeSeries]]` A list of lists of residual `TimeSeries` for a sequence of `series` with `last_points_only=False`. The outer residual list has length `len(series)`. The inner lists consist of the residuals from all possible series-specific historical forecasts.
- Improvements to `TimeSeries`: [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Performance boost for methods: `slice_intersect()`, `has_same_time_as()`
- New method `slice_intersect_values()`, which returns the sliced values of a series, where the time index has been intersected with another series.
- 🔴 Moved utils functions to clearly separate Darts-specific from non-Darts-specific logic: [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Moved function `generate_index()` from `darts.utils.timeseries_generation` to `darts.utils.utils`
- Moved functions `retain_period_common_to_all()`, `series2seq()`, `seq2series()`, `get_single_series()` from `darts.utils.utils` to `darts.utils.ts_utils`.
- Improvements to `ForecastingModel`: [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo).
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`.
- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska).
- `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series.

**Fixed**
- fixed a bug in `quantile_loss`, where the loss was computed on all samples rather than only on the predicted quantiles. [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Fixed type hint warning "Unexpected argument" when calling `historical_forecasts()` caused by the `_with_sanity_checks` decorator. The type hinting is now properly configured to expect any input arguments and return the output type of the method for which the sanity checks are performed for. [#2286](https://github.com/unit8co/darts/pull/2286) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**
Expand Down
2 changes: 1 addition & 1 deletion darts/dataprocessing/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from darts import TimeSeries
from darts.dataprocessing.transformers import FittableDataTransformer
from darts.logging import get_logger, raise_if, raise_log
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

try:
from typing import Literal
Expand Down
8 changes: 3 additions & 5 deletions darts/dataprocessing/encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@
from darts.dataprocessing.transformers import FittableDataTransformer
from darts.logging import get_logger, raise_if, raise_if_not
from darts.timeseries import DIMS
from darts.utils.timeseries_generation import (
datetime_attribute_timeseries,
generate_index,
)
from darts.utils.utils import seq2series, series2seq
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.ts_utils import seq2series, series2seq
from darts.utils.utils import generate_index

SupportedTimeSeries = Union[TimeSeries, Sequence[TimeSeries]]
logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion darts/dataprocessing/transformers/midas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from darts.logging import get_logger, raise_log
from darts.timeseries import _finite_rows_boundaries
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

logger = get_logger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion darts/dataprocessing/transformers/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
BaseDataTransformer,
FittableDataTransformer,
)
from darts.logging import get_logger, raise_if_not
from darts.timeseries import TimeSeries
from darts.utils.utils import raise_if_not

logger = get_logger(__name__)


def _get_summation_matrix(series: TimeSeries):
Expand All @@ -37,6 +39,7 @@ def _get_summation_matrix(series: TimeSeries):
raise_if_not(
series.has_hierarchy,
"The provided series must have a hierarchy defined for reconciliation to be performed.",
logger=logger,
)
hierarchy = series.hierarchy
components_seq = list(series.components)
Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/tft_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from darts.explainability.explainability import _ForecastingModelExplainer
from darts.logging import get_logger, raise_log
from darts.models import TFTModel
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

try:
from typing import Literal
Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils.statistics import stationarity_tests
from darts.utils.utils import series2seq
from darts.utils.ts_utils import series2seq

logger = get_logger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion darts/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
import warnings
from typing import NoReturn


def get_logger(name):
Expand Down Expand Up @@ -104,7 +105,9 @@ def raise_if(
raise_if_not(not condition, message, logger)


def raise_log(exception: Exception, logger: logging.Logger = get_logger("main_logger")):
def raise_log(
exception: Exception, logger: logging.Logger = get_logger("main_logger")
) -> NoReturn:
"""
Can be used to replace "raise" when throwing an exception to ensure the logging
of the exception. After logging it, the exception is raised.
Expand Down
63 changes: 61 additions & 2 deletions darts/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,80 @@
"""
Metrics
-------

For deterministic forecasts (point predictions with `num_samples == 1`):
- Aggregated over time:
Absolute metrics:
- :func:`MERR <darts.metrics.metrics.merr>`: Mean Error
- :func:`MAE <darts.metrics.metrics.mae>`: Mean Absolute Error
- :func:`MSE <darts.metrics.metrics.mse>`: Mean Squared Error
- :func:`RMSE <darts.metrics.metrics.rmse>`: Root Mean Squared Error
- :func:`RMSLE <darts.metrics.metrics.rmsle>`: Root Mean Squared Log Error

Relative metrics:
- :func:`MASE <darts.metrics.metrics.mase>`: Mean Absolute Scaled Error
- :func:`MSSE <darts.metrics.metrics.msse>`: Mean Squared Scaled Error
- :func:`RMSSE <darts.metrics.metrics.rmsse>`: Root Mean Squared Scaled Error
- :func:`MAPE <darts.metrics.metrics.mape>`: Mean Absolute Percentage Error
- :func:`sMAPE <darts.metrics.metrics.smape>`: symmetric Mean Absolute Percentage Error
- :func:`OPE <darts.metrics.metrics.ope>`: Overall Percentage Error
- :func:`MARRE <darts.metrics.metrics.marre>`: Mean Absolute Ranged Relative Error

Other metrics:
- :func:`R2 <darts.metrics.metrics.r2_score>`: Coefficient of Determination
- :func:`CV <darts.metrics.metrics.coefficient_of_variation>`: Coefficient of Variation

- Per time step:
Absolute metrics:
- :func:`ERR <darts.metrics.metrics.err>`: Error
- :func:`AE <darts.metrics.metrics.ae>`: Absolute Error
- :func:`SE <darts.metrics.metrics.se>`: Squared Error
- :func:`SLE <darts.metrics.metrics.sle>`: Squared Log Error

Relative metrics:
- :func:`ASE <darts.metrics.metrics.ase>`: Absolute Scaled Error
- :func:`SSE <darts.metrics.metrics.sse>`: Squared Scaled Error
- :func:`APE <darts.metrics.metrics.ape>`: Absolute Percentage Error
- :func:`sAPE <darts.metrics.metrics.sape>`: symmetric Absolute Percentage Error
- :func:`ARRE <darts.metrics.metrics.arre>`: Absolute Ranged Relative Error

For probabilistic forecasts (storchastic predictions with `num_samples >> 1`):
- Aggregated over time:
- :func:`MQL <darts.metrics.metrics.mql>`: Mean Quantile Loss
- :func:`QR <darts.metrics.metrics.qr>`: Quantile Risk
- Per time step:
- :func:`QL <darts.metrics.metrics.ql>`: Quantile Loss

For Dynamic Time Warping (DTW) (aggregated over time):
- :func:`DTW <darts.metrics.metrics.dtw_metric>`: Dynamic Time Warping Metric
"""

from .metrics import (
ae,
ape,
arre,
ase,
coefficient_of_variation,
dtw_metric,
err,
mae,
mape,
marre,
mase,
merr,
mql,
mse,
msse,
ope,
quantile_loss,
ql,
qr,
r2_score,
rho_risk,
rmse,
rmsle,
rmsse,
sape,
se,
sle,
smape,
sse,
)