-
Notifications
You must be signed in to change notification settings - Fork 124
/
estimator.py
490 lines (460 loc) · 19.3 KB
/
estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
# Copyright 2024 Arjun Ashok
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional
import pytorch_lightning as pl
import torch
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.itertools import Cyclic
from gluonts.time_feature import (
get_lags_for_frequency,
time_features_from_frequency_str,
)
from gluonts.torch.distributions import StudentTOutput, NegativeBinomialOutput
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.transform import (
AddObservedValuesIndicator,
AddTimeFeatures,
Chain,
DummyValueImputation,
ExpectedNumInstanceSampler,
InstanceSampler,
InstanceSplitter,
TestSplitSampler,
Transformation,
ValidationSplitSampler,
)
from gluon_utils.gluon_ts_distributions.implicit_quantile_network import (
ImplicitQuantileNetworkOutput,
)
from lag_llama.gluon.lightning_module import LagLlamaLightningModule
PREDICTION_INPUT_NAMES = [
"past_target",
"past_observed_values",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
class LagLlamaEstimator(PyTorchLightningEstimator):
"""
An estimator training a ConvTSMixer model for forecasting.
This class is uses the model defined in ``ConvTSMixerModel``,
and wraps it into a ``ConvTSMixerLightningModule`` for training
purposes: training is performed using PyTorch Lightning's ``pl.Trainer``
class.
Parameters
----------
prediction_length
Length of the prediction horizon.
context_length
Number of time steps prior to prediction time that the model
takes as inputs (default: ``10 * prediction_length``).
lr
Learning rate (default: ``1e-3``).
weight_decay
Weight decay regularization parameter (default: ``1e-8``).
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
loss
Loss to be optimized during training
(default: ``NegativeLogLikelihood()``).
batch_norm
Whether to apply batch normalization.
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Number of batches to be processed in each training epoch
(default: 50).
trainer_kwargs
Additional arguments to provide to ``pl.Trainer`` for construction.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
use_single_pass_sampling
If True, use a single forward pass and sample N times from the saved distribution, much more efficient.
If False, perform N forward passes and maintain N parallel prediction paths, this is true probalistic forecasting.
(default: False)
"""
@validated()
def __init__(
self,
prediction_length: int,
context_length: Optional[int] = None,
input_size: int = 1,
n_layer: int = 1,
n_embd_per_head: int = 32,
n_head: int = 4,
max_context_length: int = 2048,
rope_scaling=None,
scaling: Optional[str] = "mean",
lr: float = 1e-3,
weight_decay: float = 1e-8,
# Augmentations arguments
aug_prob: float = 0.1,
freq_mask_rate: float = 0.1,
freq_mixing_rate: float = 0.1,
jitter_prob: float = 0.0,
jitter_sigma: float = 0.03,
scaling_prob: float = 0.0,
scaling_sigma: float = 0.1,
rotation_prob: float = 0.0,
permutation_prob: float = 0.0,
permutation_max_segments: int = 5,
permutation_seg_mode: str = "equal",
magnitude_warp_prob: float = 0.0,
magnitude_warp_sigma: float = 0.2,
magnitude_warp_knot: int = 4,
time_warp_prob: float = 0.0,
time_warp_sigma: float = 0.2,
time_warp_knot: int = 4,
window_slice_prob: float = 0.0,
window_slice_reduce_ratio: float = 0.9,
window_warp_prob: float = 0.0,
window_warp_window_ratio: float = 0.1,
window_warp_scales: list = [0.5, 2.0],
# Continuning model arguments
distr_output: str = "studentT",
loss: DistributionLoss = NegativeLogLikelihood(),
num_parallel_samples: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
time_feat: bool = False,
dropout: float = 0.0,
lags_seq: list = ["Q", "M", "W", "D", "H", "T", "S"],
data_id_to_name_map: dict = {},
use_cosine_annealing_lr: bool = False,
cosine_annealing_lr_args: dict = {},
track_loss_per_series: bool = False,
ckpt_path: Optional[str] = None,
nonnegative_pred_samples: bool = False,
use_single_pass_sampling: bool = False,
device: torch.device = torch.device("cuda")
) -> None:
default_trainer_kwargs = {"max_epochs": 100}
if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
self.scaling = scaling
self.input_size = input_size
self.prediction_length = prediction_length
self.context_length = context_length
self.max_context_length = max_context_length
lag_indices = []
for freq in lags_seq:
lag_indices.extend(
get_lags_for_frequency(freq_str=freq, num_default_lags=1)
)
if len(lag_indices):
self.lags_seq = sorted(set(lag_indices))
self.lags_seq = [lag_index - 1 for lag_index in self.lags_seq]
else:
self.lags_seq = []
self.n_head = n_head
self.n_layer = n_layer
self.n_embd_per_head = n_embd_per_head
self.rope_scaling = rope_scaling
self.lr = lr
self.weight_decay = weight_decay
if distr_output == "studentT":
distr_output = StudentTOutput()
elif distr_output == "neg_bin":
distr_output = NegativeBinomialOutput()
elif distr_output == "iqn":
distr_output = ImplicitQuantileNetworkOutput()
self.distr_output = distr_output
self.num_parallel_samples = num_parallel_samples
self.loss = loss
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.nonnegative_pred_samples = nonnegative_pred_samples
self.use_single_pass_sampling = use_single_pass_sampling
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0,
min_future=prediction_length,
min_instances=1,
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
self.aug_prob = aug_prob
self.freq_mask_rate = freq_mask_rate
self.freq_mixing_rate = freq_mixing_rate
self.jitter_prob = jitter_prob
self.jitter_sigma = jitter_sigma
self.scaling_prob = scaling_prob
self.scaling_sigma = scaling_sigma
self.rotation_prob = rotation_prob
self.permutation_prob = permutation_prob
self.permutation_max_segments = permutation_max_segments
self.permutation_seg_mode = permutation_seg_mode
self.magnitude_warp_prob = magnitude_warp_prob
self.magnitude_warp_sigma = magnitude_warp_sigma
self.magnitude_warp_knot = magnitude_warp_knot
self.time_warp_prob = time_warp_prob
self.time_warp_sigma = time_warp_sigma
self.time_warp_knot = time_warp_knot
self.window_slice_prob = window_slice_prob
self.window_slice_reduce_ratio = window_slice_reduce_ratio
self.window_warp_prob = window_warp_prob
self.window_warp_window_ratio = window_warp_window_ratio
self.window_warp_scales = window_warp_scales
self.track_loss_per_series = track_loss_per_series
self.time_feat = time_feat
self.dropout = dropout
self.data_id_to_name_map = data_id_to_name_map
self.ckpt_path = ckpt_path
self.use_cosine_annealing_lr = use_cosine_annealing_lr
self.cosine_annealing_lr_args = cosine_annealing_lr_args
self.device = device
@classmethod
def derive_auto_fields(cls, train_iter):
stats = calculate_dataset_statistics(train_iter)
return {
"num_feat_dynamic_real": stats.num_feat_dynamic_real,
"num_feat_static_cat": len(stats.feat_static_cat),
"cardinality": [len(cats) for cats in stats.feat_static_cat],
}
def create_transformation(self) -> Transformation:
if self.time_feat:
return Chain(
[
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features_from_frequency_str("S"),
pred_length=self.prediction_length,
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
imputation_method=DummyValueImputation(0.0),
),
]
)
else:
return Chain(
[
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
imputation_method=DummyValueImputation(0.0),
),
]
)
def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningModule:
model_kwargs = {
"input_size": self.input_size,
"context_length": self.context_length,
"max_context_length": self.max_context_length,
"lags_seq": self.lags_seq,
"n_layer": self.n_layer,
"n_embd_per_head": self.n_embd_per_head,
"n_head": self.n_head,
"scaling": self.scaling,
"distr_output": self.distr_output,
"num_parallel_samples": self.num_parallel_samples,
"rope_scaling": self.rope_scaling,
"time_feat": self.time_feat,
"dropout": self.dropout,
}
if self.ckpt_path is not None:
return LagLlamaLightningModule.load_from_checkpoint(
checkpoint_path=self.ckpt_path,
map_location=self.device,
strict=False,
loss=self.loss,
lr=self.lr,
weight_decay=self.weight_decay,
context_length=self.context_length,
prediction_length=self.prediction_length,
model_kwargs=model_kwargs,
# Augmentations
aug_prob=self.aug_prob,
freq_mask_rate=self.freq_mask_rate,
freq_mixing_rate=self.freq_mixing_rate,
jitter_prob=self.jitter_prob,
jitter_sigma=self.jitter_sigma,
scaling_prob=self.scaling_prob,
scaling_sigma=self.scaling_sigma,
rotation_prob=self.rotation_prob,
permutation_prob=self.permutation_prob,
permutation_max_segments=self.permutation_max_segments,
permutation_seg_mode=self.permutation_seg_mode,
magnitude_warp_prob=self.magnitude_warp_prob,
magnitude_warp_sigma=self.magnitude_warp_sigma,
magnitude_warp_knot=self.magnitude_warp_knot,
time_warp_prob=self.time_warp_prob,
time_warp_sigma=self.time_warp_sigma,
time_warp_knot=self.time_warp_knot,
window_slice_prob=self.window_slice_prob,
window_slice_reduce_ratio=self.window_slice_reduce_ratio,
window_warp_prob=self.window_warp_prob,
window_warp_window_ratio=self.window_warp_window_ratio,
window_warp_scales=self.window_warp_scales,
use_kv_cache=use_kv_cache,
data_id_to_name_map=self.data_id_to_name_map,
use_cosine_annealing_lr=self.use_cosine_annealing_lr,
cosine_annealing_lr_args=self.cosine_annealing_lr_args,
track_loss_per_series=self.track_loss_per_series,
nonnegative_pred_samples=self.nonnegative_pred_samples,
)
else:
return LagLlamaLightningModule(
loss=self.loss,
lr=self.lr,
weight_decay=self.weight_decay,
context_length=self.context_length,
prediction_length=self.prediction_length,
model_kwargs=model_kwargs,
# Augmentations
aug_prob=self.aug_prob,
freq_mask_rate=self.freq_mask_rate,
freq_mixing_rate=self.freq_mixing_rate,
jitter_prob=self.jitter_prob,
jitter_sigma=self.jitter_sigma,
scaling_prob=self.scaling_prob,
scaling_sigma=self.scaling_sigma,
rotation_prob=self.rotation_prob,
permutation_prob=self.permutation_prob,
permutation_max_segments=self.permutation_max_segments,
permutation_seg_mode=self.permutation_seg_mode,
magnitude_warp_prob=self.magnitude_warp_prob,
magnitude_warp_sigma=self.magnitude_warp_sigma,
magnitude_warp_knot=self.magnitude_warp_knot,
time_warp_prob=self.time_warp_prob,
time_warp_sigma=self.time_warp_sigma,
time_warp_knot=self.time_warp_knot,
window_slice_prob=self.window_slice_prob,
window_slice_reduce_ratio=self.window_slice_reduce_ratio,
window_warp_prob=self.window_warp_prob,
window_warp_window_ratio=self.window_warp_window_ratio,
window_warp_scales=self.window_warp_scales,
use_kv_cache=use_kv_cache,
data_id_to_name_map=self.data_id_to_name_map,
use_cosine_annealing_lr=self.use_cosine_annealing_lr,
cosine_annealing_lr_args=self.cosine_annealing_lr_args,
track_loss_per_series=self.track_loss_per_series,
nonnegative_pred_samples=self.nonnegative_pred_samples,
)
def _create_instance_splitter(self, module: LagLlamaLightningModule, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.context_length + max(self.lags_seq),
future_length=self.prediction_length,
time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES]
if self.time_feat
else [FieldName.OBSERVED_VALUES],
dummy_value=self.distr_output.value_in_support,
)
def create_training_data_loader(
self,
data: Dataset,
module: LagLlamaLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
data = Cyclic(data).stream()
instances = self._create_instance_splitter(module, "training").apply(
data, is_train=True
)
if self.time_feat:
return as_stacked_batches(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES
+ ["past_time_feat", "future_time_feat"],
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
else:
return as_stacked_batches(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
def create_validation_data_loader(
self,
data: Dataset,
module: LagLlamaLightningModule,
**kwargs,
) -> Iterable:
instances = self._create_instance_splitter(module, "validation").apply(
data, is_train=True
)
if self.time_feat:
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES
+ ["past_time_feat", "future_time_feat"],
output_type=torch.tensor,
)
else:
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
)
def create_predictor(
self,
transformation: Transformation,
module,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
if self.time_feat:
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES
+ ["past_time_feat", "future_time_feat"],
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="cuda" if torch.cuda.is_available() else "cpu",
)
else:
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="cuda" if torch.cuda.is_available() else "cpu",
)