Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalization to BlockRNNModel #1748

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 71 additions & 14 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import ExtractRnnOutput, TemporalBatchNorm1d

logger = get_logger(__name__)

Expand All @@ -30,6 +31,7 @@ def __init__(
nr_params: int,
num_layers_out_fc: Optional[List] = None,
dropout: float = 0.0,
normalization: str = None,
**kwargs,
):
"""This class allows to create custom block RNN modules that can later be used with Darts'
Expand Down Expand Up @@ -63,6 +65,8 @@ def __init__(
This network connects the last hidden layer of the PyTorch RNN module to the output.
dropout
The fraction of neurons that are dropped in all-but-last RNN layers.
normalization
The name of the normalization applied after RNN and FC layers ("batch", "layer")
**kwargs
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.
"""
Expand All @@ -77,6 +81,7 @@ def __init__(
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.dropout = dropout
self.out_len = self.output_chunk_length
self.normalization = normalization

@io_processor
@abstractmethod
Expand Down Expand Up @@ -143,37 +148,34 @@ def __init__(
self.name = name

# Defining the RNN module
self.rnn = getattr(nn, self.name)(
self.rnn = self._rnn_sequence(
name,
self.input_size,
self.hidden_dim,
self.num_layers,
batch_first=True,
dropout=self.dropout,
self.dropout,
self.normalization,
)

# The RNN module is followed by a fully connected layer, which maps the last hidden layer
# to the output of desired length
last = self.hidden_dim
feats = []
for feature in self.num_layers_out_fc + [
self.out_len * self.target_size * self.nr_params
]:
feats.append(nn.Linear(last, feature))
last = feature
self.fc = nn.Sequential(*feats)
self.fc = self._fc_layer(
self.hidden_dim,
self.num_layers_out_fc,
self.target_size,
self.normalization,
)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in
# data is of size (batch_size, input_chunk_length, input_size)
batch_size = x.size(0)

out, hidden = self.rnn(x)
hidden = self.rnn(x)

""" Here, we apply the FC network only on the last output point (at the last time step)
"""
if self.name == "LSTM":
hidden = hidden[0]
predictions = hidden[-1, :, :]
predictions = self.fc(predictions)
predictions = predictions.view(
Expand All @@ -183,6 +185,61 @@ def forward(self, x_in: Tuple):
# predictions is of size (batch_size, output_chunk_length, 1)
return predictions

def _rnn_sequence(
self,
name: str,
input_size: int,
hidden_dim: int,
num_layers: int,
dropout: float = 0.0,
normalization: str = None,
):

modules = []
is_lstm = self.name == "LSTM"
for i in range(num_layers):
input = input_size if (i == 0) else hidden_dim
is_last = i == num_layers - 1
rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True)

modules.append(rnn)
modules.append(ExtractRnnOutput(not is_last, is_lstm))
modules.append(nn.Dropout(dropout))
if normalization:
modules.append(self._normalization_layer(normalization, hidden_dim))
if not is_last: # pytorch RNNs don't have dropout applied on the last layer
modules.append(nn.Dropout(dropout))

return nn.Sequential(*modules)

def _fc_layer(
self,
input_size: int,
num_layers_out_fc: list[int],
target_size: int,
normalization: str = None,
):
if not num_layers_out_fc:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this point here.

num_layers_out_fc is a list of integers correct?
Suppose num_layers_out_fc = [], then not num_layers_out_fc is True.
So why num_layers_out_fc = [] ?

num_layers_out_fc = []

last = input_size
feats = []
for feature in num_layers_out_fc + [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will rather use the extend method for lists

self.output_chunk_length * target_size * self.nr_params
]:
if normalization:
feats.append(self._normalization_layer(normalization, last))
feats.append(nn.Linear(last, feature))
last = feature
return nn.Sequential(*feats)

def _normalization_layer(self, normalization: str, hidden_size: int):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if normalization is different from batch and layer the method return None. is this intended?


if normalization == "batch":
return TemporalBatchNorm1d(hidden_size)
elif normalization == "layer":
return nn.LayerNorm(hidden_size)


class BlockRNNModel(PastCovariatesTorchModel):
def __init__(
Expand Down
27 changes: 27 additions & 0 deletions darts/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,30 @@ def decorator(self, *args, **kwargs) -> T:
return decorated(self, *args, **kwargs)

return decorator


class TemporalBatchNorm1d(nn.Module):
def __init__(self, feature_size) -> None:
super().__init__()
self.norm = nn.BatchNorm1d(feature_size)

def forward(self, input):
input = input.swapaxes(1, 2)
input = self.norm(input)
input = input.swapaxes(1, 2)
return input if len(input) > 1 else input[0]


class ExtractRnnOutput(nn.Module):
def __init__(self, is_output, is_lstm) -> None:
self.is_output = is_output
self.is_lstm = is_lstm
super().__init__()

def forward(self, input):
output, hidden = input
if self.is_output:
return output
if self.is_lstm:
return hidden[0]
return hidden