Skip to content

Commit

Permalink
Merge branch 'master' into fix/unit8co#1101
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed Aug 7, 2022
2 parents 4f53c86 + eb18103 commit a6356d4
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 12 deletions.
56 changes: 56 additions & 0 deletions darts/models/components/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

from darts.utils.torch import MonteCarloDropout


class CustomFeedForwardEncoderLayer(nn.TransformerEncoderLayer):
"""Overwrites the PyTorch TransformerEncoderLayer to use Darts' Position-wise Feed-Forward variants."""

def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
"""
Parameters
----------
ffn
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
dropout
Fraction of neurons affected by Dropout (default=0.1).
args
positional arguments from torch.nn.TransformerEncoderLayer.
kwargs
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
"""
super().__init__(*args, **kwargs)
self.ffn = ffn
self.dropout = MonteCarloDropout(dropout)

# overwrite the feed forward block
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.ffn(x)
return self.dropout(x)


class CustomFeedForwardDecoderLayer(nn.TransformerDecoderLayer):
"""Overwrites the PyTorch TransformerDecoderLayer to use Darts' custom Position Wise Feed Forward Layers."""

def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
"""
Parameters
----------
ffn
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
dropout
Fraction of neurons affected by Dropout (default=0.1).
args
positional arguments from torch.nn.TransformerEncoderLayer.
kwargs
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
"""
super().__init__(*args, **kwargs)
self.ffn = ffn
self.dropout = MonteCarloDropout(dropout)

# overwrite the feed forward block
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.ffn(x)
return self.dropout(x)
2 changes: 1 addition & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def load_model(path: str) -> "TorchForecastingModel":
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
if os.path.exists(path_ptl_ckpt):
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
model.trainer = model.model.trainer
model.trainer = None

return model

Expand Down
74 changes: 66 additions & 8 deletions darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
import torch
import torch.nn as nn

from darts.logging import get_logger, raise_if_not
from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.components import glu_variants
from darts.models.components.glu_variants import GLU_FFN
from darts.models.components.transformer import (
CustomFeedForwardDecoderLayer,
CustomFeedForwardEncoderLayer,
)
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel

Expand All @@ -22,6 +26,34 @@
FFN = GLU_FFN + BUILT_IN


def _generate_coder(
d_model, dim_ff, dropout, nhead, num_layers, coder_cls, layer_cls, ffn_cls
):
"""Generates an Encoder or Decoder with one of Darts' Feed-forward Network variants.
Parameters
----------
coder_cls
Either `torch.nn.TransformerEncoder` or `...TransformerDecoder`
layer_cls
Either `darts.models.components.transformer.CustomFeedForwardEncoderLayer` or
`...CustomFeedForwardDecoderLayer`
ffn_cls
One of Darts' Position-wise Feed-Forward Network variants `from darts.models.components.glu_variants`
"""
layer = layer_cls(
ffn=ffn_cls(d_model=d_model, d_ff=dim_ff, dropout=dropout),
dropout=dropout,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ff,
)
return coder_cls(
layer,
num_layers=num_layers,
norm=nn.LayerNorm(d_model),
)


# This implementation of positional encoding is taken from the PyTorch documentation:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class _PositionalEncoding(nn.Module):
Expand Down Expand Up @@ -142,13 +174,39 @@ def __init__(

raise_if_not(activation in FFN, f"'{activation}' is not in {FFN}")
if activation in GLU_FFN:
# use glu variant feedforward layers
self.activation = getattr(glu_variants, activation)(
d_model=d_model, d_ff=dim_feedforward, dropout=dropout
raise_if(
custom_encoder is not None or custom_decoder is not None,
"Cannot use `custom_encoder` or `custom_decoder` along with an `activation` from "
f"{GLU_FFN}",
logger=logger,
)
# use glu variant feed-forward layers
ffn_cls = getattr(glu_variants, activation)

# custom feed-forward layers have activation built-in. reset activation
activation = None

custom_encoder = _generate_coder(
d_model,
dim_feedforward,
dropout,
nhead,
num_encoder_layers,
nn.TransformerEncoder,
CustomFeedForwardEncoderLayer,
ffn_cls,
)

custom_decoder = _generate_coder(
d_model,
dim_feedforward,
dropout,
nhead,
num_decoder_layers,
nn.TransformerDecoder,
CustomFeedForwardDecoderLayer,
ffn_cls,
)
else:
# use nn.Transformer built in feedforward layers
self.activation = activation

# Defining the Transformer module
self.transformer = nn.Transformer(
Expand All @@ -158,7 +216,7 @@ def __init__(
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=self.activation,
activation=activation,
custom_encoder=custom_encoder,
custom_decoder=custom_decoder,
)
Expand Down
24 changes: 22 additions & 2 deletions darts/tests/models/forecasting/test_transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
logger = get_logger(__name__)

try:
import torch.nn as nn

from darts.models.components.transformer import (
CustomFeedForwardDecoderLayer,
CustomFeedForwardEncoderLayer,
)
from darts.models.forecasting.transformer_model import (
TransformerModel,
_TransformerModule,
Expand Down Expand Up @@ -118,14 +124,28 @@ def test_activations(self):
)
model1.fit(self.series, epochs=1)

# internal activation function
# internal activation function uses PyTorch TransformerEncoderLayer
model2 = TransformerModel(
input_chunk_length=1, output_chunk_length=1, activation="gelu"
)
model2.fit(self.series, epochs=1)
assert isinstance(
model2.model.transformer.encoder.layers[0], nn.TransformerEncoderLayer
)
assert isinstance(
model2.model.transformer.decoder.layers[0], nn.TransformerDecoderLayer
)

# glue variant FFN
# glue variant FFN uses our custom _FeedForwardEncoderLayer
model3 = TransformerModel(
input_chunk_length=1, output_chunk_length=1, activation="SwiGLU"
)
model3.fit(self.series, epochs=1)
assert isinstance(
model3.model.transformer.encoder.layers[0],
CustomFeedForwardEncoderLayer,
)
assert isinstance(
model3.model.transformer.decoder.layers[0],
CustomFeedForwardDecoderLayer,
)
2 changes: 1 addition & 1 deletion darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def from_dataframe(
else:
raise_if_not(
isinstance(df.index, VALID_INDEX_TYPES),
"If time_col is not specified, the DataFrame must be indexed either with"
"If time_col is not specified, the DataFrame must be indexed either with "
"a DatetimeIndex, or with a RangeIndex.",
logger,
)
Expand Down

0 comments on commit a6356d4

Please sign in to comment.