Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish TODOs in NHiTs and NBEATs #955

Merged
merged 16 commits into from
May 18, 2022
60 changes: 56 additions & 4 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(
input_chunk_length: int,
target_length: int,
g_type: GTypes,
batch_norm: bool,
dropout: float,
):
"""PyTorch module implementing the basic building block of the N-BEATS architecture.

Expand All @@ -104,6 +106,10 @@ def __init__(
The length of the forecast of the model.
g_type
The type of function that is implemented by the waveform generator.
batch_norm
Whether to use batch norm
dropout
Dropout probability

Inputs
------
Expand All @@ -127,12 +133,22 @@ def __init__(
self.nr_params = nr_params
self.g_type = g_type
self.relu = nn.ReLU()
self.dropout = dropout
self.batch_norm = batch_norm

# fully connected stack before fork
self.linear_layer_stack_list = [nn.Linear(input_chunk_length, layer_width)]
self.linear_layer_stack_list += [
nn.Linear(layer_width, layer_width) for _ in range(num_layers - 1)
]
for _ in range(num_layers - 1):
self.linear_layer_stack_list += [nn.Linear(layer_width, layer_width)]
gdevos010 marked this conversation as resolved.
Show resolved Hide resolved

if self.batch_norm:
self.linear_layer_stack_list.append(
nn.BatchNorm1d(num_features=self.layer_width)
)

if self.dropout > 0:
self.linear_layer_stack_list.append(nn.Dropout(p=self.dropout))

self.fc_stack = nn.ModuleList(self.linear_layer_stack_list)

# Fully connected layer producing forecast/backcast expansion coeffcients (waveform generator parameters).
Expand Down Expand Up @@ -202,6 +218,8 @@ def __init__(
input_chunk_length: int,
target_length: int,
g_type: GTypes,
batch_norm: bool,
dropout: float,
):
"""PyTorch module implementing one stack of the N-BEATS architecture that comprises multiple basic blocks.

Expand All @@ -223,6 +241,10 @@ def __init__(
The length of the forecast of the model.
g_type
The function that is implemented by the waveform generators in each block.
batch_norm
whether to apply batch norm on first block of this stack
dropout
Dropout probability

Inputs
------
Expand Down Expand Up @@ -254,8 +276,12 @@ def __init__(
input_chunk_length,
target_length,
g_type,
batch_norm=(
batch_norm and i == 0
), # batch norm only on first block of first stack
dropout=dropout,
)
for _ in range(num_blocks)
for i in range(num_blocks)
]
else:
# same block instance is used for weight sharing
Expand All @@ -267,6 +293,8 @@ def __init__(
input_chunk_length,
target_length,
g_type,
batch_norm=batch_norm,
dropout=dropout,
)
self.blocks_list = [interpretable_block] * num_blocks

Expand Down Expand Up @@ -310,6 +338,8 @@ def __init__(
layer_widths: List[int],
expansion_coefficient_dim: int,
trend_polynomial_degree: int,
batch_norm: bool,
dropout: float,
**kwargs
):
"""PyTorch module implementing the N-BEATS architecture.
Expand Down Expand Up @@ -342,6 +372,10 @@ def __init__(
trend_polynomial_degree
The degree of the polynomial used as waveform generator in trend stacks. Only used if
`generic_architecture` is set to `False`.
batch_norm
Whether to apply batch norm on first block of the first stack
dropout
Dropout probability
**kwargs
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.

Expand Down Expand Up @@ -375,6 +409,10 @@ def __init__(
self.input_chunk_length_multi,
self.target_length,
_GType.GENERIC,
batch_norm=(
batch_norm and i == 0
), # batch norm only on first block of first stack
dropout=dropout,
)
for i in range(num_stacks)
]
Expand All @@ -389,6 +427,8 @@ def __init__(
self.input_chunk_length_multi,
self.target_length,
_GType.TREND,
batch_norm=batch_norm,
dropout=dropout,
)
seasonality_stack = _Stack(
num_blocks,
Expand All @@ -399,6 +439,8 @@ def __init__(
self.input_chunk_length_multi,
self.target_length,
_GType.SEASONALITY,
batch_norm=batch_norm,
dropout=dropout,
)
self.stacks_list = [trend_stack, seasonality_stack]

Expand Down Expand Up @@ -460,6 +502,7 @@ def __init__(
layer_widths: Union[int, List[int]] = 256,
expansion_coefficient_dim: int = 5,
trend_polynomial_degree: int = 2,
dropout: float = 0.0,
**kwargs
):
"""Neural Basis Expansion Analysis Time Series Forecasting (N-BEATS).
Expand Down Expand Up @@ -502,6 +545,8 @@ def __init__(
trend_polynomial_degree
The degree of the polynomial used as waveform generator in trend stacks. Only used if
`generic_architecture` is set to `False`.
dropout
The dropout probability to be used in the fully connected layers.
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -656,6 +701,11 @@ def __init__(
self.expansion_coefficient_dim = expansion_coefficient_dim
self.trend_polynomial_degree = trend_polynomial_degree

# Currently batch norm is not an option as it seems to perform badly
self.batch_norm = False
gdevos010 marked this conversation as resolved.
Show resolved Hide resolved

self.dropout = dropout

if not generic_architecture:
self.num_stacks = 2

Expand All @@ -681,5 +731,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
layer_widths=self.layer_widths,
expansion_coefficient_dim=self.expansion_coefficient_dim,
trend_polynomial_degree=self.trend_polynomial_degree,
batch_norm=self.batch_norm,
dropout=self.dropout,
**self.pl_module_params,
)
27 changes: 22 additions & 5 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
n_freq_downsample: int,
batch_norm: bool,
dropout: float,
activation: str,
):
"""PyTorch module implementing the basic building block of the N-HiTS architecture.

Expand Down Expand Up @@ -56,7 +57,8 @@ def __init__(
Whether to use batch norm
dropout
Dropout probability

activation
The activation function of encoder/decoder intermediate layer, 'relu' or 'gelu'.
gdevos010 marked this conversation as resolved.
Show resolved Hide resolved
Inputs
------
x of shape `(batch_size, input_chunk_length)`
Expand All @@ -83,7 +85,10 @@ def __init__(
self.batch_norm = batch_norm
self.dropout = dropout

self.activation = nn.ReLU() # TODO: make configurable?
if activation == "relu":
self.activation = nn.ReLU()
else:
gdevos010 marked this conversation as resolved.
Show resolved Hide resolved
self.activation = nn.GELU()

# number of parameters theta for backcast and forecast
"""
Expand Down Expand Up @@ -127,7 +132,6 @@ def __init__(
)
layers.append(self.activation)

