Skip to content

Commit

Permalink
add torch.utils.data.DataLoader related parameters to fit() and `…
Browse files Browse the repository at this point in the history
…predict()` of `TorchForecastingModel`
  • Loading branch information
Bohdan Bilonoh authored and BohdanBilonoh committed Apr 8, 2024
1 parent 0d5c722 commit 23af346
Showing 1 changed file with 129 additions and 5 deletions.
134 changes: 129 additions & 5 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,10 @@ def fit(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
) -> "TorchForecastingModel":
"""Fit/train the model on one or multiple series.
Expand Down Expand Up @@ -714,11 +717,25 @@ def fit(
large number of training samples. This parameter upper-bounds the number of training samples per time
series (taking only the most recent samples in each series). Leaving to None does not apply any
upper bound.
pin_memory
Whether to use ``pin_memory`` in PyTorch ``DataLoader`` instances. This can speed up training on GPUs.
If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
prefetch_factor
Optionally, an integer specifying the ``prefetch_factor`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
Optionally, a boolean specifying whether to use persistent workers in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. (default: False)
Returns
-------
Expand All @@ -740,7 +757,10 @@ def fit(
verbose=verbose,
epochs=epochs,
max_samples_per_ts=max_samples_per_ts,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
# call super fit only if user is actually fitting the model
super().fit(
Expand All @@ -762,7 +782,10 @@ def _setup_for_fit_from_dataset(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
) -> Tuple[
Tuple[
Sequence[TimeSeries],
Expand All @@ -775,7 +798,10 @@ def _setup_for_fit_from_dataset(
Optional[pl.Trainer],
Optional[bool],
int,
bool,
int,
Optional[int],
bool,
],
]:
"""This method acts on `TimeSeries` inputs. It performs sanity checks, and sets up / returns the datasets and
Expand Down Expand Up @@ -897,7 +923,10 @@ def _setup_for_fit_from_dataset(
trainer,
verbose,
epochs,
pin_memory,
num_loader_workers,
prefetch_factor,
persistent_workers,
)
return series_input, fit_from_ds_params

Expand All @@ -909,7 +938,10 @@ def fit_from_dataset(
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
) -> "TorchForecastingModel":
"""
Train the model with a specific :class:`darts.utils.data.TrainingDataset` instance.
Expand Down Expand Up @@ -942,11 +974,25 @@ def fit_from_dataset(
epochs
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
was provided to the model constructor.
pin_memory
Whether to use ``pin_memory`` in PyTorch ``DataLoader`` instances. This can speed up training on GPUs.
If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
prefetch_factor
Optionally, an integer specifying the ``prefetch_factor`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
Optionally, a boolean specifying whether to use persistent workers in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. (default: False)
Returns
-------
Expand All @@ -960,7 +1006,10 @@ def fit_from_dataset(
trainer=trainer,
verbose=verbose,
epochs=epochs,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
)
return self
Expand All @@ -972,7 +1021,10 @@ def _setup_for_train(
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:
"""This method acts on `TrainingDataset` inputs. It performs sanity checks, and sets up / returns the trainer,
model, and dataset loaders required for training the model with `_train()`.
Expand Down Expand Up @@ -1039,7 +1091,9 @@ def _setup_for_train(
batch_size=self.batch_size,
shuffle=True,
num_workers=num_loader_workers,
pin_memory=True,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
drop_last=False,
collate_fn=self._batch_collate_fn,
)
Expand All @@ -1053,7 +1107,9 @@ def _setup_for_train(
batch_size=self.batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
drop_last=False,
collate_fn=self._batch_collate_fn,
)
Expand Down Expand Up @@ -1120,7 +1176,10 @@ def lr_find(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
min_lr: float = 1e-08,
max_lr: float = 1,
num_training: int = 100,
Expand Down Expand Up @@ -1192,11 +1251,25 @@ def lr_find(
large number of training samples. This parameter upper-bounds the number of training samples per time
series (taking only the most recent samples in each series). Leaving to None does not apply any
upper bound.
pin_memory
Whether to use ``pin_memory`` in PyTorch ``DataLoader`` instances. This can speed up training on GPUs.
If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
prefetch_factor
Optionally, an integer specifying the ``prefetch_factor`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
Optionally, a boolean specifying whether to use persistent workers in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. (default: False)
min_lr
minimum learning rate to investigate
max_lr
Expand Down Expand Up @@ -1228,7 +1301,10 @@ def lr_find(
verbose=verbose,
epochs=epochs,
max_samples_per_ts=max_samples_per_ts,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
trainer, model, train_loader, val_loader = self._setup_for_train(*params)
return Tuner(trainer).lr_find(
Expand Down Expand Up @@ -1257,7 +1333,10 @@ def predict(
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
show_warnings: bool = True,
Expand Down Expand Up @@ -1319,11 +1398,25 @@ def predict(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
pin_memory
Whether to use ``pin_memory`` in PyTorch ``DataLoader`` instances. This can speed up training on GPUs.
If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
prefetch_factor
Optionally, an integer specifying the ``prefetch_factor`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
Optionally, a boolean specifying whether to use persistent workers in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. (default: False)
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Expand Down Expand Up @@ -1407,7 +1500,10 @@ def predict(
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
)
Expand All @@ -1425,7 +1521,10 @@ def predict_from_dataset(
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
) -> Sequence[TimeSeries]:
Expand Down Expand Up @@ -1466,11 +1565,25 @@ def predict_from_dataset(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
pin_memory
Whether to use ``pin_memory`` in PyTorch ``DataLoader`` instances. This can speed up training on GPUs.
If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
prefetch_factor
Optionally, an integer specifying the ``prefetch_factor`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
Optionally, a boolean specifying whether to use persistent workers in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. (default: False)
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Expand Down Expand Up @@ -1529,7 +1642,9 @@ def predict_from_dataset(
batch_size=batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
drop_last=False,
collate_fn=self._batch_collate_fn,
)
Expand Down Expand Up @@ -2800,7 +2915,10 @@ def predict(
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
pin_memory: bool = False,
num_loader_workers: int = 0,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
show_warnings: bool = True,
Expand All @@ -2821,7 +2939,10 @@ def predict(
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
Expand All @@ -2838,7 +2959,10 @@ def predict(
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
pin_memory=pin_memory,
num_loader_workers=num_loader_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
Expand Down

0 comments on commit 23af346

Please sign in to comment.