Skip to content

Commit

Permalink
Feat/mc dropout (#1013)
Browse files Browse the repository at this point in the history
* Added Monte Carle Dropout support

* Added MC Dropout to models that can support it

* resolve issue

* extend user guide with MC Dropout

* Solve an issue in TCN tests

* change paper reference

* correct typo
  • Loading branch information
hrzn committed Jun 21, 2022
1 parent bb94236 commit a1328fa
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 18 deletions.
7 changes: 5 additions & 2 deletions darts/models/components/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
import torch
from torch import nn as nn

from darts.utils.torch import MonteCarloDropout


class FeedForward(nn.Module):
"""
Expand All @@ -78,7 +80,8 @@ def __init__(
"""
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `dropout` is dropout probability for the hidden layer
* `dropout` is dropout probability for the hidden layer,
compatible with Monte Carlo dropout at inference time
* `is_gated` specifies whether the hidden layer is gated
* `bias1` specified whether the first fully connected layer should have a learnable bias
* `bias2` specified whether the second fully connected layer should have a learnable bias
Expand All @@ -90,7 +93,7 @@ def __init__(
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
# Hidden layer dropout
self.dropout = nn.Dropout(dropout)
self.dropout = MonteCarloDropout(dropout)
# Activation function $f$
self.activation = activation
# Whether there is a gate
Expand Down
7 changes: 5 additions & 2 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from darts.logging import get_logger, raise_if_not, raise_log
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
)

if self.dropout > 0:
self.linear_layer_stack_list.append(nn.Dropout(p=self.dropout))
self.linear_layer_stack_list.append(MonteCarloDropout(p=self.dropout))

self.fc_stack = nn.ModuleList(self.linear_layer_stack_list)

Expand Down Expand Up @@ -586,7 +587,9 @@ def __init__(
The degree of the polynomial used as waveform generator in trend stacks. Only used if
`generic_architecture` is set to `False`.
dropout
The dropout probability to be used in the fully connected layers (default=0.0).
The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout
at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at
prediction time).
activation
The activation function of encoder/decoder intermediate layer (default='ReLU').
Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid']
Expand Down
7 changes: 5 additions & 2 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -153,7 +154,7 @@ def __init__(
layers.append(nn.BatchNorm1d(num_features=self.layer_widths[i + 1]))

if self.dropout > 0:
layers.append(nn.Dropout(p=self.dropout))
layers.append(MonteCarloDropout(p=self.dropout))

self.layers = nn.Sequential(*layers)

Expand Down Expand Up @@ -520,7 +521,9 @@ def __init__(
downsampling factors before interpolation, for each block in each stack.
If left to ``None``, some default values will be used based on ``output_chunk_length``.
dropout
Fraction of neurons affected by Dropout (default=0.1).
The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout
at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at
prediction time).
activation
The activation function of encoder/decoder intermediate layer (default='ReLU').
Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid']
Expand Down
17 changes: 16 additions & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from darts.timeseries import TimeSeries
from darts.utils.likelihood_models import Likelihood
from darts.utils.timeseries_generation import _build_forecast_series
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -342,8 +343,22 @@ def _sample_tiling(input_data_tuple, batch_sample_size):
tiled_input_data.append(None)
return tuple(tiled_input_data)

def _get_mc_dropout_modules(self) -> set:
def recurse_children(children, acc):
for module in children:
if isinstance(module, MonteCarloDropout):
acc.add(module)
acc = recurse_children(module.children(), acc)
return acc

return recurse_children(self.children(), set())

def set_mc_dropout(self, active: bool):
for module in self._get_mc_dropout_modules():
module.mc_dropout_enabled = active

def _is_probabilistic(self) -> bool:
return self.likelihood is not None
return self.likelihood is not None or len(self._get_mc_dropout_modules()) > 0

def _produce_predict_output(self, x: Tuple):
if self.likelihood:
Expand Down
7 changes: 5 additions & 2 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.timeseries import TimeSeries
from darts.utils.data import PastCovariatesShiftedDataset
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(
self.target_size = target_size
self.nr_params = nr_params
self.dilation_base = dilation_base
self.dropout = nn.Dropout(p=dropout)
self.dropout = MonteCarloDropout(p=dropout)

# If num_layers is not passed, compute number of layers needed for full history coverage
if num_layers is None and dilation_base > 1:
Expand Down Expand Up @@ -288,7 +289,9 @@ def __init__(
num_layers
The number of convolutional layers.
dropout
The dropout rate for every convolutional layer.
The dropout rate for every convolutional layer. This is compatible with Monte Carlo dropout
at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at
prediction time).
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down
4 changes: 3 additions & 1 deletion darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ def __init__(
or the TFT original FeedForward Network.
["GatedResidualNetwork"]
dropout : float
Fraction of neurons affected by Dropout.
Fraction of neurons affected by dropout. This is compatible with Monte Carlo dropout
at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at
prediction time).
hidden_continuous_size : int
Default for hidden size for processing continuous variables
add_relative_index : bool
Expand Down
7 changes: 4 additions & 3 deletions darts/models/forecasting/tft_submodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn.functional as F

from darts.logging import get_logger
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -188,7 +189,7 @@ def __init__(self, input_size: int, hidden_size: int = None, dropout: float = No
super().__init__()

if dropout is not None:
self.dropout = nn.Dropout(dropout)
self.dropout = MonteCarloDropout(dropout)
else:
self.dropout = dropout
self.hidden_size = hidden_size or input_size
Expand Down Expand Up @@ -500,7 +501,7 @@ class _ScaledDotProductAttention(nn.Module):
def __init__(self, dropout: float = None, scale: bool = True):
super().__init__()
if dropout is not None:
self.dropout = nn.Dropout(p=dropout)
self.dropout = MonteCarloDropout(p=dropout)
else:
self.dropout = dropout
self.softmax = nn.Softmax(dim=2)
Expand Down Expand Up @@ -530,7 +531,7 @@ def __init__(self, n_head: int, d_model: int, dropout: float = 0.0):
self.n_head = n_head
self.d_model = d_model
self.d_k = self.d_q = self.d_v = d_model // n_head
self.dropout = nn.Dropout(p=dropout)
self.dropout = MonteCarloDropout(p=dropout)

self.v_layer = nn.Linear(self.d_model, self.d_v)
self.q_layers = nn.ModuleList(
Expand Down
15 changes: 14 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ def predict(
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
mc_dropout: bool = False,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""Predict the ``n`` time step following the end of the training series, or of the specified ``series``.
Expand Down Expand Up @@ -1015,6 +1016,9 @@ def predict(
for the inference/prediction dataset loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Returns
-------
Expand Down Expand Up @@ -1077,6 +1081,8 @@ def predict(
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
mc_dropout=mc_dropout,
)

return predictions[0] if called_with_single_series else predictions
Expand All @@ -1093,6 +1099,7 @@ def predict_from_dataset(
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
mc_dropout: bool = False,
) -> Sequence[TimeSeries]:

"""
Expand Down Expand Up @@ -1136,6 +1143,9 @@ def predict_from_dataset(
for the inference/prediction dataset loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Returns
-------
Expand Down Expand Up @@ -1184,6 +1194,9 @@ def predict_from_dataset(
collate_fn=self._batch_collate_fn,
)

# Set mc_dropout rate
self.model.set_mc_dropout(mc_dropout)

# setup trainer. will only be re-instantiated if both `trainer` and `self.trainer` are `None`
trainer = trainer if trainer is not None else self.trainer
self._setup_trainer(trainer=trainer, verbose=verbose, epochs=self.n_epochs)
Expand Down Expand Up @@ -1428,7 +1441,7 @@ def _is_probabilistic(self) -> bool:
return (
self.model._is_probabilistic()
if self.model_created
else self.likelihood is not None
else True # all torch models can be probabilistic (via Dropout)
)


Expand Down
8 changes: 8 additions & 0 deletions darts/tests/models/forecasting/test_TCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def test_coverage(self):
)

model.model.eval()

# also disable MC Dropout:
model.model.set_mc_dropout(False)

input_tensor = torch.zeros(
[1, input_chunk_length, 1], dtype=torch.float64
)
Expand Down Expand Up @@ -146,6 +150,10 @@ def test_coverage(self):
)

model_2.model.eval()

# also disable MC Dropout:
model_2.model.set_mc_dropout(False)

input_tensor = torch.zeros(
[1, input_chunk_length, 1], dtype=torch.float64
)
Expand Down
42 changes: 42 additions & 0 deletions darts/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from inspect import signature
from typing import Any, Callable, TypeVar

import torch.nn as nn
import torch.nn.functional as F
from numpy.random import randint
from sklearn.utils import check_random_state
from torch import Tensor
from torch.random import fork_rng, manual_seed

from darts.logging import get_logger, raise_if_not
Expand All @@ -20,6 +23,45 @@
MAX_NUMPY_SEED_VALUE = (1 << 31) - 1


class MonteCarloDropout(nn.Dropout):
"""
Defines Monte Carlo dropout Module as defined
in the paper https://arxiv.org/pdf/1506.02142.pdf.
In summary, This technique uses the regular dropout
which can be interpreted as a Bayesian approximation of
a well-known probabilistic model: the Gaussian process.
We can treat the many different networks
(with different neurons dropped out) as Monte Carlo samples
from the space of all available models. This provides mathematical
grounds to reason about the model’s uncertainty and, as it turns out,
often improves its performance.
"""

# We need to init it to False as some models may start by
# a validation round, in which case MC dropout is disabled.
mc_dropout_enabled: bool = False

def train(self, mode: bool = True):
# NOTE: we could use the line below if self.mc_dropout_rate represented
# a rate to be applied at inference time, and self.applied_rate the
# actual rate to be used in self.forward(). However, the original paper
# considers the same rate for training and inference; we also stick to this.

# self.applied_rate = self.p if mode else self.mc_dropout_rate

if mode: # in train mode, keep dropout as is
self.mc_dropout_enabled = True
# in eval mode, bank on the mc_dropout_enabled flag
# mc_dropout_enabled is set equal to "mc_dropout" param given to predict()

def forward(self, input: Tensor) -> Tensor:
# NOTE: we could use the following line in case a different rate
# is used for inference:
# return F.dropout(input, self.applied_rate, True, self.inplace)

return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace)


def _is_method(func: Callable[..., Any]) -> bool:
"""Check if the specified function is a method.
Expand Down

0 comments on commit a1328fa

Please sign in to comment.