# 1. Import Data

HINWEIS: Hier arbeite ich aktuell noch aktiv dran (siehe PDF, S. 15). 

In [30]:
# Standard libs
import os
import glob
import json
from dataclasses import dataclass
from __future__ import annotations

# Utils
import math
import random
import gc

# Data & Processing
import numpy as np
import pandas as pd
import geopandas as gpd

# Scikit-Learn
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, confusion_matrix
from sklearn.neighbors import BallTree, NearestNeighbors
from sklearn.model_selection import train_test_split

# TQDM
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

# Plotting
import matplotlib.pyplot as plt

# PyTorch
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

# PyTorch Geometric
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv


# Confirm the current working directory
print("Current working directory:", os.getcwd())

Current working directory: /home/qusta100/STGNN


In [None]:
#Load data
df = pd.read_csv(
    "/gpfs/scratch/qusta100/STGNN/Data/Temp/final.csv",
    dtype={9: str}
)
df["date"] = pd.to_datetime(df["date"])

In [10]:
# Filter for one specific day
# Because of limited computational resources, I restrict the dataset to a single day, which gives me 96 observations.
df = df[df['date'].str.startswith("2025-04-01")]
df.head()

Unnamed: 0,date,station_uuid,diesel,e5,e10,uuid,name,brand,street,house_number,...,openingtimes_json,day,in_thuringia,in_region,last_seen,is_open,weekday,holiday,time,Brent_Price
0,2025-04-01 00:00:00,00060065-7890-4444-8888-acdc00000004,1.559,1.709,1.649,00060065-7890-4444-8888-acdc00000004,Georg Ultsch GmbH,Tankstelle Lichtenfels,Robert-Koch-Str.,18,...,{},2025-04-30,False,1,2025-04-30 21:45:00,True,Tuesday,0,00:00,74.959999
1,2025-04-01 00:15:00,00060065-7890-4444-8888-acdc00000004,1.559,1.709,1.649,00060065-7890-4444-8888-acdc00000004,Georg Ultsch GmbH,Tankstelle Lichtenfels,Robert-Koch-Str.,18,...,{},2025-04-30,False,1,2025-04-30 21:45:00,True,Tuesday,0,00:15,74.959999
2,2025-04-01 00:30:00,00060065-7890-4444-8888-acdc00000004,1.559,1.709,1.649,00060065-7890-4444-8888-acdc00000004,Georg Ultsch GmbH,Tankstelle Lichtenfels,Robert-Koch-Str.,18,...,{},2025-04-30,False,1,2025-04-30 21:45:00,True,Tuesday,0,00:30,74.959999
3,2025-04-01 00:45:00,00060065-7890-4444-8888-acdc00000004,1.559,1.709,1.649,00060065-7890-4444-8888-acdc00000004,Georg Ultsch GmbH,Tankstelle Lichtenfels,Robert-Koch-Str.,18,...,{},2025-04-30,False,1,2025-04-30 21:45:00,True,Tuesday,0,00:45,74.959999
4,2025-04-01 01:00:00,00060065-7890-4444-8888-acdc00000004,1.559,1.709,1.649,00060065-7890-4444-8888-acdc00000004,Georg Ultsch GmbH,Tankstelle Lichtenfels,Robert-Koch-Str.,18,...,{},2025-04-30,False,1,2025-04-30 21:45:00,True,Tuesday,0,01:00,74.959999


# 2. STGNN Design

In [31]:
# =============================
# a) Parameters
# =============================

STATION_COL = "station_uuid"
LAT_COL = "latitude"
LON_COL = "longitude"
TARGET_COL = "e5"
TIME_COL = "date"

RADIUS_KM = 10.0

EMBED_DIM = 32
HIDDEN_DIM = 64
LR = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 50
PATIENCE = 10
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
SEED = 42

WINDOW_SIZE = 16
HORIZON_STEPS = 4


# =============================
# b) Helper functions
# =============================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def km_to_radians(km: float) -> float:
    earth_radius_km = 6371.0088
    return km / earth_radius_km


def build_radius_graph(lat_lon_deg: np.ndarray, radius_km: float) -> np.ndarray:
    lat_lon_rad = np.radians(lat_lon_deg)
    tree = BallTree(lat_lon_rad, metric="haversine")
    rad = km_to_radians(radius_km)

    ind = tree.query_radius(lat_lon_rad, r=rad, return_distance=False)

    src, dst = [], []
    for i, neigh in enumerate(ind):
        for j in neigh:
            if i == j:
                continue
            src.append(j)
            dst.append(i)
    edge_index = np.vstack([np.array(src), np.array(dst)])
    return edge_index


