Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/mc dropout #2312

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
comment: false
coverage:
status:
project: off
patch: off
3 changes: 1 addition & 2 deletions .github/workflows/develop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ jobs:

- name: "8. Codecov upload"
if: ${{ matrix.flavour == 'all' }}
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: true
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

docs:
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ jobs:

- name: "7. Codecov upload"
if: ${{ matrix.flavour == 'all' }}
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: true
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

check-examples:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed a bug in `quantile_loss`, where the loss was computed on all samples rather than only on the predicted quantiles. [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader).
- Fixed type hint warning "Unexpected argument" when calling `historical_forecasts()` caused by the `_with_sanity_checks` decorator. The type hinting is now properly configured to expect any input arguments and return the output type of the method for which the sanity checks are performed for. [#2286](https://github.com/unit8co/darts/pull/2286) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a segmentation fault that some users were facing when importing a `LightGBMModel`. [#2304](https://github.com/unit8co/darts/pull/2304) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug when using a dropout with a `TorchForecasting` and pytorch lightning versions >= 2.2.0, where the dropout was not properly activated during training. [#2312](https://github.com/unit8co/darts/pull/2312) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**

Expand Down
14 changes: 13 additions & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
self.pred_batch_size: Optional[int] = None
self.pred_n_jobs: Optional[int] = None
self.predict_likelihood_parameters: Optional[bool] = None
self.pred_mc_dropout: Optional[bool] = None

@property
def first_prediction_index(self) -> int:
Expand Down Expand Up @@ -241,6 +242,14 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
self._calculate_metrics(output, target, self.val_metrics)
return loss

def on_predict_start(self) -> None:
# optionally, activate monte carlo dropout for prediction
self.set_mc_dropout(active=self.pred_mc_dropout)
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

def on_predict_end(self) -> None:
# deactivate, monte carlo dropout for any downstream task
self.set_mc_dropout(active=False)

def predict_step(
self, batch: Tuple, batch_idx: int, dataloader_idx: Optional[int] = None
) -> Sequence[TimeSeries]:
Expand Down Expand Up @@ -339,6 +348,7 @@ def set_predict_parameters(
batch_size: int,
n_jobs: int,
predict_likelihood_parameters: bool,
mc_dropout: bool,
) -> None:
"""to be set from TorchForecastingModel before calling trainer.predict() and reset at self.on_predict_end()"""
self.pred_n = n
Expand All @@ -347,6 +357,7 @@ def set_predict_parameters(
self.pred_batch_size = batch_size
self.pred_n_jobs = n_jobs
self.predict_likelihood_parameters = predict_likelihood_parameters
self.pred_mc_dropout = mc_dropout

def _compute_loss(self, output, target):
# output is of shape (batch_size, n_timesteps, n_components, n_params)
Expand Down Expand Up @@ -464,8 +475,9 @@ def recurse_children(children, acc):
return recurse_children(self.children(), set())

def set_mc_dropout(self, active: bool):
# optionally, activate dropout in all MonteCarloDropout modules
for module in self._get_mc_dropout_modules():
module.mc_dropout_enabled = active
module._mc_dropout_enabled = active

@property
def supports_probabilistic_prediction(self) -> bool:
Expand Down
32 changes: 16 additions & 16 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
num_filters: int,
kernel_size: int,
dilation_base: int,
dropout_fn,
dropout: float,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
weight_norm: bool,
nr_blocks_below: int,
num_layers: int,
Expand All @@ -46,8 +46,8 @@ def __init__(
The size of every kernel in a convolutional layer.
dilation_base
The base of the exponent that will determine the dilation on every level.
dropout_fn
The dropout function to be applied to every convolutional layer.
dropout
The dropout to be applied to every convolutional layer.
weight_norm
Boolean value indicating whether to use weight normalization.
nr_blocks_below
Expand Down Expand Up @@ -77,7 +77,8 @@ def __init__(

self.dilation_base = dilation_base
self.kernel_size = kernel_size
self.dropout_fn = dropout_fn
self.dropout1 = MonteCarloDropout(dropout)
self.dropout2 = MonteCarloDropout(dropout)
self.num_layers = num_layers
self.nr_blocks_below = nr_blocks_below

Expand Down Expand Up @@ -111,14 +112,14 @@ def forward(self, x):
self.kernel_size - 1
)
x = F.pad(x, (left_padding, 0))
x = self.dropout_fn(F.relu(self.conv1(x)))
x = self.dropout1(F.relu(self.conv1(x)))

# second step
x = F.pad(x, (left_padding, 0))
x = self.conv2(x)
if self.nr_blocks_below < self.num_layers - 1:
x = F.relu(x)
x = self.dropout_fn(x)
x = self.dropout2(x)

# add residual
if self.conv1.in_channels != self.conv2.out_channels:
Expand Down Expand Up @@ -195,7 +196,6 @@ def __init__(
self.target_size = target_size
self.nr_params = nr_params
self.dilation_base = dilation_base
self.dropout = MonteCarloDropout(p=dropout)

# If num_layers is not passed, compute number of layers needed for full history coverage
if num_layers is None and dilation_base > 1:
Expand All @@ -221,15 +221,15 @@ def __init__(
self.res_blocks_list = []
for i in range(num_layers):
res_block = _ResidualBlock(
num_filters,
kernel_size,
dilation_base,
self.dropout,
weight_norm,
i,
num_layers,
self.input_size,
target_size * nr_params,
num_filters=num_filters,
kernel_size=kernel_size,
dilation_base=dilation_base,
dropout=dropout,
weight_norm=weight_norm,
nr_blocks_below=i,
num_layers=num_layers,
input_size=self.input_size,
target_size=target_size * nr_params,
)
self.res_blocks_list.append(res_block)
self.res_blocks = nn.ModuleList(self.res_blocks_list)
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/tide_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

MixedCovariatesTrainTensorType = Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
Expand All @@ -40,7 +41,7 @@ def __init__(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_dim),
nn.Dropout(dropout),
MonteCarloDropout(dropout),
)

# linear skip connection from input to output of self.dense
Expand Down
4 changes: 1 addition & 3 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,7 @@ def predict_from_dataset(
batch_size=batch_size,
n_jobs=n_jobs,
predict_likelihood_parameters=predict_likelihood_parameters,
mc_dropout=mc_dropout,
)

pred_loader = DataLoader(
Expand All @@ -1534,9 +1535,6 @@ def predict_from_dataset(
collate_fn=self._batch_collate_fn,
)

# set mc_dropout rate
self.model.set_mc_dropout(mc_dropout)

# set up trainer. use user supplied trainer or create a new trainer from scratch
self.trainer = self._setup_trainer(
trainer=trainer, model=self.model, verbose=verbose, epochs=self.n_epochs
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(self, d_model, dropout=0.1, max_len=500):
Tensor containing the embedded time series enhanced with positional encoding.
"""
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.dropout = MonteCarloDropout(p=dropout)

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
Expand Down
66 changes: 66 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import itertools
import os
from typing import Any, Dict, Optional
Expand All @@ -6,6 +7,7 @@
import numpy as np
import pandas as pd
import pytest
from pytorch_lightning.callbacks import Callback

import darts.utils.timeseries_generation as tg
from darts import TimeSeries
Expand Down Expand Up @@ -1466,6 +1468,70 @@ def test_rin(self, model_config):
assert isinstance(model_rin_mv.model.rin, RINorm)
assert model_rin_mv.model.rin.input_dim == self.multivariate_series.n_components

@pytest.mark.parametrize("use_mc_dropout", [False, True])
def test_mc_dropout_active(self, use_mc_dropout):
"""Test that model activates dropout ."""

class CheckMCDropout(Callback):
def __init__(self, activate_mc_dropout):
self.use_mc_dropout = activate_mc_dropout

@staticmethod
def _check_dropout_activity(pl_module, expected_active: bool):
dropouts = pl_module._get_mc_dropout_modules()
assert all(
[
dropout.mc_dropout_enabled is expected_active
for dropout in dropouts
]
)

def on_train_batch_start(self, *args, **kwargs) -> None:
self._check_dropout_activity(args[1], expected_active=True)

def on_validation_batch_start(self, *args, **kwargs) -> None:
self._check_dropout_activity(args[1], expected_active=False)

def on_predict_batch_start(self, *args, **kwargs) -> None:
self._check_dropout_activity(
args[1], expected_active=self.use_mc_dropout
)

series = self.series[:20]
pl_trainer_kwargs = copy.deepcopy(tfm_kwargs)
pl_trainer_kwargs["pl_trainer_kwargs"]["callbacks"] = [
CheckMCDropout(activate_mc_dropout=use_mc_dropout)
]
model = TiDEModel(10, 10, dropout=0.1, random_state=42, **pl_trainer_kwargs)
model.fit(series, val_series=series, epochs=1)

num_samples = 1 if not use_mc_dropout else 10
preds = model.predict(
n=10, series=series, mc_dropout=use_mc_dropout, num_samples=num_samples
)
assert preds.n_samples == num_samples

@pytest.mark.parametrize("use_mc_dropout", [False, True])
def test_dropout_output(self, use_mc_dropout):
"""Test that model without dropout generates different results than one which uses near-full dropout."""
series = self.series[:20]
num_samples = 1 if not use_mc_dropout else 10

# dropouts for overfit and underfit
preds = []
for dropout in [0.0, 0.99]:
model = TiDEModel(10, 10, dropout=dropout, random_state=42, **tfm_kwargs)
model.fit(series, val_series=series, epochs=1)
preds.append(
model.predict(
n=10,
series=series,
mc_dropout=use_mc_dropout,
num_samples=num_samples,
).all_values()
)
assert not np.array_equal(preds[0], preds[1])

@pytest.mark.parametrize(
"config",
itertools.product(
Expand Down
25 changes: 8 additions & 17 deletions darts/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,21 @@ class MonteCarloDropout(nn.Dropout):
often improves its performance.
"""

# We need to init it to False as some models may start by
# a validation round, in which case MC dropout is disabled.
mc_dropout_enabled: bool = False

def train(self, mode: bool = True):
# NOTE: we could use the line below if self.mc_dropout_rate represented
# a rate to be applied at inference time, and self.applied_rate the
# actual rate to be used in self.forward(). However, the original paper
# considers the same rate for training and inference; we also stick to this.

# self.applied_rate = self.p if mode else self.mc_dropout_rate

if mode: # in train mode, keep dropout as is
self.mc_dropout_enabled = True
# in eval mode, bank on the mc_dropout_enabled flag
# mc_dropout_enabled is set equal to "mc_dropout" param given to predict()
# mc dropout is deactivated at init; see `MonteCarloDropout.mc_dropout_enabled` for more info
_mc_dropout_enabled = False

def forward(self, input: Tensor) -> Tensor:
# NOTE: we could use the following line in case a different rate
# is used for inference:
# return F.dropout(input, self.applied_rate, True, self.inplace)

return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace)

@property
def mc_dropout_enabled(self) -> bool:
# mc dropout is only activated on `PLForecastingModule.on_predict_start()`
# otherwise, it is activated based on the `model.training` flag.
return self._mc_dropout_enabled or self.training


def _is_method(func: Callable[..., Any]) -> bool:
"""Check if the specified function is a method.
Expand Down