In [1]:
import xarray as xr
import ocf_blosc2
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
import torch

In [2]:
test_data = pd.read_csv("result_data/tft_data_36_test.csv")

In [3]:
test_data = test_data.drop(columns=['Unnamed: 0'])
test_data.columns

Index(['ss_id', 'init_time', 'step', 'pv_datetime', 'pv_hour', 'horizon',
       'generation', 'capacity', 'normalize_generation', 'lat', 'long', 'tilt',
       'orientation', 'dlwrf', 'dswrf', 'duvrs', 'hcc', 'lcc', 'mcc', 'sde',
       'sr', 't2m', 'tcc', 'u10', 'u100', 'v10', 'v100'],
      dtype='object')

In [4]:
best_model_path = 'lightning_logs/my_model/version_20/checkpoints/epoch=49-step=1500.ckpt'
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)


In [5]:
max_encoder_length = 36
max_prediction_length = 36

test_data.rename(columns={'pv_hour': 'day_hour'}, inplace=True)

test_data['ss_id'] = test_data['ss_id'].astype(int)

test_data['pv_datetime'] = pd.to_datetime(test_data['pv_datetime'])
test_data['date'] = test_data['pv_datetime'].dt.date
test_data['day_of_week'] = test_data['pv_datetime'].dt.dayofweek
test_data['month'] = test_data['pv_datetime'].dt.month

test_data['ss_id'] = test_data['ss_id'].astype(str)
test_data['day_of_week'] = test_data['day_of_week'].astype(str)
test_data['month'] = test_data['month'].astype(str)
test_data['day_hour'] = test_data['day_hour'].astype(str)

test_data['time_idx'] = test_data.index
test_data['time_idx'] = test_data['time_idx'].astype(int)

In [6]:
test_data.shape

(3600, 31)

In [7]:
new_data = TimeSeriesDataSet(
    test_data,
    time_idx="time_idx",
    target="normalize_generation",
    group_ids=["ss_id"],  # Grouping by ss_id to identify different PV sites
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_reals=["capacity", "lat", "long", "tilt", "orientation"],
    time_varying_known_categoricals=[ "month", "day_of_week", "day_hour"],
    time_varying_known_reals=["time_idx", "dlwrf", "dswrf", "duvrs", "hcc", "lcc", "mcc", "sde", "sr", "t2m", "tcc", "u10", "u100", "v10", "v100"],
    time_varying_unknown_reals=["normalize_generation"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=False,
)


In [8]:
new_data_loader = new_data.to_dataloader(train=False, batch_size=128, num_workers=2)

In [9]:
# lightning_logs/my_model/version_20/checkpoints/epoch=49-step=1500.ckpt
# actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
# predictions = best_tft.predict(val_dataloader)

In [10]:
actuals = torch.cat([y[0] for x, y in iter(new_data_loader)])

In [11]:
predictions = best_tft.predict(new_data_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
print("Checking for NaN values in actuals and predictions...")
print("Number of NaNs in actuals:", torch.isnan(actuals).sum().item())
print("Number of NaNs in predictions:", torch.isnan(predictions).sum().item())

Checking for NaN values in actuals and predictions...
Number of NaNs in actuals: 0
Number of NaNs in predictions: 35208


In [13]:

mae = (actuals - predictions).abs().mean().item()
print(f"Mean Absolute Error: {mae}")


average_p50_loss = (actuals - predictions).abs().mean(axis=1)
print(f"Average P50 Loss per Time Series: {average_p50_loss}")


Mean Absolute Error: nan
Average P50 Loss per Time Series: tensor([14.8395, 14.8610, 14.8481,  ..., 64.7494, 67.9491, 70.2907])


In [14]:
print("Actuals:", actuals.tolist())
print("Predictions:", predictions.tolist())

Actuals: [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 146.5113067626953, 251.7209930419922, 300.4230041503906, 317.4930114746094, 397.4580078125, 395.8500061035156, 353.5799865722656, 263.70361328125, 185.38619995117188, 98.48880004882812, 2.6986650482285768e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 146.5113067626953, 251.7209930419922, 300.4230041503906, 317.4930114746094, 397.4580078125, 395.8500061035156, 353.5799865722656, 263.70361328125, 185.38619995117188, 98.48880004882812, 2.6986650482285768e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 146.5113067626953, 251.7209930419922, 300.4230041503906, 317.4930114746094, 397.4580078125, 395.8500061035156, 353.5799865722656, 263.70361328125, 185.38619995117188, 98.48880004882812, 2.6986650482285768e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 

In [15]:
mask = ~torch.isnan(actuals) & ~torch.isnan(predictions)
actuals_filtered = actuals[mask]
predictions_filtered = predictions[mask]

# Calculate MAE
mae = (actuals_filtered - predictions_filtered).abs().mean().item()
print(f"Mean Absolute Error: {mae}")


Mean Absolute Error: 58.833858489990234