# =============================
# c) Window-based ST data (multi-step)
# =============================

@dataclass
class STWindowDataMulti:
    x: torch.Tensor             # [T, N, F] full time series
    y: torch.Tensor             # [T, N, H] H horizons for each t
    edge_index: torch.Tensor
    train_end_idx: np.ndarray   # time indices where windows end (train)
    val_end_idx: np.ndarray     # time indices (validation)
    test_end_idx: np.ndarray    # time indices (test)
    valid_nodes: np.ndarray     # node indices (e.g. Thüringen)
    times: np.ndarray           # sorted timestamps
    meta: pd.DataFrame          # station metadata (uuid, lat, lon, node_id)


def make_window_data_multi_from_df(
    df: pd.DataFrame,
    window_size: int = WINDOW_SIZE,
    horizon_steps: int = HORIZON_STEPS
) -> STWindowDataMulti:

    needed = [STATION_COL, LAT_COL, LON_COL, TARGET_COL, TIME_COL, "in_thuringia"]
    assert all(c in df.columns for c in needed), f"Missing columns: {set(needed) - set(df.columns)}"

    df = df.copy()

    # Stations + node ids
    stations = df[[STATION_COL, LAT_COL, LON_COL]].drop_duplicates().reset_index(drop=True)
    stations["node_id"] = np.arange(len(stations))
    station2id = stations.set_index(STATION_COL)["node_id"].to_dict()

    # Time axis
    times = np.sort(df[TIME_COL].unique())
    time2id = {t: i for i, t in enumerate(times)}

    N = len(stations)
    T = len(times)
    F = 1    # only price used as feature
    H = horizon_steps

    # Spatial graph
    lat_lon = stations[[LAT_COL, LON_COL]].to_numpy(dtype=float)
    edge_index_np = build_radius_graph(lat_lon, RADIUS_KM)
    edge_index = torch.tensor(edge_index_np, dtype=torch.long)

    # Tensors
    x = torch.zeros((T, N, F), dtype=torch.float32)
    y = torch.full((T, N, H), float("nan"), dtype=torch.float32)

    # Aggregate per (station, time) and drop NaN prices
    df_idx = (
        df[[STATION_COL, TIME_COL, TARGET_COL]]
        .dropna(subset=[TARGET_COL])
        .groupby([STATION_COL, TIME_COL], as_index=False)
        .mean()
        .sort_values(TIME_COL)
    )

    df_idx["node_id"] = df_idx[STATION_COL].map(station2id)
    df_idx["time_id"] = df_idx[TIME_COL].map(time2id)

    # x[t] = price at time t
    for _, row in df_idx.iterrows():
        t = int(row["time_id"])
        n = int(row["node_id"])
        v = float(row[TARGET_COL])
        x[t, n, 0] = v

    # y[t, :, h] = price at t + h + 1
    for station, g in df_idx.groupby(STATION_COL):
        g = g.sort_values("time_id")
        node_id = int(g["node_id"].iloc[0])
        t_ids = g["time_id"].to_numpy()
        vals  = g[TARGET_COL].to_numpy()

        L = len(t_ids)
        for idx in range(L):
            t = int(t_ids[idx])
            if idx + H >= L:
                break
            for h in range(H):
                y[t, node_id, h] = float(vals[idx + h + 1])

    # Time indices where full horizon is available
    effective_T = T - H
    if effective_T < window_size:
        raise ValueError(f"Too few timesteps ({effective_T}) for window size {window_size} and horizon {H}.")

    # Time-based split over base indices 0..effective_T-1
    time_indices = np.arange(effective_T)
    num_total = len(time_indices)
    num_train = int((1.0 - VAL_SPLIT - TEST_SPLIT) * num_total)
    num_val   = int(VAL_SPLIT * num_total)
    num_test  = num_total - num_train - num_val

    train_t = time_indices[:num_train]
    val_t   = time_indices[num_train:num_train + num_val]
    test_t  = time_indices[num_train + num_val:]

    # Window end indices must satisfy t >= window_size - 1
    min_end = window_size - 1
    train_end_idx = train_t[train_t >= min_end]
    val_end_idx   = val_t[val_t >= min_end]
    test_end_idx  = test_t[test_t >= min_end]

    # Thüringen mask for nodes
    th_mask = (
        df[[STATION_COL, "in_thuringia"]]
        .drop_duplicates()
        .set_index(STATION_COL)["in_thuringia"]
    )
    stations = stations.join(th_mask, on=STATION_COL)
    valid_nodes = stations.index[stations["in_thuringia"] == 1].to_numpy()

    if valid_nodes.size == 0:
        raise ValueError("No nodes with in_thuringia == 1 found.")

    meta = stations[[STATION_COL, LAT_COL, LON_COL, "node_id"]].copy()

    return STWindowDataMulti(
        x=x,
        y=y,
        edge_index=edge_index,
        train_end_idx=train_end_idx,
        val_end_idx=val_end_idx,
        test_end_idx=test_end_idx,
        valid_nodes=valid_nodes,
        times=times,
        meta=meta
    )


