-
Notifications
You must be signed in to change notification settings - Fork 731
Open
Description
- PyTorch-Forecasting version: 1.0.0
- PyTorch version: 2.0.1
- Python version: 3.10
- Operating System: Ubuntu 22.04
Expected behavior
TFT training on UCI electricity dataset is coded up per examples using Lightning, train/val datasets fits well within system memory.
Actual behavior
Throughout training, cpu RAM consumption steadily grows until it hits OOM (96gb) and kernel crashes. I have tried disabling logging everywhere I know how within Pytorch-Forecasting and Lightning to no avail. GPU ram consumption stays steady at ~1.2gb
Code to reproduce the problem
max_prediction_length = 24
max_encoder_length = 7*24
batch_size = 64
num_workers=8
split_time_idx = 30000
train_data = TimeSeriesDataSet(data=df[lambda x: x.time_idx < split_time_idx],
time_idx='time_idx',
target='demand',
group_ids=['group_id'],
min_encoder_length=max_encoder_length,
max_encoder_length=max_encoder_length,
min_prediction_length=max_prediction_length,
max_prediction_length=max_prediction_length,
static_categoricals=['group_id'],
time_varying_known_reals=['time_idx', 'hour', 'weekday', 'day', 'month'],
time_varying_unknown_reals=['demand'],
target_normalizer=GroupNormalizer(groups=['group_id'], transformation='softplus'),
add_relative_time_idx=True,
add_target_scales=True,
randomize_length=False)
val_data = TimeSeriesDataSet.from_dataset(train_data, df[lambda x: x.time_idx >= split_time_idx], stop_randomization=True, predict=False)
train_dataloader = train_data.to_dataloader(train=True, batch_size=batch_size, num_workers=num_workers)
val_dataloader = val_data.to_dataloader(train=False, batch_size=batch_size, num_workers=num_workers)
tft = TemporalFusionTransformer.from_dataset(train_data,
learning_rate=.001,
hidden_size=160,
hidden_continuous_size=160,
attention_head_size=4,
dropout=.1,
output_size=output_size,
loss=quantile_loss,
log_interval=-1,
reduce_on_plateau_patience=4)tomnaumann, caseytomlin, binhna, vidalalcala, sairamtvv and 7 more
Metadata
Metadata
Assignees
Labels
No labels