-
Notifications
You must be signed in to change notification settings - Fork 110
/
lightning_module.py
484 lines (447 loc) · 19.4 KB
/
lightning_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# Following code is to import from the parent directory (for augmentation)
import random
import numpy as np
from lightning import LightningModule
import torch
import torch.nn.functional as F
from gluonts.core.component import validated
from gluonts.itertools import prod
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import repeat_along_dim, take_last
from data.augmentations.freq_mask import freq_mask
from data.augmentations.freq_mix import freq_mix
from data.augmentations.gluonts_augmentations import (
ApplyAugmentations,
Jitter,
MagnitudeWarp,
Permutation,
Rotation,
Scaling,
TimeWarp,
WindowSlice,
WindowWarp,
)
from gluon_utils.gluon_ts_distributions.implicit_quantile_network import (
ImplicitQuantileNetworkOutput,
)
from lag_llama.model.module import LagLlamaModel
class LagLlamaLightningModule(LightningModule):
"""
A ``pl.LightningModule`` class that can be used to train a
``LagLlamaLightningModule`` with PyTorch Lightning.
This is a thin layer around a (wrapped) ``LagLlamaLightningModule`` object,
that exposes the methods to evaluate training and validation loss.
Parameters
----------
model
``LagLlamaLightningModule`` to be trained.
loss
Loss function to be used for training,
default: ``NegativeLogLikelihood()``.
lr
Learning rate, default: ``1e-3``.
weight_decay
Weight decay regularization parameter, default: ``1e-8``.
"""
@validated()
def __init__(
self,
model_kwargs: dict,
context_length: int,
prediction_length: int,
loss: DistributionLoss = NegativeLogLikelihood(),
lr: float = 1e-3,
weight_decay: float = 1e-8,
aug_prob: float = 0.1,
freq_mask_rate: float = 0.1,
freq_mixing_rate: float = 0.1,
jitter_prob: float = 0.0,
jitter_sigma: float = 0.03,
scaling_prob: float = 0.0,
scaling_sigma: float = 0.1,
rotation_prob: float = 0.0,
permutation_prob: float = 0.0,
permutation_max_segments: int = 5,
permutation_seg_mode: str = "equal",
magnitude_warp_prob: float = 0.0,
magnitude_warp_sigma: float = 0.2,
magnitude_warp_knot: int = 4,
time_warp_prob: float = 0.0,
time_warp_sigma: float = 0.2,
time_warp_knot: int = 4,
window_slice_prob: float = 0.0,
window_slice_reduce_ratio: float = 0.9,
window_warp_prob: float = 0.0,
window_warp_window_ratio: float = 0.1,
window_warp_scales: list = [0.5, 2.0],
data_id_to_name_map: dict = {},
use_cosine_annealing_lr: bool = False,
cosine_annealing_lr_args: dict = {},
track_loss_per_series: bool = False,
nonnegative_pred_samples: bool = False,
use_kv_cache: bool = True,
):
super().__init__()
self.save_hyperparameters()
self.context_length = self.hparams.context_length
self.prediction_length = self.hparams.prediction_length
self.model = LagLlamaModel(**self.hparams.model_kwargs)
self.loss = self.hparams.loss
self.lr = self.hparams.lr
self.weight_decay = self.hparams.weight_decay
self.aug_prob = self.hparams.aug_prob
self.freq_mask_rate = self.hparams.freq_mask_rate
self.freq_mixing_rate = self.hparams.freq_mixing_rate
self.jitter_prob = self.hparams.jitter_prob
self.jitter_sigma = self.hparams.jitter_sigma
self.scaling_prob = self.hparams.scaling_prob
self.scaling_sigma = self.hparams.scaling_sigma
self.rotation_prob = self.hparams.rotation_prob
self.permutation_prob = self.hparams.permutation_prob
self.permutation_max_segments = self.hparams.permutation_max_segments
self.permutation_seg_mode = self.hparams.permutation_seg_mode
self.magnitude_warp_prob = self.hparams.magnitude_warp_prob
self.magnitude_warp_sigma = self.hparams.magnitude_warp_sigma
self.magnitude_warp_knot = self.hparams.magnitude_warp_knot
self.time_warp_prob = self.hparams.time_warp_prob
self.time_warp_sigma = self.hparams.time_warp_sigma
self.time_warp_knot = self.hparams.time_warp_knot
self.window_slice_prob = self.hparams.window_slice_prob
self.window_slice_reduce_ratio = self.hparams.window_slice_reduce_ratio
self.window_warp_prob = self.hparams.window_warp_prob
self.window_warp_window_ratio = self.hparams.window_warp_window_ratio
self.window_warp_scales = self.hparams.window_warp_scales
self.data_id_to_name_map = self.hparams.data_id_to_name_map
self.use_cosine_annealing_lr = self.hparams.use_cosine_annealing_lr
self.cosine_annealing_lr_args = self.hparams.cosine_annealing_lr_args
self.track_loss_per_series = self.hparams.track_loss_per_series
self.nonnegative_pred_samples = self.hparams.nonnegative_pred_samples
self.time_feat = self.hparams.model_kwargs["time_feat"]
# data_id based
self.train_loss_dict = {}
self.val_loss_dict = {}
# item_id based - to be used only in single-dataset mode
self.train_loss_dict_per_series = {}
self.val_loss_dict_per_series = {}
self.use_kv_cache = use_kv_cache
self.transforms = []
aug_probs = dict(
Jitter=dict(prob=self.jitter_prob, sigma=self.jitter_sigma),
Scaling=dict(prob=self.scaling_prob, sigma=self.scaling_sigma),
Rotation=dict(prob=self.rotation_prob),
Permutation=dict(
prob=self.permutation_prob,
max_segments=self.permutation_max_segments,
seg_mode=self.permutation_seg_mode,
),
MagnitudeWarp=dict(
prob=self.magnitude_warp_prob,
sigma=self.magnitude_warp_sigma,
knot=self.magnitude_warp_knot,
),
TimeWarp=dict(
prob=self.time_warp_prob,
sigma=self.time_warp_sigma,
knot=self.time_warp_knot,
),
WindowSlice=dict(
prob=self.window_slice_prob, reduce_ratio=self.window_slice_reduce_ratio
),
WindowWarp=dict(
prob=self.window_warp_prob,
window_ratio=self.window_warp_window_ratio,
warp_slices=self.window_warp_scales,
),
)
for aug, params in aug_probs.items():
if params["prob"] > 0:
if aug == "Jitter":
self.transforms.append(Jitter(params["prob"], params["sigma"]))
elif aug == "Scaling":
self.transforms.append(Scaling(params["prob"], params["sigma"]))
elif aug == "Rotation":
self.transforms.append(Rotation(params["prob"]))
elif aug == "Permutation":
self.transforms.append(
Permutation(
params["prob"], params["max_segments"], params["seg_mode"]
)
)
elif aug == "MagnitudeWarp":
self.transforms.append(
MagnitudeWarp(params["prob"], params["sigma"], params["knot"])
)
elif aug == "TimeWarp":
self.transforms.append(
TimeWarp(params["prob"], params["sigma"], params["knot"])
)
elif aug == "WindowSlice":
self.transforms.append(
WindowSlice(params["prob"], params["reduce_ratio"])
)
elif aug == "WindowWarp":
self.transforms.append(
WindowWarp(
params["prob"],
params["window_ratio"],
params["warp_slices"],
)
)
self.augmentations = ApplyAugmentations(self.transforms)
# greedy prediction
def forward(self, *args, **kwargs):
past_target = kwargs[
"past_target"
] # (bsz, model.context_length+max(model.lags_seq))
past_observed_values = kwargs[
"past_observed_values"
] # (bsz, model.context_length+max(model.lags_seq))
if self.time_feat:
past_time_feat = kwargs["past_time_feat"]
future_time_feat = kwargs["future_time_feat"]
repeated_past_time_feat = past_time_feat.repeat_interleave(
self.model.num_parallel_samples, 0
)
repeated_future_time_feat = future_time_feat.repeat_interleave(
self.model.num_parallel_samples, 0
)
repeated_past_target = past_target.repeat_interleave(
self.model.num_parallel_samples, 0
) # (bsz* self.model.num_parallel_samples, model.context_length+max(model.lags_seq))
repeated_past_observed_values = past_observed_values.repeat_interleave(
self.model.num_parallel_samples, 0
) # (bsz* self.model.num_parallel_samples, model.context_length+max(model.lags_seq))
future_samples = []
for t in range(self.prediction_length):
if self.time_feat:
params, loc, scale = self.model(
*args,
past_time_feat=repeated_past_time_feat,
future_time_feat=repeated_future_time_feat[..., : t + 1, :],
past_target=repeated_past_target,
past_observed_values=repeated_past_observed_values,
use_kv_cache=self.use_kv_cache,
)
else:
params, loc, scale = self.model(
*args,
past_time_feat=None, # repeated_past_time_feat,
future_time_feat=None, # repeated_future_time_feat[..., : t + 1, :],
past_target=repeated_past_target,
past_observed_values=repeated_past_observed_values,
use_kv_cache=self.use_kv_cache,
)
sliced_params = [
p[:, -1:] for p in params
] # Take the last timestep predicted. Each tensor is of shape (#bsz*#parallel_samples, 1)
distr = self.model.distr_output.distribution(sliced_params, loc, scale)
sample = distr.sample() # (#bsz*#parallel_samples, 1)
if self.nonnegative_pred_samples:
sample = F.relu(sample)
future_samples.append(sample)
repeated_past_target = torch.cat((repeated_past_target, sample), dim=1)
repeated_past_observed_values = torch.cat(
(repeated_past_observed_values, torch.ones_like(sample)), dim=1
)
self.model.reset_cache()
concat_future_samples = torch.cat(future_samples, dim=-1)
return concat_future_samples.reshape(
(-1, self.model.num_parallel_samples, self.prediction_length)
+ self.model.distr_output.event_shape,
)
# train
def _compute_loss(self, batch, do_not_average=False, return_observed_values=False):
past_target = batch[
"past_target"
] # (bsz, model.context_length+max(model.lags_seq))
past_observed_values = batch[
"past_observed_values"
] # (bsz, model.context_length+max(model.lags_seq)) with 0s or 1s indicating available (1s) or missing (0s)
future_target = batch["future_target"] # (bsz, model.prediction_length)
future_observed_values = batch[
"future_observed_values"
] # (bsz, model.prediction_length) with 0s or 1s indicating available (1s) or missing (0s)
if self.time_feat:
past_time_feat = batch["past_time_feat"]
future_time_feat = batch["future_time_feat"]
else:
past_time_feat = None
future_time_feat = None
extra_dims = len(future_target.shape) - len(past_target.shape) # usually 0
extra_shape = future_target.shape[:extra_dims] # shape remains the same
repeats = prod(extra_shape) # usually 1
past_target = repeat_along_dim(
past_target, 0, repeats
) # (bsz, model.context_length+max(model.lags_seq))
past_observed_values = repeat_along_dim(
past_observed_values, 0, repeats
) # (bsz, model.context_length+max(model.lags_seq))
future_target_reshaped = future_target.reshape(
-1,
*future_target.shape[extra_dims + 1 :],
) # (bsz, model.prediction_length)
future_observed_reshaped = future_observed_values.reshape(
-1,
*future_observed_values.shape[extra_dims + 1 :],
) # (bsz, model.prediction_length)
distr_args, loc, scale = self.model(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
future_target=future_target_reshaped,
) # distr_args is a tuple with two tensors of shape (bsz, context_length+pred_len-1)
context_target = take_last(
past_target, dim=-1, num=self.context_length - 1
) # (bsz, context_length-1) # Basically removes the first value since it cannot be predicted
target = torch.cat(
(context_target, future_target_reshaped),
dim=1,
) # (bsz, context_length-1+pred_len) # values that can be predicted
context_observed = take_last(
past_observed_values, dim=-1, num=self.context_length - 1
) # same as context_target, but for observed_values tensor
observed_values = torch.cat(
(context_observed, future_observed_reshaped), dim=1
) # same as target but for observed_values tensor
if type(self.model.distr_output) == ImplicitQuantileNetworkOutput:
if not do_not_average:
loss = (
self.model.distr_output.loss(target, distr_args, loc, scale)
* observed_values
).sum() / observed_values.sum().clamp_min(1.0)
else:
loss = (
self.model.distr_output.loss(target, distr_args, loc, scale)
* observed_values
)
else:
distr = self.model.distr_output.distribution(
distr_args, loc=loc, scale=scale
) # an object representing a distribution with the specified parameters. We need this to compute the NLL loss.
if not do_not_average:
loss = (
self.loss(distr, target) * observed_values
).sum() / observed_values.sum().clamp_min(1.0)
else:
loss = self.loss(distr, target) * observed_values
if not return_observed_values:
return loss
else:
return loss, observed_values
def training_step(self, batch, batch_idx: int): # type: ignore
"""
Execute training step.
"""
if random.random() < self.aug_prob:
# Freq mix and Freq mask have separate functions
if self.freq_mask_rate > 0:
batch["past_target"], batch["future_target"] = freq_mask(
batch["past_target"],
batch["future_target"],
rate=self.freq_mask_rate,
)
if self.freq_mixing_rate:
batch["past_target"], batch["future_target"] = freq_mix(
batch["past_target"],
batch["future_target"],
rate=self.freq_mixing_rate,
)
# Other augmentation
if len(self.transforms):
batch["past_target"], batch["future_target"] = self.augmentations(
batch["past_target"], batch["future_target"]
)
train_loss_per_sample, observed_values = self._compute_loss(
batch, do_not_average=True, return_observed_values=True
)
train_loss_avg = train_loss_per_sample.sum() / observed_values.sum().clamp_min(
1.0
)
self.log(
"train_loss", train_loss_avg, on_epoch=True, on_step=False, prog_bar=False
)
return train_loss_avg
def on_train_epoch_end(self):
# Log all losses
for key, value in self.train_loss_dict.items():
loss_avg = np.mean(value)
self.log(
f"train_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}",
loss_avg,
on_epoch=True,
on_step=False,
prog_bar=False,
)
if self.track_loss_per_series:
# Log all losses
for key, value in self.train_loss_dict_per_series.items():
loss_avg = np.mean(value)
self.log(
f"train_loss_avg_per_train_series/{key}",
loss_avg,
on_epoch=True,
on_step=False,
prog_bar=False,
)
# Reset loss_dict
self.train_loss_dict = {}
self.train_loss_dict_per_series = {}
def validation_step(self, batch, batch_idx: int): # type: ignore
"""
Execute validation step.
"""
val_loss_per_sample, observed_values = self._compute_loss(
batch, do_not_average=True, return_observed_values=True
)
val_loss_avg = val_loss_per_sample.sum() / observed_values.sum().clamp_min(1.0)
self.log("val_loss", val_loss_avg, on_epoch=True, on_step=False, prog_bar=False)
return val_loss_avg
def on_validation_epoch_end(self):
# Log all losses
for key, value in self.val_loss_dict.items():
loss_avg = np.mean(value)
if key >= 0:
self.log(
f"val_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}",
loss_avg,
on_epoch=True,
on_step=False,
prog_bar=False,
)
else:
self.log(
f"val_loss_avg_per_test_dataset/{self.data_id_to_name_map[key]}",
loss_avg,
on_epoch=True,
on_step=False,
prog_bar=False,
)
if self.track_loss_per_series:
# Log all losses
for key, value in self.val_loss_dict_per_series.items():
loss_avg = np.mean(value)
self.log(
f"val_loss_avg_per_train_series/{key}",
loss_avg,
on_epoch=True,
on_step=False,
prog_bar=False,
)
# Reset loss_dict
self.val_loss_dict = {}
self.val_loss_dict_per_series = {}
def configure_optimizers(self):
"""
Returns the optimizer to use.
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
if self.use_cosine_annealing_lr:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, **self.cosine_annealing_lr_args, verbose=True
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
else:
return optimizer