# =============================
# d) STGNN with windowing (multi-step output)
# =============================

class STGNNWindowMultiRegressor(nn.Module):

    def __init__(self, num_nodes: int, embed_dim: int, hidden_dim: int, horizon_steps: int, in_dim: int = 1):
        super().__init__()
        self.horizon_steps = horizon_steps
        self.emb = nn.Embedding(num_nodes, embed_dim)
        self.proj = nn.Linear(in_dim + embed_dim, hidden_dim)
        self.conv = SAGEConv(hidden_dim, hidden_dim)
        self.act = nn.ReLU()
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=False)
        self.head = nn.Linear(hidden_dim, horizon_steps)

    def forward(self, x_window: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:

        W, N, F = x_window.shape
        node_emb = self.emb.weight

        gnn_out_per_t = []

        for t in range(W):
            xt = x_window[t]
            h = torch.cat([xt, node_emb], dim=1)
            h = self.proj(h)
            h = self.act(h)
            h = self.conv(h, edge_index)
            h = self.act(h)
            gnn_out_per_t.append(h)

        gnn_seq = torch.stack(gnn_out_per_t, dim=0)

        gru_out, h_n = self.gru(gnn_seq)
        last_hidden = h_n[-1]

        out = self.head(last_hidden)
        return out


# =============================
# e) Training and evaluation (multi-step)
# =============================

@dataclass
class STTrainConfig:
    lr: float = LR
    weight_decay: float = WEIGHT_DECAY
    epochs: int = EPOCHS
    patience: int = PATIENCE
    window_size: int = WINDOW_SIZE
    horizon_steps: int = HORIZON_STEPS


def _compute_window_metrics_multi(
    model: nn.Module,
    data: STWindowDataMulti,
    end_indices: np.ndarray,
    device: torch.device,
    loss_nodes: torch.Tensor,
) -> dict:

    x_full = data.x.to(device)
    y_full = data.y.to(device)
    edge_index = data.edge_index.to(device)
    W = WINDOW_SIZE
    H = y_full.shape[-1]

    preds_all = []
    trues_all = []

    model.eval()
    with torch.no_grad():
        for end_t in end_indices:
            start_t = end_t - W + 1
            x_win = x_full[start_t:end_t+1]
            y_target = y_full[end_t]

            pred = model(x_win, edge_index)

            y_nodes = y_target[loss_nodes]
            p_nodes = pred[loss_nodes]

            mask = torch.isfinite(y_nodes)
            if mask.sum() == 0:
                continue

            preds_all.append(p_nodes[mask].cpu().numpy())
            trues_all.append(y_nodes[mask].cpu().numpy())

    if len(preds_all) == 0:
        return {"MAE": float("nan"), "RMSE": float("nan")}

    preds_concat = np.concatenate(preds_all)
    trues_concat = np.concatenate(trues_all)

    mae = np.mean(np.abs(preds_concat - trues_concat))
    mse = np.mean((preds_concat - trues_concat) ** 2)
    rmse = math.sqrt(mse)

    return {"MAE": float(mae), "RMSE": float(rmse)}


def train_stgnn_window_multi(data: STWindowDataMulti, cfg: STTrainConfig) -> dict:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x_full = data.x.to(device)
    y_full = data.y.to(device)
    edge_index = data.edge_index.to(device)
    W = cfg.window_size
    H = cfg.horizon_steps

    model = STGNNWindowMultiRegressor(
        num_nodes=x_full.shape[1],
        embed_dim=EMBED_DIM,
        hidden_dim=HIDDEN_DIM,
        horizon_steps=H,
        in_dim=x_full.shape[2]
    ).to(device)

    opt = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    loss_fn = nn.SmoothL1Loss()

    loss_nodes = torch.tensor(data.valid_nodes, dtype=torch.long, device=device)

    best_state = None
    best_val = float("inf")
    wait = 0

    for epoch in range(cfg.epochs):
        model.train()
        epoch_loss = 0.0
        count = 0

        for end_t in data.train_end_idx:
            start_t = end_t - W + 1

            x_win = x_full[start_t:end_t+1]
            y_target = y_full[end_t]

            pred = model(x_win, edge_index)

            y_nodes = y_target[loss_nodes]
            p_nodes = pred[loss_nodes]

            mask = torch.isfinite(y_nodes)
            if mask.sum() == 0:
                continue

            loss = loss_fn(p_nodes[mask], y_nodes[mask])

            opt.zero_grad()
            loss.backward()
            opt.step()

            epoch_loss += loss.item()
            count += 1

        avg_train_loss = epoch_loss / max(count, 1)

        val_metrics = _compute_window_metrics_multi(model, data, data.val_end_idx, device, loss_nodes)
        val_mae = val_metrics["MAE"]

        print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_MAE={val_mae:.4f}")

        if val_mae < best_val:
            best_val = val_mae
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= cfg.patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    test_metrics = _compute_window_metrics_multi(model, data, data.test_end_idx, device, loss_nodes)
    val_metrics = _compute_window_metrics_multi(model, data, data.val_end_idx, device, loss_nodes)

    return {
        "model": model,
        "metrics": {
            "val": val_metrics,
            "test": test_metrics,
        },
    }


# =============================
# f) High-level API
# =============================

def train_stgnn_window_from_df(
    df: pd.DataFrame,
    *,
    seed: int = SEED,
    window_size: int = WINDOW_SIZE,
    horizon_steps: int = HORIZON_STEPS
) -> dict:

    set_seed(seed)
    data = make_window_data_multi_from_df(df, window_size=window_size, horizon_steps=horizon_steps)
    cfg = STTrainConfig(window_size=window_size, horizon_steps=horizon_steps)
    result = train_stgnn_window_multi(data, cfg)

    return {
        "model": result["model"],
        "metrics": result["metrics"],
        "data": data,
        "window_size": window_size,
        "horizon_steps": horizon_steps,
    }


In [32]:
# Example with a window size of 16 and 4 horizon steps
res = train_stgnn_window_from_df(df, window_size=16, horizon_steps=4)

print("Val:", res["metrics"]["val"])
print("Test:", res["metrics"]["test"])

Epoch 1: train_loss=0.2343, val_MAE=0.0663
Epoch 2: train_loss=0.0034, val_MAE=0.0422
Epoch 3: train_loss=0.0027, val_MAE=0.0375
Epoch 4: train_loss=0.0026, val_MAE=0.0347
Epoch 5: train_loss=0.0025, val_MAE=0.0340
Epoch 6: train_loss=0.0024, val_MAE=0.0337
Epoch 7: train_loss=0.0023, val_MAE=0.0328
Epoch 8: train_loss=0.0022, val_MAE=0.0314
Epoch 9: train_loss=0.0021, val_MAE=0.0299
Epoch 10: train_loss=0.0020, val_MAE=0.0286
Epoch 11: train_loss=0.0019, val_MAE=0.0275
Epoch 12: train_loss=0.0018, val_MAE=0.0266
Epoch 13: train_loss=0.0017, val_MAE=0.0258
Epoch 14: train_loss=0.0016, val_MAE=0.0251
Epoch 15: train_loss=0.0015, val_MAE=0.0243
Epoch 16: train_loss=0.0014, val_MAE=0.0235
Epoch 17: train_loss=0.0013, val_MAE=0.0226
Epoch 18: train_loss=0.0013, val_MAE=0.0217
Epoch 19: train_loss=0.0012, val_MAE=0.0208
Epoch 20: train_loss=0.0011, val_MAE=0.0200
Epoch 21: train_loss=0.0011, val_MAE=0.0192
Epoch 22: train_loss=0.0010, val_MAE=0.0186
Epoch 23: train_loss=0.0010, val_MAE=0.01