Skip to content

Commit

Permalink
Refactor/metrics (#2284)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Apr 4, 2024
1 parent 91c7087 commit 5c97c9b
Show file tree
Hide file tree
Showing 46 changed files with 6,504 additions and 2,087 deletions.
75 changes: 75 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,87 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- 🚀🚀🚀 Improvements to metrics, historical forecasts, backtest, and residuals through major refactor. The refactor includes optimization of multiple process and improvemenets to consistency, reliability, and the documentation. Some of these necessary changes come at the cost of breaking changes. [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Metrics:
- Optimized all metrics, which now run >20 times faster than before for univariate series, and >>20 times for multivariate series. This boosts direct metric computations as well as backtesting and residuals computation!
- Added new metrics:
- Time aggregated metric `merr()` (Mean Error)
- Time aggregated scaled metrics `rmsse()`, and `msse()`: The (Root) Mean Squared Scaled Error.
- "Per time step" metrics that return a metric score per time step: `err()` (Error), `ae()` (Absolute Error), `se()` (Squared Error), `sle()` (Squared Log Error), `ase()` (Absolute Scaled Error), `sse` (Squared Scaled Error), `ape()` (Absolute Percentage Error), `sape()` (symmetric Absolute Percentage Error), `arre()` (Absolute Ranged Relative Error), `ql` (Quantile Loss)
- All scaled metrics now accept `insample` series that can be overlapping into `pred_series` (before that had to end exactly one step before `pred_series`). Darts will handle the correct time extraction for you.
- Improvements to the documentation:
- Added a summary list of all metrics to the [metrics documentation page](https://unit8co.github.io/darts/generated_api/darts.metrics.html)
- Standardized the documentation of each metric (added formula, improved return documentation, ...)
- 🔴 Improved metric output consistency based on the type of input `series`, and the applied reductions:
- `float`: A single metric score for:
- single univariate series
- single multivariate series with `component_reduction`
- sequence (list) of uni/multivariate series with `series_reduction` and `component_reduction` (and `time_reduction` for "per time step metrics")
- `np.ndarray`: A numpy array of metric scores. The array has shape (n time steps, n components) without time and component reductions. The time dimension is only available for "per time step" metrics. For:
- single multivariate series and at least `component_reduction=None` for time aggregated metrics.
- single uni/multivariate series and at least `time_reduction=None` for "per time step metrics"
- sequence of uni/multivariate series including `series_reduction` and at least one of `component_reduction=None` or `time_reduction=None` for "per time step metrics"
- `List[float]`: Same as for type `float` but for a sequence of series
- `List[np.ndarray]` Same as for type `np.ndarray` but for a sequence of series
- 🔴 Other breaking changes:
- `quantile_loss()`:
- renamed to `mql()` (Mean Quantile Loss)
- renamed quantile parameter `tau` to `q`
- the metric is now multiplied by a factor `2` to make the loss more interpretable (e.g. for `q=0.5` it is identical to the `MAE`)
- `rho_risk()`:
- renamed to `qr()` (Quantile Risk)
- renamed quantile parameter `rho` to `q`
- Renamed metric parameter `reduction` to `series_reduction`
- Renamed metric parameter `inter_reduction` to `component_reduction`
- Scaled metrics do not allow seasonality inference anymore with `m=None`.
- Custom metrics using decorators `multi_ts_support` and `multivariate_support` must now act on multivariate series (possibly containing missing values) instead of univariate series.
- `ForecastingModel.historical_forecasts()`:
- 🔴 Improved historical forecasts output consistency based on the type of input `series`: If `series` is a sequence, historical forecasts will always return a sequence/list of the same length (instead of trying to reduce to a `TimeSeries` object).
- `TimeSeries`: A single historical forecast for a single `series` and `last_points_only=True`: it contains only the predictions at step `forecast_horizon` from all historical forecasts.
- `List[TimeSeries]` A list of historical forecasts for:
- a sequence (list) of `series` and `last_points_only=True`: for each series, it contains only the predictions at step `forecast_horizon` from all historical forecasts.
- a single `series` and `last_points_only=False`: for each historical forecast, it contains the entire horizon `forecast_horizon`.
- `List[List[TimeSeries]]` A list of lists of historical forecasts for a sequence of `series` and `last_points_only=False`. For each series, and historical forecast, it contains the entire horizon `forecast_horizon`. The outer list is over the series provided in the input sequence, and the inner lists contain the historical forecasts for each series.
- `ForecastingModel.backtest()`:
- Metrics are now computed only once between all `series` and `historical_forecasts`, significantly speeding things up when using a large number of `series`.
- Added support for scaled metrics as `metric` (such as `ase`, `mase`, ...). No extra code required, backtest extracts the correct `insample` series for you.
- Added support for passing additional metric arguments with parameter `metric_kwargs`. This allows for example parallelization of the metric computation with `n_jobs`, customize the metric reduction with `*_reduction`, specify seasonality `m` for scaled metrics, etc..
- 🔴 Improved backtest output consistency based on the type of input `series`, `historical_forecast`, and the applied backtest reduction:
- `float`: A single backtest score for single uni/multivariate series, a single `metric` function and:
- `historical_forecasts` generated with `last_points_only=True`
- `historical_forecasts` generated with `last_points_only=False` and using a backtest `reduction`
- `np.ndarray`: An numpy array of backtest scores. For single series and one of:
- a single `metric` function, `historical_forecasts` generated with `last_points_only=False` and backtest `reduction=None`. The output has shape (n forecasts,).
- multiple `metric` functions and `historical_forecasts` generated with `last_points_only=False`. The output has shape (n metrics,) when using a backtest `reduction`, and (n metrics, n forecasts) when `reduction=None`
- multiple uni/multivariate series including `series_reduction` and at least one of `component_reduction=None` or `time_reduction=None` for "per time step metrics"
- `List[float]`: Same as for type `float` but for a sequence of series. The returned metric list has length `len(series)` with the `float` metric for each input `series`.
- `List[np.ndarray]` Same as for type `np.ndarray` but for a sequence of series. The returned metric list has length `len(series)` with the `np.ndarray` metrics for each input `series`.
- 🔴 Other breaking changes:
- `reduction` callable now acts on `axis=1` rather than `axis=0` to aggregate the metrics per series.
- backtest will now raise an error when user supplied `historical_forecasts` don't have the expected format based on input `series` and the `last_points_only` value.
- `ForecastingModel.residuals()`. While the default behavior of `residuals()` remains identical, the method is now very similar to `backtest()` but that it computes a "per time step" `metric` on `historical_forecasts`:
- Added support for multivariate `series`.
- Added support for all `historical_forecasts()` parameters to generate the historical forecasts for the residuals computation.
- Added support for pre-computed historical forecasts with parameter `historical_forecasts`.
- Added support for computing the residuals with any of Darts' "per time step" metric with parameter `metric` (e.g. `err()`, `ae()`, `ape()`, ...). By default uses `err()` (Error).
- Added support for parallelizing the metric computation across historical forecasts with parameter `n_jobs`.
- 🔴 Improved residuals output and consistency based on the type of input `series` and `historical_forecast`:
- `TimeSeries`: Residual `TimeSeries` for a single `series` and `historical_forecasts` generated with `last_points_only=True`.
- `List[TimeSeries]` A list of residual `TimeSeries` for a sequence (list) of `series` with `last_points_only=True`. The residual list has length `len(series)`.
- `List[List[TimeSeries]]` A list of lists of residual `TimeSeries` for a sequence of `series` with `last_points_only=False`. The outer residual list has length `len(series)`. The inner lists consist of the residuals from all possible series-specific historical forecasts.
- Improvements to `TimeSeries`: [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Performance boost for methods: `slice_intersect()`, `has_same_time_as()`
- New method `slice_intersect_values()`, which returns the sliced values of a series, where the time index has been intersected with another series.
- 🔴 Moved utils functions to clearly separate Darts-specific from non-Darts-specific logic: [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Moved function `generate_index()` from `darts.utils.timeseries_generation` to `darts.utils.utils`
- Moved functions `retain_period_common_to_all()`, `series2seq()`, `seq2series()`, `get_single_series()` from `darts.utils.utils` to `darts.utils.ts_utils`.
- Improvements to `ForecastingModel`: [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo).
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`.
- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska).
- `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series.

**Fixed**
- fixed a bug in `quantile_loss`, where the loss was computed on all samples rather than only on the predicted quantiles. [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Fixed type hint warning "Unexpected argument" when calling `historical_forecasts()` caused by the `_with_sanity_checks` decorator. The type hinting is now properly configured to expect any input arguments and return the output type of the method for which the sanity checks are performed for. [#2286](https://github.com/unit8co/darts/pull/2286) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**
Expand Down
2 changes: 1 addition & 1 deletion darts/dataprocessing/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from darts import TimeSeries
from darts.dataprocessing.transformers import FittableDataTransformer
from darts.logging import get_logger, raise_if, raise_log
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

try:
from typing import Literal
Expand Down
8 changes: 3 additions & 5 deletions darts/dataprocessing/encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@
from darts.dataprocessing.transformers import FittableDataTransformer
from darts.logging import get_logger, raise_if, raise_if_not
from darts.timeseries import DIMS
from darts.utils.timeseries_generation import (
datetime_attribute_timeseries,
generate_index,
)
from darts.utils.utils import seq2series, series2seq
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.ts_utils import seq2series, series2seq
from darts.utils.utils import generate_index

SupportedTimeSeries = Union[TimeSeries, Sequence[TimeSeries]]
logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion darts/dataprocessing/transformers/midas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from darts.logging import get_logger, raise_log
from darts.timeseries import _finite_rows_boundaries
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

logger = get_logger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion darts/dataprocessing/transformers/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
BaseDataTransformer,
FittableDataTransformer,
)
from darts.logging import get_logger, raise_if_not
from darts.timeseries import TimeSeries
from darts.utils.utils import raise_if_not

logger = get_logger(__name__)


def _get_summation_matrix(series: TimeSeries):
Expand All @@ -37,6 +39,7 @@ def _get_summation_matrix(series: TimeSeries):
raise_if_not(
series.has_hierarchy,
"The provided series must have a hierarchy defined for reconciliation to be performed.",
logger=logger,
)
hierarchy = series.hierarchy
components_seq = list(series.components)
Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/tft_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from darts.explainability.explainability import _ForecastingModelExplainer
from darts.logging import get_logger, raise_log
from darts.models import TFTModel
from darts.utils.timeseries_generation import generate_index
from darts.utils.utils import generate_index

try:
from typing import Literal
Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils.statistics import stationarity_tests
from darts.utils.utils import series2seq
from darts.utils.ts_utils import series2seq

logger = get_logger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion darts/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
import warnings
from typing import NoReturn


def get_logger(name):
Expand Down Expand Up @@ -104,7 +105,9 @@ def raise_if(
raise_if_not(not condition, message, logger)


def raise_log(exception: Exception, logger: logging.Logger = get_logger("main_logger")):
def raise_log(
exception: Exception, logger: logging.Logger = get_logger("main_logger")
) -> NoReturn:
"""
Can be used to replace "raise" when throwing an exception to ensure the logging
of the exception. After logging it, the exception is raised.
Expand Down
63 changes: 61 additions & 2 deletions darts/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,80 @@
"""
Metrics
-------
For deterministic forecasts (point predictions with `num_samples == 1`):
- Aggregated over time:
Absolute metrics:
- :func:`MERR <darts.metrics.metrics.merr>`: Mean Error
- :func:`MAE <darts.metrics.metrics.mae>`: Mean Absolute Error
- :func:`MSE <darts.metrics.metrics.mse>`: Mean Squared Error
- :func:`RMSE <darts.metrics.metrics.rmse>`: Root Mean Squared Error
- :func:`RMSLE <darts.metrics.metrics.rmsle>`: Root Mean Squared Log Error
Relative metrics:
- :func:`MASE <darts.metrics.metrics.mase>`: Mean Absolute Scaled Error
- :func:`MSSE <darts.metrics.metrics.msse>`: Mean Squared Scaled Error
- :func:`RMSSE <darts.metrics.metrics.rmsse>`: Root Mean Squared Scaled Error
- :func:`MAPE <darts.metrics.metrics.mape>`: Mean Absolute Percentage Error
- :func:`sMAPE <darts.metrics.metrics.smape>`: symmetric Mean Absolute Percentage Error
- :func:`OPE <darts.metrics.metrics.ope>`: Overall Percentage Error
- :func:`MARRE <darts.metrics.metrics.marre>`: Mean Absolute Ranged Relative Error
Other metrics:
- :func:`R2 <darts.metrics.metrics.r2_score>`: Coefficient of Determination
- :func:`CV <darts.metrics.metrics.coefficient_of_variation>`: Coefficient of Variation
- Per time step:
Absolute metrics:
- :func:`ERR <darts.metrics.metrics.err>`: Error
- :func:`AE <darts.metrics.metrics.ae>`: Absolute Error
- :func:`SE <darts.metrics.metrics.se>`: Squared Error
- :func:`SLE <darts.metrics.metrics.sle>`: Squared Log Error
Relative metrics:
- :func:`ASE <darts.metrics.metrics.ase>`: Absolute Scaled Error
- :func:`SSE <darts.metrics.metrics.sse>`: Squared Scaled Error
- :func:`APE <darts.metrics.metrics.ape>`: Absolute Percentage Error
- :func:`sAPE <darts.metrics.metrics.sape>`: symmetric Absolute Percentage Error
- :func:`ARRE <darts.metrics.metrics.arre>`: Absolute Ranged Relative Error
For probabilistic forecasts (storchastic predictions with `num_samples >> 1`):
- Aggregated over time:
- :func:`MQL <darts.metrics.metrics.mql>`: Mean Quantile Loss
- :func:`QR <darts.metrics.metrics.qr>`: Quantile Risk
- Per time step:
- :func:`QL <darts.metrics.metrics.ql>`: Quantile Loss
For Dynamic Time Warping (DTW) (aggregated over time):
- :func:`DTW <darts.metrics.metrics.dtw_metric>`: Dynamic Time Warping Metric
"""

from .metrics import (
ae,
ape,
arre,
ase,
coefficient_of_variation,
dtw_metric,
err,
mae,
mape,
marre,
mase,
merr,
mql,
mse,
msse,
ope,
quantile_loss,
ql,
qr,
r2_score,
rho_risk,
rmse,
rmsle,
rmsse,
sape,
se,
sle,
smape,
sse,
)

0 comments on commit 5c97c9b

Please sign in to comment.