Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/sample weight torch #2410

Merged
merged 18 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- 🚀🚀 Improvements to `GlobalForecastingModel` (`RegressionModel` and `TorchForecastingModel`) : [#2404](https://github.com/unit8co/darts/pull/2404) and [#2410](https://github.com/unit8co/darts/pull/2410) by [Dennis Bader](https://github.com/dennisbader).
- Added parameters `sample_weight` and `val_sample_weight` to `fit()` to apply weights to each observation with the corresponding output step, and target component in the training and evaluation set. Supported by both deterministic and probabilistic models. The sample weight can either be `TimeSeries` themselves or built-in weight generators "linear_decay" and "exponential_decay". In case of a `TimeSeries` it is handled identically as the covariates (e.g. pass multiple weight series with multiple target series, relevant time frame extraction is handled automatically for you, ...).
- Improvements to the Anomaly Detection Module through major refactor. The refactor includes major performance optimization for the majority of the processes and improvements to the API, consistency, reliability, and the documentation. Some of these necessary changes come at the cost of breaking changes : [#1477](https://github.com/unit8co/darts/pull/1477) by [Dennis Bader](https://github.com/dennisbader), [Samuele Giuliano Piazzetta](https://github.com/piaz97), [Antoine Madrona](https://github.com/madtoinou), [Julien Herzen](https://github.com/hrzn), [Julien Adda](https://github.com/julien12234).
- 🚀 Added an example notebook that showcases how to use Darts for Time Series Anomaly Detection
- 🚀🚀 Added an example notebook that showcases how to use Darts for Time Series Anomaly Detection
- Added a new dataset for anomaly detection with the number of taxi passengers in New York from the year 2014 to 2015.
- `FittableWindowScorer` (KMeans, PyOD, and Wasserstein Scorers) now accept any of darts "per-time" step metrics as difference function `diff_fn`.
- `ForecastingAnomalyModel` is now much faster thanks to optimized historical forecasts to generate the prediction input for the scorers. We also added more control over the historical forecasts generation through additional parameters in all model methods.
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/global_baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> MixedCovariatesTrainingDataset:
return MixedCovariatesSequentialDataset(
Expand All @@ -264,6 +265,7 @@ def _build_train_dataset(
output_chunk_shift=self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)


Expand Down
44 changes: 33 additions & 11 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This file contains abstract classes for deterministic and probabilistic PyTorch Lightning Modules
"""

import copy
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -161,6 +162,13 @@ def __init__(

# define the loss function
self.criterion = loss_fn
self.train_criterion = copy.deepcopy(loss_fn)
self.val_criterion = copy.deepcopy(loss_fn)
# reduction will be set to `None` when calling `TFM.fit()` with sample weights;
# reset the actual criterion in method `on_fit_end()`
self.train_criterion_reduction: Optional[str] = None
self.val_criterion_reduction: Optional[str] = None

# by default models are deterministic (i.e. not probabilistic)
self.likelihood = likelihood

Expand Down Expand Up @@ -212,11 +220,11 @@ def forward(self, *args, **kwargs) -> Any:

def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
output = self._produce_train_output(train_batch[:-1])
target = train_batch[
-1
] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
# by convention, the last two elements are sample weights and future target
output = self._produce_train_output(train_batch[:-2])
sample_weight = train_batch[-2]
target = train_batch[-1]
loss = self._compute_loss(output, target, self.train_criterion, sample_weight)
self.log(
"train_loss",
loss,
Expand All @@ -229,9 +237,11 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:

def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
"""performs the validation step"""
output = self._produce_train_output(val_batch[:-1])
# the last two elements are sample weights and future target
output = self._produce_train_output(val_batch[:-2])
sample_weight = val_batch[-2]
target = val_batch[-1]
loss = self._compute_loss(output, target)
loss = self._compute_loss(output, target, self.val_criterion, sample_weight)
self.log(
"val_loss",
loss,
Expand All @@ -242,6 +252,15 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
self._update_metrics(output, target, self.val_metrics)
return loss

def on_fit_end(self) -> None:
# revert the loss function reduction change when sample weights were used
if self.train_criterion_reduction is not None:
self.train_criterion.reduction = self.train_criterion_reduction
self.train_criterion_reduction = None
if self.val_criterion_reduction is not None:
self.val_criterion.reduction = self.val_criterion_reduction
self.val_criterion_reduction = None

def on_train_epoch_end(self):
self._compute_metrics(self.train_metrics)

Expand Down Expand Up @@ -364,14 +383,17 @@ def set_predict_parameters(
self.predict_likelihood_parameters = predict_likelihood_parameters
self.pred_mc_dropout = mc_dropout

def _compute_loss(self, output, target):
def _compute_loss(self, output, target, criterion, sample_weight):
# output is of shape (batch_size, n_timesteps, n_components, n_params)
if self.likelihood:
return self.likelihood.compute_loss(output, target)
loss = self.likelihood.compute_loss(output, target, sample_weight)
else:
# If there's no likelihood, nr_params=1, and we need to squeeze out the
# last dimension of model output, for properly computing the loss.
return self.criterion(output.squeeze(dim=-1), target)
loss = criterion(output.squeeze(dim=-1), target)
if sample_weight is not None:
loss = (loss * sample_weight).mean()
return loss

def _update_metrics(self, output, target, metrics):
if not len(metrics):
Expand Down Expand Up @@ -511,7 +533,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["train_sample_shape"] = self.train_sample_shape
# we must save the loss to properly restore it when resuming training
checkpoint["loss_fn"] = self.criterion
# we must save the metrics to continue outputing them when resuming training
# we must save the metrics to continue logging them when resuming training
checkpoint["torch_metrics_train"] = self.train_metrics
checkpoint["torch_metrics_val"] = self.val_metrics

Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> DualCovariatesShiftedDataset:
return DualCovariatesShiftedDataset(
Expand All @@ -572,6 +573,7 @@ def _build_train_dataset(
shift=1,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)

def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> PastCovariatesShiftedDataset:
return PastCovariatesShiftedDataset(
Expand All @@ -544,4 +545,5 @@ def _build_train_dataset(
shift=self.output_chunk_length + self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)
2 changes: 2 additions & 0 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> MixedCovariatesSequentialDataset:
raise_if(
Expand All @@ -1179,6 +1180,7 @@ def _build_train_dataset(
output_chunk_shift=self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)

def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
Expand Down
Loading
Loading