Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/mc dropout #2312

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
self.pred_batch_size: Optional[int] = None
self.pred_n_jobs: Optional[int] = None
self.predict_likelihood_parameters: Optional[bool] = None
self.pred_mc_dropout: Optional[bool] = None

@property
def first_prediction_index(self) -> int:
Expand Down Expand Up @@ -241,6 +242,14 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
self._calculate_metrics(output, target, self.val_metrics)
return loss

def on_predict_start(self) -> None:
# optionally, activate monte carlo dropout for prediction
self.set_mc_dropout(active=self.pred_mc_dropout)
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

def on_predict_end(self) -> None:
# deactivate, monte carlo dropout for any downstream task
self.set_mc_dropout(active=False)

def predict_step(
self, batch: Tuple, batch_idx: int, dataloader_idx: Optional[int] = None
) -> Sequence[TimeSeries]:
Expand Down Expand Up @@ -339,6 +348,7 @@ def set_predict_parameters(
batch_size: int,
n_jobs: int,
predict_likelihood_parameters: bool,
mc_dropout: bool,
) -> None:
"""to be set from TorchForecastingModel before calling trainer.predict() and reset at self.on_predict_end()"""
self.pred_n = n
Expand All @@ -347,6 +357,7 @@ def set_predict_parameters(
self.pred_batch_size = batch_size
self.pred_n_jobs = n_jobs
self.predict_likelihood_parameters = predict_likelihood_parameters
self.pred_mc_dropout = mc_dropout

def _compute_loss(self, output, target):
# output is of shape (batch_size, n_timesteps, n_components, n_params)
Expand Down
32 changes: 16 additions & 16 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
num_filters: int,
kernel_size: int,
dilation_base: int,
dropout_fn,
dropout: float,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
weight_norm: bool,
nr_blocks_below: int,
num_layers: int,
Expand All @@ -46,8 +46,8 @@ def __init__(
The size of every kernel in a convolutional layer.
dilation_base
The base of the exponent that will determine the dilation on every level.
dropout_fn
The dropout function to be applied to every convolutional layer.
dropout
The dropout to be applied to every convolutional layer.
weight_norm
Boolean value indicating whether to use weight normalization.
nr_blocks_below
Expand Down Expand Up @@ -77,7 +77,8 @@ def __init__(

self.dilation_base = dilation_base
self.kernel_size = kernel_size
self.dropout_fn = dropout_fn
self.dropout1 = MonteCarloDropout(dropout)
self.dropout2 = MonteCarloDropout(dropout)
self.num_layers = num_layers
self.nr_blocks_below = nr_blocks_below

Expand Down Expand Up @@ -111,14 +112,14 @@ def forward(self, x):
self.kernel_size - 1
)
x = F.pad(x, (left_padding, 0))
x = self.dropout_fn(F.relu(self.conv1(x)))
x = self.dropout1(F.relu(self.conv1(x)))

# second step
x = F.pad(x, (left_padding, 0))
x = self.conv2(x)
if self.nr_blocks_below < self.num_layers - 1:
x = F.relu(x)
x = self.dropout_fn(x)
x = self.dropout2(x)

# add residual
if self.conv1.in_channels != self.conv2.out_channels:
Expand Down Expand Up @@ -195,7 +196,6 @@ def __init__(
self.target_size = target_size
self.nr_params = nr_params
self.dilation_base = dilation_base
self.dropout = MonteCarloDropout(p=dropout)

# If num_layers is not passed, compute number of layers needed for full history coverage
if num_layers is None and dilation_base > 1:
Expand All @@ -221,15 +221,15 @@ def __init__(
self.res_blocks_list = []
for i in range(num_layers):
res_block = _ResidualBlock(
num_filters,
kernel_size,
dilation_base,
self.dropout,
weight_norm,
i,
num_layers,
self.input_size,
target_size * nr_params,
num_filters=num_filters,
kernel_size=kernel_size,
dilation_base=dilation_base,
dropout=dropout,
weight_norm=weight_norm,
nr_blocks_below=i,
num_layers=num_layers,
input_size=self.input_size,
target_size=target_size * nr_params,
)
self.res_blocks_list.append(res_block)
self.res_blocks = nn.ModuleList(self.res_blocks_list)
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/tide_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

MixedCovariatesTrainTensorType = Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
Expand All @@ -40,7 +41,7 @@ def __init__(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_dim),
nn.Dropout(dropout),
MonteCarloDropout(dropout),
)

# linear skip connection from input to output of self.dense
Expand Down
4 changes: 1 addition & 3 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,7 @@ def predict_from_dataset(
batch_size=batch_size,
n_jobs=n_jobs,
predict_likelihood_parameters=predict_likelihood_parameters,
mc_dropout=mc_dropout,
)

pred_loader = DataLoader(
Expand All @@ -1534,9 +1535,6 @@ def predict_from_dataset(
collate_fn=self._batch_collate_fn,
)

# set mc_dropout rate
self.model.set_mc_dropout(mc_dropout)

# set up trainer. use user supplied trainer or create a new trainer from scratch
self.trainer = self._setup_trainer(
trainer=trainer, model=self.model, verbose=verbose, epochs=self.n_epochs
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(self, d_model, dropout=0.1, max_len=500):
Tensor containing the embedded time series enhanced with positional encoding.
"""
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.dropout = MonteCarloDropout(p=dropout)

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
Expand Down
24 changes: 6 additions & 18 deletions darts/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,17 @@ class MonteCarloDropout(nn.Dropout):
often improves its performance.
"""

# We need to init it to False as some models may start by
# a validation round, in which case MC dropout is disabled.
mc_dropout_enabled: bool = False

def train(self, mode: bool = True):
# NOTE: we could use the line below if self.mc_dropout_rate represented
# a rate to be applied at inference time, and self.applied_rate the
# actual rate to be used in self.forward(). However, the original paper
# considers the same rate for training and inference; we also stick to this.

# self.applied_rate = self.p if mode else self.mc_dropout_rate

if mode: # in train mode, keep dropout as is
self.mc_dropout_enabled = True
# in eval mode, bank on the mc_dropout_enabled flag
# mc_dropout_enabled is set equal to "mc_dropout" param given to predict()
# mc dropout is only activated on `PLForecastingModule.on_predict_start()`
# otherwise, it is activated based on the `model.training` flag.
mc_dropout_enabled = False

def forward(self, input: Tensor) -> Tensor:
# NOTE: we could use the following line in case a different rate
# is used for inference:
# return F.dropout(input, self.applied_rate, True, self.inplace)

return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace)
return F.dropout(
input, self.p, self.mc_dropout_enabled or self.training, self.inplace
)


def _is_method(func: Callable[..., Any]) -> bool:
Expand Down