Skip to content

Commit 21a46a2

Browse files
authored
[ENH] Add Samformer model for PTF v2 from DSIPTS (#1952)
#### Reference Issues/PRs Closes #1940
1 parent da9fe90 commit 21a46a2

File tree

6 files changed

+420
-0
lines changed

6 files changed

+420
-0
lines changed

pytorch_forecasting/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Encoder,
2020
EncoderLayer,
2121
)
22+
from pytorch_forecasting.layers._normalization import RevIN
2223
from pytorch_forecasting.layers._output._flatten_head import (
2324
FlattenHead,
2425
)
@@ -50,6 +51,7 @@
5051
"sLSTMLayer",
5152
"sLSTMNetwork",
5253
"SeriesDecomposition",
54+
"RevIN",
5355
"ResidualBlock",
5456
"embedding_cat_variables",
5557
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
RevIN: Reverse Instance Normalization
3+
"""
4+
5+
from pytorch_forecasting.layers._normalization._revin import RevIN
6+
7+
__all__ = ["RevIN"]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
Reverse Instance Normalization (RevIN) layer.
3+
---------------------------------------------
4+
"""
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
10+
class RevIN(nn.Module):
11+
def __init__(self, num_features, eps=1e-5, affine=True, subtract_last=False):
12+
"""
13+
Reverse Instance Normalization (RevIN) layer.
14+
15+
Parameters
16+
----------
17+
num_features : int
18+
Number of input features.
19+
eps : float, optional
20+
A small value added to the denominator for numerical stability (default: 1e-5).
21+
affine : bool, optional
22+
If True, the layer will have learnable affine parameters (default: True).
23+
subtract_last: bool, optional
24+
If True, the last feature will be subtracted from the mean (default: False).
25+
""" # noqa: E501
26+
super().__init__()
27+
self.num_features = num_features
28+
self.eps = eps
29+
self.affine = affine
30+
self.subtract_last = subtract_last
31+
32+
if self.affine:
33+
self._init_params()
34+
35+
def forward(self, x, mode: str):
36+
if mode == "norm":
37+
self._get_statistics(x)
38+
x = self._normalize(x)
39+
elif mode == "denorm":
40+
x = self._denormalize(x)
41+
else:
42+
raise NotImplementedError
43+
return x
44+
45+
def _init_params(self):
46+
"""Initialize learnable parameters if affine is True."""
47+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
48+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
49+
50+
def _get_statistics(self, x):
51+
dim2reduce = tuple(range(1, x.ndim - 1))
52+
if self.subtract_last:
53+
self.last = x[:, -1, :].unsqueeze(1)
54+
else:
55+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
56+
self.stdev = torch.sqrt(
57+
torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
58+
).detach() # noqa: E501
59+
60+
def _normalize(self, x):
61+
if self.subtract_last:
62+
x = x - self.last
63+
else:
64+
x = x - self.mean
65+
x = x / self.stdev
66+
if self.affine:
67+
x = x * self.affine_weight
68+
x = x + self.affine_bias
69+
return x
70+
71+
def _denormalize(self, x):
72+
if self.affine:
73+
x = x - self.affine_bias
74+
x = x / (self.affine_weight + self.eps * self.eps)
75+
x = x * self.stdev
76+
if self.subtract_last:
77+
x = x + self.last
78+
else:
79+
x = x + self.mean
80+
return x
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
DSIPTS Implementation of Samformer for V2
3+
--------------------------------------
4+
"""
5+
6+
from pytorch_forecasting.models.samformer._samformer_v2 import Samformer
7+
from pytorch_forecasting.models.samformer._samformer_v2_pkg import Samformer_pkg_v2
8+
9+
__all__ = ["Samformer", "Samformer_pkg_v2"]
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Samformer Model from DSIPTS for PyTorch Forecasting
3+
---------------------------------------------------
4+
"""
5+
6+
import math
7+
from typing import Optional, Union
8+
9+
import numpy as np
10+
import torch
11+
import torch.nn as nn
12+
from torch.optim import Optimizer
13+
14+
from pytorch_forecasting.layers import RevIN
15+
from pytorch_forecasting.models.base._base_model_v2 import BaseModel
16+
17+
18+
class Samformer(BaseModel):
19+
"""
20+
Samformer: Unlocking the Potential of Transformers in Time Series Forecasting
21+
with Sharpness-Aware Minimization and Channel-Wise Attention.
22+
23+
Parameters
24+
----------
25+
out_channels : int, optional
26+
Number of variables to be predicted. Default is 1.
27+
hidden_size : int, optional
28+
First embedding size of the model ('r' in the paper). Default is 512.
29+
use_revin : bool, optional
30+
Whether to use Reverse Instance Normalization. Default is True.
31+
persistence_weight : float, optional
32+
Weight for persistence baseline. Default is 0.0.
33+
"""
34+
35+
@classmethod
36+
def _pkg(cls):
37+
"""Return the package class for this model."""
38+
from pytorch_forecasting.models.samformer._samformer_v2_pkg import (
39+
Samformer_pkg_v2,
40+
)
41+
42+
return Samformer_pkg_v2
43+
44+
def __init__(
45+
self,
46+
loss: nn.Module,
47+
# specific params
48+
hidden_size: int,
49+
use_revin: bool,
50+
# out_channels has to be 1, due to lack of MultiLoss support in v2.
51+
out_channels: Optional[Union[int, list[int]]] = 1,
52+
persistence_weight: float = 0.0,
53+
logging_metrics: Optional[list[nn.Module]] = None,
54+
optimizer: Optional[Union[Optimizer, str]] = "adam",
55+
optimizer_params: Optional[dict] = None,
56+
lr_scheduler: Optional[str] = None,
57+
lr_scheduler_params: Optional[dict] = None,
58+
metadata: Optional[dict] = None,
59+
**kwargs,
60+
):
61+
super().__init__(
62+
loss=loss,
63+
logging_metrics=logging_metrics,
64+
optimizer=optimizer,
65+
optimizer_params=optimizer_params,
66+
lr_scheduler=lr_scheduler,
67+
lr_scheduler_params=lr_scheduler_params,
68+
)
69+
70+
self.save_hyperparameters(ignore=["loss", "logging_metrics", "optimizer"])
71+
self.metadata = metadata
72+
self.n_quantiles = 1
73+
74+
if hasattr(loss, "quantiles") and loss.quantiles is not None:
75+
self.n_quantiles = len(loss.quantiles)
76+
77+
self.max_encoder_length = self.metadata["max_encoder_length"]
78+
self.max_prediction_length = self.metadata["max_prediction_length"]
79+
self.encoder_cont = self.metadata["encoder_cont"]
80+
self.encoder_input_dim = self.encoder_cont + 1 # +1 for target variable input.
81+
82+
self.hidden_size = hidden_size
83+
if out_channels != 1:
84+
raise ValueError(
85+
"out_channels has to be 1 for Samformer,",
86+
" due to lack of MultiLoss support in v2.",
87+
)
88+
self.out_channels = out_channels
89+
self.use_revin = use_revin
90+
self.persistence_weight = persistence_weight
91+
92+
if self.use_revin:
93+
self.revin = RevIN(num_features=self.encoder_input_dim)
94+
95+
self.compute_keys = nn.Linear(self.max_encoder_length, self.hidden_size)
96+
self.compute_queries = nn.Linear(self.max_encoder_length, self.hidden_size)
97+
self.compute_values = nn.Linear(
98+
self.max_encoder_length, self.max_encoder_length
99+
) # noqa: E501
100+
self.linear_forecaster = nn.Linear(
101+
self.max_encoder_length, self.max_prediction_length
102+
) # noqa: E501
103+
104+
def _scaled_dot_product_attention(
105+
self,
106+
query,
107+
key,
108+
value,
109+
attn_mask=None,
110+
dropout_p=0.0,
111+
is_causal=False,
112+
scale=None,
113+
enable_gqa=False,
114+
) -> torch.Tensor:
115+
L, S = query.size(-2), key.size(-2)
116+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
117+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
118+
if is_causal:
119+
assert attn_mask is None
120+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
121+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
122+
attn_bias.to(query.dtype)
123+
124+
if attn_mask is not None:
125+
if attn_mask.dtype == torch.bool:
126+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
127+
else:
128+
attn_bias = attn_mask + attn_bias
129+
130+
if enable_gqa:
131+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
132+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
133+
134+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
135+
attn_weight += attn_bias
136+
attn_weight = torch.softmax(attn_weight, dim=-1)
137+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
138+
return attn_weight @ value
139+
140+
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
141+
"""
142+
Forward pass of the model.
143+
144+
Parameters
145+
----------
146+
x : dict[str, torch.Tensor]
147+
Input data containing past and future sequences.
148+
149+
Returns
150+
-------
151+
dict[str, torch.Tensor]
152+
Output predictions.
153+
"""
154+
encoder_cont = x["encoder_cont"]
155+
target = x["target_past"]
156+
input_tensor = torch.cat((encoder_cont, target), dim=-1)
157+
# batch_size = input_tensor.shape[0]
158+
159+
if self.use_revin:
160+
x_norm = self.revin(input_tensor, mode="norm").transpose(1, 2)
161+
else:
162+
x_norm = input_tensor.transpose(1, 2)
163+
164+
queries = self.compute_queries(x_norm)
165+
keys = self.compute_keys(x_norm)
166+
values = self.compute_values(x_norm)
167+
168+
att_score = self._scaled_dot_product_attention(queries, keys, values)
169+
170+
out = x_norm + att_score
171+
out = self.linear_forecaster(out)
172+
173+
out = out.transpose(1, 2)
174+
175+
target_predictions = out[:, :, -1] # (batch_size, max_prediction_length)
176+
177+
if target_predictions.ndim == 1:
178+
target_predictions = target_predictions.unsqueeze(0)
179+
180+
if self.n_quantiles > 1:
181+
target_predictions = target_predictions.unsqueeze(-1).expand(
182+
-1, -1, self.n_quantiles
183+
)
184+
elif self.n_quantiles == 1:
185+
target_predictions = target_predictions.unsqueeze(-1)
186+
return {"prediction": target_predictions}

0 commit comments

Comments
 (0)