# TODO: also add these two for NBEATS?
if self.batch_norm:
layers.append(nn.BatchNorm1d(num_features=self.layer_widths[i + 1]))

Expand Down Expand Up @@ -195,6 +199,7 @@ def __init__(
n_freq_downsample: Tuple[int],
batch_norm: bool,
dropout: float,
activation: str,
):
"""PyTorch module implementing one stack of the N-BEATS architecture that comprises multiple basic blocks.

Expand All @@ -220,6 +225,8 @@ def __init__(
whether to apply batch norm on first block of this stack
dropout
Dropout probability
activation
The activation function of encoder/decoder intermediate layer, 'relu' or 'gelu'.

Inputs
------
Expand Down Expand Up @@ -255,6 +262,7 @@ def __init__(
batch_norm and i == 0
), # batch norm only on first block of first stack
dropout=dropout,
activation=activation,
)
for i in range(num_blocks)
]
Expand Down Expand Up @@ -299,6 +307,7 @@ def __init__(
n_freq_downsample: Tuple[Tuple[int]],
batch_norm: bool,
dropout: float,
activation: str,
**kwargs,
):
"""PyTorch module implementing the N-HiTS architecture.
Expand Down Expand Up @@ -331,6 +340,8 @@ def __init__(
Whether to apply batch norm on first block of the first stack
dropout
Dropout probability
activation
The activation function of encoder/decoder intermediate layer, 'relu' or 'gelu'.
**kwargs
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.

Expand Down Expand Up @@ -369,6 +380,7 @@ def __init__(
batch_norm and i == 0
), # batch norm only on first block of first stack
dropout=dropout,
activation=activation,
)
for i in range(num_stacks)
]
Expand Down Expand Up @@ -429,7 +441,8 @@ def __init__(
layer_widths: Union[int, List[int]] = 512,
pooling_kernel_sizes: Optional[Tuple[Tuple[int]]] = None,
n_freq_downsample: Optional[Tuple[Tuple[int]]] = None,
dropout: float = 0.0,
dropout: float = 0.1,
activation: str = "relu",
**kwargs,
):
"""An implementation of the N-HiTS model, as presented in [1]_.
Expand Down Expand Up @@ -480,7 +493,9 @@ def __init__(
downsampling factors before interpolation, for each block in each stack.
If left to ``None``, some default values will be used based on ``output_chunk_length``.
dropout
The dropout probability to be used in the fully connected layers.
Fraction of neurons affected by Dropout (default=0.1).
activation
The activation function of encoder/decoder intermediate layer, 'relu' or 'gelu' (default='relu').
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -632,6 +647,7 @@ def __init__(
self.num_blocks = num_blocks
self.num_layers = num_layers
self.layer_widths = layer_widths
self.activation = activation

# Currently batch norm is not an option as it seems to perform badly
self.batch_norm = False
Expand Down Expand Up @@ -730,5 +746,6 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
n_freq_downsample=self.n_freq_downsample,
batch_norm=self.batch_norm,
dropout=self.dropout,
activation=self.activation,
**self.pl_module_params,
)