-
Notifications
You must be signed in to change notification settings - Fork 749
Description
Describe the bug
To Reproduce
Build TimeSeriesDataSet
predicting = TimeSeriesDataSet(
pred,
time_idx="time_idx",
target="DEPARTURE_DELAY",
group_ids=['IATA_CODE'],
min_encoder_length=max_prediction_length, # years, length to encode (can be far longer than the decoder length but does not have to be)
max_encoder_length=max_encoder_length, # its in years, max_encoder_length
max_prediction_length=max_prediction_length,
static_categoricals=static_cats,
static_reals=static_reals,
time_varying_known_categoricals=time_varying_cats,
time_varying_known_reals=[
"SCHEDULED_DEPARTURE", "DISTANCE", "SCHEDULED_ARRIVAL",
"SCHEDULED_TIME", "ARRIVAL_TIME", "time_idx"
],
time_varying_unknown_reals=[
"DEPARTURE_TIME", "ARRIVAL_DELAY", "SCHEDULED_TIME",
"WIND_GUST_SPEED", "VIS_DIST", "TMP_CELSIUS", "SLP_PRESSURE",
"DEW_CELSIUS", "CEILING_QUALITY", "CIG_HEIGHT", "WIND_DIRECTION",
"SECURITY_DELAY", "AIRLINE_DELAY", "LATE_AIRCRAFT_DELAY",
"WEATHER_DELAY", "AIR_SYSTEM_DELAY", "DIVERTED", "CANCELLED",
"TAXI_IN", "TAXI_OUT", "WHEELS_OFF",
"WHEELS_ON", "ELAPSED_TIME", "AIR_TIME", "DEPARTURE_DELAY"
],
target_normalizer=GroupNormalizer(groups=['IATA_CODE'], transformation="softplus"), # use softplus and normalize by group
categorical_encoders=categorical_encoders, # Apply NaN encoders
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True,
)
pred_dataloader = predicting.to_dataloader(train=True, batch_size=128, num_workers=4)
best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);
AttributeError Traceback (most recent call last)
Cell In[231], line 1
----> 1 best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);
File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:719, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
717 # add attention on secondary axis
718 if plot_attention:
--> 719 interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
720 for f in to_list(fig):
721 ax = f.axes[0]
File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:597, in TemporalFusionTransformer.interpret_output(self, out, reduction, attention_prediction_horizon)
595 # roll encoder attention (so start last encoder value is on the right)
596 encoder_attention = out["encoder_attention"]
--> 597 shifts = encoder_attention.size(3) - out["encoder_lengths"]
598 new_index = (
599 torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as(
600 encoder_attention
601 )
602 - shifts[:, None, None, None]
603 ) % encoder_attention.size(3)
604 encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)
AttributeError: 'list' object has no attribute 'size'
**Expected behavior**
<!--
A clear and concise description of what you expected to happen.
-->
**Additional context**
<!--
Add any other context about the problem here.
-->
**Versions**
<details>
<!--
Please run the following code snippet and paste the output here:
from sktime import show_versions; show_versions()
-->
</details>
<!-- Thanks for contributing! -->
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
