In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import sys
from pathlib import Path
sys.path.append(Path(os.getcwd()).parent.parent.as_posix())

In [3]:
import dgl
import json

import optuna
import pickle

import torch
import torch.nn as nn

from functools import partial
from itertools import product
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from dataset import get_datasets, ETTDataset

from utils import seed_everything
from models.gcn import GCNModel

from constructor import construct_ess, construct_vanilla, construct_complete
from graph_features import spectral_features, deepwalk_features

from train import train_step, evaluation_step

import warnings
warnings.simplefilter("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything()

In [5]:
class DatasetAdapter(Dataset):
    def __init__(self, dataset: ETTDataset, graph_construction_fn, graph_features_fn=None):
        super().__init__()
        self.graphs: list[dgl.DGLGraph] = []
        self.targets: list[torch.Tensor] = []
        for idx in tqdm(range(len(dataset)), desc="Building graphs"):
            x_data, time_data, y_data = dataset[idx]
            graph = graph_construction_fn(x_data)

            if graph_features_fn:
                graph_features = graph_features_fn(graph)
                graph.ndata["h"] = torch.cat([x_data.T, graph_features], dim=1)
            else:
                graph.ndata["h"] = x_data.T
            
            graph.ndata["h"] = torch.cat([
                graph.ndata["h"],
                time_data.repeat(graph.number_of_nodes(), 1),
            ], dim=1)
            
            self.targets.append(y_data)
            self.graphs.append(graph)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx) -> tuple[dgl.DGLGraph, torch.Tensor]:
        return self.graphs[idx], self.targets[idx]

In [6]:
def graph_collate_fn(batch):
    """
    Custom collate function for batching DGL graphs.
    :param graphs: batch of graphs and targets
    :returns: batched graph, batch of targets
    """
    graphs, targets = zip(*batch)
    targets_tensor = torch.stack(targets, dim=0)
    return dgl.batch(graphs), targets_tensor

In [7]:
class GraphTSModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_layers: int,
        horizon_size: int,
        activation_fn: nn.Module,
        dropout: float = 0,
    ) -> "GraphTSModel":
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.activation_fn = activation_fn
        self.dropout = dropout
        self.horizon_size = horizon_size

        self.backbone = GCNModel(
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
            activation_fn=self.activation_fn,
            dropout=self.dropout
        )

        self.head = nn.Linear(self.hidden_dim, self.horizon_size)
    
    def forward(self, graph, features):
        x = features
        outputs = self.backbone(graph, x)
        tgt_emb = outputs[6::7] # extract OT's embeddings
        outputs = self.head(tgt_emb)
        return outputs

In [8]:
LOOKBACK_SIZE = 96
HORIZON_SIZES = [24, 48, 96, 168, 192, 336, 720]
ACTIVATION_FN = nn.ReLU
GRAPH_FEATURES_FN = partial(spectral_features, embed_size=7)
GRAPH_CONSTRUCTION_FNS = [
    (construct_complete, "complete"),
    (partial(construct_ess, alpha=0.05), "ess"),
    (partial(construct_vanilla, alpha=0.05), "vanilla")
]

SETUPS = list(product(HORIZON_SIZES, GRAPH_CONSTRUCTION_FNS))

# ETTh1

In [9]:
with open("./spectral_gcn_study_results.pkl", "rb") as f:
    gcn_spectral_h1 = pickle.load(f)

gcn_spectral_h1

{'24_complete': <optuna.study.study.Study at 0x7199355171d0>,
 '24_ess': <optuna.study.study.Study at 0x7198315afe50>,
 '24_vanilla': <optuna.study.study.Study at 0x7197fc172250>,
 '48_complete': <optuna.study.study.Study at 0x7197fb750cd0>,
 '48_ess': <optuna.study.study.Study at 0x7197fb7af790>,
 '48_vanilla': <optuna.study.study.Study at 0x7197fad09f50>,
 '168_complete': <optuna.study.study.Study at 0x7197fad6a150>,
 '168_ess': <optuna.study.study.Study at 0x7197fadc6550>,
 '168_vanilla': <optuna.study.study.Study at 0x7197fa32a950>,
 '336_complete': <optuna.study.study.Study at 0x7197fa382d50>,
 '336_ess': <optuna.study.study.Study at 0x7197fa3e7150>,
 '336_vanilla': <optuna.study.study.Study at 0x7197f9943550>,
 '720_complete': <optuna.study.study.Study at 0x7197f99a3990>,
 '720_ess': <optuna.study.study.Study at 0x7197f99ffd90>,
 '720_vanilla': <optuna.study.study.Study at 0x7197f8f741d0>}

In [10]:
DATASET_NAME = "ETTh1.csv"
ETTH1_RESULTS = {}

In [11]:
pbar = tqdm(SETUPS)
for HORIZON_SIZE, (GRAPH_CONSTRUCTION_FN, NAMING) in pbar:
    pbar.set_description(f"Training for horizon size {HORIZON_SIZE} and {NAMING} graph")
    train_ds, val_ds, test_ds = get_datasets(
        dataset_name=DATASET_NAME,
        lookback_size=LOOKBACK_SIZE,
        horizon_size=HORIZON_SIZE
    )
    
    train_adapter_ds = DatasetAdapter(
        dataset=train_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    val_adapter_ds = DatasetAdapter(
        dataset=val_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    test_adapter_ds = DatasetAdapter(
        dataset=test_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )
    
    if HORIZON_SIZE in [96, 192]:
        params = gcn_spectral_h1[f"{168}_{NAMING}"].best_params
    else:
        params = gcn_spectral_h1[f"{HORIZON_SIZE}_{NAMING}"].best_params
    
    train_loader = DataLoader(
        dataset=train_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=True,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    val_loader = DataLoader(
        dataset=val_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    test_loader = DataLoader(
        dataset=test_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )
    
    INPUT_DIM = train_adapter_ds[0][0].ndata["h"].shape[1]
    
    model = GraphTSModel(
        input_dim=INPUT_DIM,
        hidden_dim=params["hidden_dim"],
        num_layers=params["num_layers"],
        horizon_size=HORIZON_SIZE,
        activation_fn=ACTIVATION_FN,
        dropout=params["dropout"]
    ).to(device)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.33, patience=2)
    
    pbar2 = tqdm(range(10), desc="Training")

    for epoch in pbar2:
        train_loss = train_step(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=device
        )
        val_loss = evaluation_step(
            model=model,
            loader=val_loader,
            device=device
        )
        test_loss = evaluation_step(
            model=model,
            loader=test_loader,
            device=device
        )

        pbar2.set_postfix_str(
            # f"[train] mse = {train_output['mse']:.4f} "
            # f"[train] mae = {train_output['mae']:.4f} "
            f"[valid] mse = {val_loss['mse']:.4f} "
            f"[valid] mae = {val_loss['mae']:.4f} "
            f"[test]  mse = {test_loss['mse']:.4f} "
            f"[test]  mae = {test_loss['mae']:.4f}"
        )
        if scheduler:
            scheduler.step(val_loss["mse"])
        
    ETTH1_RESULTS[f"{HORIZON_SIZE}_{NAMING}_spectral"] = (test_loss["mse"], test_loss["mae"])

Building graphs: 100%|██████████| 8522/8522 [00:03<00:00, 2362.32it/s] [00:00<?, ?it/s]
Building graphs: 100%|██████████| 2762/2762 [00:01<00:00, 2597.18it/s]
Building graphs: 100%|██████████| 2762/2762 [00:01<00:00, 2242.15it/s]
Training: 100%|██████████| 10/10 [00:24<00:00,  2.50s/it, [valid] mse = 0.5977 [valid] mae = 0.3346 [test]  mse = 0.4622 [test]  mae = 0.3050]
Building graphs: 100%|██████████| 8522/8522 [01:11<00:00, 120.03it/s]:31<10:26, 31.35s/it]     
Building graphs: 100%|██████████| 2762/2762 [00:22<00:00, 121.07it/s]
Building graphs: 100%|██████████| 2762/2762 [00:22<00:00, 121.24it/s]
Training: 100%|██████████| 10/10 [00:22<00:00,  2.29s/it, [valid] mse = 0.7379 [valid] mae = 0.3824 [test]  mse = 0.4727 [test]  mae = 0.3019]
Building graphs: 100%|██████████| 8522/8522 [00:57<00:00, 147.59it/s] [02:51<30:06, 95.10s/it]
Building graphs: 100%|██████████| 2762/2762 [00:18<00:00, 147.60it/s]
Building graphs: 100%|██████████| 2762/2762 [00:18<00:00, 146.21it/s]
Training: 100

In [1]:
# model = model.cpu()

# ETTh2

In [13]:
with open("./spectral_gcn_study_results_h2.pkl", "rb") as f:
    gcn_spectral_h2 = pickle.load(f)

gcn_spectral_h2

{'24_complete': <optuna.study.study.Study at 0x7197fb718bd0>,
 '24_ess': <optuna.study.study.Study at 0x7198589062d0>,
 '24_vanilla': <optuna.study.study.Study at 0x719857b4a890>,
 '48_complete': <optuna.study.study.Study at 0x71984d6cb310>,
 '48_ess': <optuna.study.study.Study at 0x719857e5a890>,
 '48_vanilla': <optuna.study.study.Study at 0x71984d756210>,
 '168_complete': <optuna.study.study.Study at 0x719854e3d6d0>,
 '168_ess': <optuna.study.study.Study at 0x719854e38250>,
 '168_vanilla': <optuna.study.study.Study at 0x71984fbd0e10>,
 '336_complete': <optuna.study.study.Study at 0x71985254ab50>,
 '336_ess': <optuna.study.study.Study at 0x71984fd22ad0>,
 '336_vanilla': <optuna.study.study.Study at 0x7197ed539b90>,
 '720_complete': <optuna.study.study.Study at 0x719857ea61d0>,
 '720_ess': <optuna.study.study.Study at 0x71985526ea50>,
 '720_vanilla': <optuna.study.study.Study at 0x71985885cbd0>}

In [14]:
DATASET_NAME = "ETTh2.csv"
ETTH2_RESULTS = {}

In [15]:
pbar = tqdm(SETUPS)
for HORIZON_SIZE, (GRAPH_CONSTRUCTION_FN, NAMING) in pbar:
    pbar.set_description(f"Training for horizon size {HORIZON_SIZE} and {NAMING} graph")
    train_ds, val_ds, test_ds = get_datasets(
        dataset_name=DATASET_NAME,
        lookback_size=LOOKBACK_SIZE,
        horizon_size=HORIZON_SIZE
    )
    
    train_adapter_ds = DatasetAdapter(
        dataset=train_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    val_adapter_ds = DatasetAdapter(
        dataset=val_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    test_adapter_ds = DatasetAdapter(
        dataset=test_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )
    
    if HORIZON_SIZE in [96, 192]:
        params = gcn_spectral_h2[f"{168}_{NAMING}"].best_params
    else:
        params = gcn_spectral_h2[f"{HORIZON_SIZE}_{NAMING}"].best_params
    
    train_loader = DataLoader(
        dataset=train_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=True,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    val_loader = DataLoader(
        dataset=val_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    test_loader = DataLoader(
        dataset=test_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )
    
    INPUT_DIM = train_adapter_ds[0][0].ndata["h"].shape[1]
    
    model = GraphTSModel(
        input_dim=INPUT_DIM,
        hidden_dim=params["hidden_dim"],
        num_layers=params["num_layers"],
        horizon_size=HORIZON_SIZE,
        activation_fn=ACTIVATION_FN,
        dropout=params["dropout"]
    ).to(device)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.33, patience=2)
    
    pbar2 = tqdm(range(10), desc="Training")

    for epoch in pbar2:
        train_loss = train_step(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=device
        )
        val_loss = evaluation_step(
            model=model,
            loader=val_loader,
            device=device
        )
        test_loss = evaluation_step(
            model=model,
            loader=test_loader,
            device=device
        )

        pbar2.set_postfix_str(
            # f"[train] mse = {train_output['mse']:.4f} "
            # f"[train] mae = {train_output['mae']:.4f} "
            f"[valid] mse = {val_loss['mse']:.4f} "
            f"[valid] mae = {val_loss['mae']:.4f} "
            f"[test]  mse = {test_loss['mse']:.4f} "
            f"[test]  mae = {test_loss['mae']:.4f}"
        )
        if scheduler:
            scheduler.step(val_loss["mse"])
        
    ETTH2_RESULTS[f"{HORIZON_SIZE}_{NAMING}_spectral"] = (test_loss["mse"], test_loss["mae"])

Building graphs: 100%|██████████| 8522/8522 [00:04<00:00, 2063.83it/s] [00:00<?, ?it/s]
Building graphs: 100%|██████████| 2762/2762 [00:01<00:00, 2395.73it/s]
Building graphs: 100%|██████████| 2762/2762 [00:01<00:00, 2374.12it/s]
Training: 100%|██████████| 10/10 [00:26<00:00,  2.60s/it, [valid] mse = 1.3666 [valid] mae = 0.4957 [test]  mse = 1.6066 [test]  mae = 0.5757]
Building graphs: 100%|██████████| 8522/8522 [01:12<00:00, 116.74it/s]:32<10:53, 32.69s/it]     
Building graphs: 100%|██████████| 2762/2762 [00:23<00:00, 116.53it/s]
Building graphs: 100%|██████████| 2762/2762 [00:23<00:00, 118.61it/s]
Training: 100%|██████████| 10/10 [00:25<00:00,  2.51s/it, [valid] mse = 1.3172 [valid] mae = 0.4848 [test]  mse = 1.0778 [test]  mae = 0.4546]
Building graphs: 100%|██████████| 8522/8522 [00:58<00:00, 144.78it/s] [02:57<31:19, 98.92s/it]
Building graphs: 100%|██████████| 2762/2762 [00:18<00:00, 148.17it/s]
Building graphs: 100%|██████████| 2762/2762 [00:18<00:00, 148.94it/s]
Training: 100

In [16]:
# model = model.cpu()

{'24_complete_spectral': (1.6066245036455509, 0.5756825710398163),
 '24_ess_spectral': (1.077795967950989, 0.4545523525239124),
 '24_vanilla_spectral': (1.154165656964272, 0.47754708999648193),
 '48_complete_spectral': (2.3193355169266074, 0.6980898265812028),
 '48_ess_spectral': (1.444478478256385, 0.5338329618574602),
 '48_vanilla_spectral': (1.6461360525498585, 0.5752044366323623),
 '96_complete_spectral': (3.3559493988301816, 0.8482037454759589),
 '96_ess_spectral': (1.961970027654144, 0.6379904564482223),
 '96_vanilla_spectral': (2.0873398326320722, 0.652703737140951),
 '168_complete_spectral': (3.9593959956141553, 0.9119381445598642),
 '168_ess_spectral': (2.692933164920798, 0.7544599229315954),
 '168_vanilla_spectral': (3.1117744318963383, 0.8024391082298102),
 '192_complete_spectral': (4.697204345530578, 1.007387195434841),
 '192_ess_spectral': (2.3886356342975485, 0.7151937426140633),
 '192_vanilla_spectral': (2.83669327829283, 0.7626932452586801),
 '336_complete_spectral': (4

# ETTm1

In [17]:
with open("./spectral_gcn_study_results.pkl", "rb") as f:
    gcn_spectral_m1 = pickle.load(f)

gcn_spectral_m1

{'24_complete': <optuna.study.study.Study at 0x71985289c950>,
 '24_ess': <optuna.study.study.Study at 0x719854f7ded0>,
 '24_vanilla': <optuna.study.study.Study at 0x719855439190>,
 '48_complete': <optuna.study.study.Study at 0x7198589d4650>,
 '48_ess': <optuna.study.study.Study at 0x7197ed51c310>,
 '48_vanilla': <optuna.study.study.Study at 0x7197f3f1ea90>,
 '168_complete': <optuna.study.study.Study at 0x7197f255b7d0>,
 '168_ess': <optuna.study.study.Study at 0x719852719250>,
 '168_vanilla': <optuna.study.study.Study at 0x71984d7b3b50>,
 '336_complete': <optuna.study.study.Study at 0x7197eeaaa850>,
 '336_ess': <optuna.study.study.Study at 0x71985508ea10>,
 '336_vanilla': <optuna.study.study.Study at 0x7197f159cfd0>,
 '720_complete': <optuna.study.study.Study at 0x7197e3f74510>,
 '720_ess': <optuna.study.study.Study at 0x7197e1d4c690>,
 '720_vanilla': <optuna.study.study.Study at 0x7197f1fad150>}

In [18]:
DATASET_NAME = "ETTm1.csv"
ETTM1_RESULTS = {}

In [None]:
pbar = tqdm(SETUPS)
for HORIZON_SIZE, (GRAPH_CONSTRUCTION_FN, NAMING) in pbar:
    pbar.set_description(f"Training for horizon size {HORIZON_SIZE} and {NAMING} graph")
    train_ds, val_ds, test_ds = get_datasets(
        dataset_name=DATASET_NAME,
        lookback_size=LOOKBACK_SIZE,
        horizon_size=HORIZON_SIZE
    )
    
    train_adapter_ds = DatasetAdapter(
        dataset=train_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    val_adapter_ds = DatasetAdapter(
        dataset=val_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    test_adapter_ds = DatasetAdapter(
        dataset=test_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )
    
    if HORIZON_SIZE in [96, 192]:
        params = gcn_spectral_h1[f"{168}_{NAMING}"].best_params
    else:
        params = gcn_spectral_h1[f"{HORIZON_SIZE}_{NAMING}"].best_params
    
    train_loader = DataLoader(
        dataset=train_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=True,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    val_loader = DataLoader(
        dataset=val_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    test_loader = DataLoader(
        dataset=test_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )
    
    INPUT_DIM = train_adapter_ds[0][0].ndata["h"].shape[1]
    
    model = GraphTSModel(
        input_dim=INPUT_DIM,
        hidden_dim=params["hidden_dim"],
        num_layers=params["num_layers"],
        horizon_size=HORIZON_SIZE,
        activation_fn=ACTIVATION_FN,
        dropout=params["dropout"]
    ).to(device)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.33, patience=2)
    
    pbar2 = tqdm(range(10), desc="Training")

    for epoch in pbar2:
        train_loss = train_step(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=device
        )
        val_loss = evaluation_step(
            model=model,
            loader=val_loader,
            device=device
        )
        test_loss = evaluation_step(
            model=model,
            loader=test_loader,
            device=device
        )

        pbar2.set_postfix_str(
            # f"[train] mse = {train_output['mse']:.4f} "
            # f"[train] mae = {train_output['mae']:.4f} "
            f"[valid] mse = {val_loss['mse']:.4f} "
            f"[valid] mae = {val_loss['mae']:.4f} "
            f"[test]  mse = {test_loss['mse']:.4f} "
            f"[test]  mae = {test_loss['mae']:.4f}"
        )
        if scheduler:
            scheduler.step(val_loss["mse"])
        
    ETTM1_RESULTS[f"{HORIZON_SIZE}_{NAMING}_spectral"] = (test_loss["mse"], test_loss["mae"])

Building graphs: 100%|██████████| 34442/34442 [00:16<00:00, 2151.69it/s]00:00<?, ?it/s]
Building graphs: 100%|██████████| 11402/11402 [00:05<00:00, 2155.61it/s]
Building graphs: 100%|██████████| 11402/11402 [00:04<00:00, 2410.40it/s]
Training: 100%|██████████| 10/10 [01:29<00:00,  8.99s/it, [valid] mse = 0.2944 [valid] mae = 0.2272 [test]  mse = 0.2095 [test]  mae = 0.2002]
Building graphs: 100%|██████████| 34442/34442 [04:55<00:00, 116.43it/s]6<38:46, 116.32s/it]     
Building graphs: 100%|██████████| 11402/11402 [01:37<00:00, 116.90it/s]
Building graphs: 100%|██████████| 11402/11402 [01:36<00:00, 117.55it/s]
Training: 100%|██████████| 10/10 [01:32<00:00,  9.23s/it, [valid] mse = 0.2959 [valid] mae = 0.2273 [test]  mse = 0.1873 [test]  mae = 0.1925]
Building graphs: 100%|██████████| 34442/34442 [04:00<00:00, 143.23it/s]11:39<2:03:50, 391.10s/it]
Building graphs: 100%|██████████| 11402/11402 [01:17<00:00, 147.69it/s]
Building graphs: 100%|██████████| 11402/11402 [01:19<00:00, 143.78it/

In [None]:
ETTM1_RESULTS

# ETTm2

In [18]:
with open("./spectral_gcn_study_results_h2.pkl", "rb") as f:
    gcn_spectral_m2 = pickle.load(f)

gcn_spectral_m2

{'24_complete': <optuna.study.study.Study at 0x707738d79090>,
 '24_ess': <optuna.study.study.Study at 0x7076d27295d0>,
 '24_vanilla': <optuna.study.study.Study at 0x707738559210>,
 '48_complete': <optuna.study.study.Study at 0x70773595e0d0>,
 '48_ess': <optuna.study.study.Study at 0x70773842d8d0>,
 '48_vanilla': <optuna.study.study.Study at 0x7077359a6c10>,
 '168_complete': <optuna.study.study.Study at 0x7076d9306ed0>,
 '168_ess': <optuna.study.study.Study at 0x70772dfdea10>,
 '168_vanilla': <optuna.study.study.Study at 0x70773315cf90>,
 '336_complete': <optuna.study.study.Study at 0x7076d0103050>,
 '336_ess': <optuna.study.study.Study at 0x7076d4362550>,
 '336_vanilla': <optuna.study.study.Study at 0x70772df53fd0>,
 '720_complete': <optuna.study.study.Study at 0x7076d3353a90>,
 '720_ess': <optuna.study.study.Study at 0x7077356403d0>,
 '720_vanilla': <optuna.study.study.Study at 0x707738260050>}

In [None]:
DATASET_NAME = "ETTm2.csv"
ETTM2_RESULTS = {}

In [None]:
pbar = tqdm(SETUPS)
for HORIZON_SIZE, (GRAPH_CONSTRUCTION_FN, NAMING) in pbar:
    pbar.set_description(f"Training for horizon size {HORIZON_SIZE} and {NAMING} graph")
    train_ds, val_ds, test_ds = get_datasets(
        dataset_name=DATASET_NAME,
        lookback_size=LOOKBACK_SIZE,
        horizon_size=HORIZON_SIZE
    )
    
    train_adapter_ds = DatasetAdapter(
        dataset=train_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    val_adapter_ds = DatasetAdapter(
        dataset=val_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )

    test_adapter_ds = DatasetAdapter(
        dataset=test_ds,
        graph_construction_fn=GRAPH_CONSTRUCTION_FN,
        graph_features_fn=GRAPH_FEATURES_FN
    )
    
    if HORIZON_SIZE in [96, 192]:
        params = gcn_spectral_m2[f"{168}_{NAMING}"].best_params
    else:
        params = gcn_spectral_m2[f"{HORIZON_SIZE}_{NAMING}"].best_params
    
    train_loader = DataLoader(
        dataset=train_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=True,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    val_loader = DataLoader(
        dataset=val_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )

    test_loader = DataLoader(
        dataset=test_adapter_ds,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=4,
        collate_fn=graph_collate_fn
    )
    
    INPUT_DIM = train_adapter_ds[0][0].ndata["h"].shape[1]
    
    model = GraphTSModel(
        input_dim=INPUT_DIM,
        hidden_dim=params["hidden_dim"],
        num_layers=params["num_layers"],
        horizon_size=HORIZON_SIZE,
        activation_fn=ACTIVATION_FN,
        dropout=params["dropout"]
    ).to(device)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.33, patience=2)
    
    pbar2 = tqdm(range(10), desc="Training")

    for epoch in pbar2:
        train_loss = train_step(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=device
        )
        val_loss = evaluation_step(
            model=model,
            loader=val_loader,
            device=device
        )
        test_loss = evaluation_step(
            model=model,
            loader=test_loader,
            device=device
        )

        pbar2.set_postfix_str(
            # f"[train] mse = {train_output['mse']:.4f} "
            # f"[train] mae = {train_output['mae']:.4f} "
            f"[valid] mse = {val_loss['mse']:.4f} "
            f"[valid] mae = {val_loss['mae']:.4f} "
            f"[test]  mse = {test_loss['mse']:.4f} "
            f"[test]  mae = {test_loss['mae']:.4f}"
        )
        if scheduler:
            scheduler.step(val_loss["mse"])
        
    ETTM2_RESULTS[f"{HORIZON_SIZE}_{NAMING}_spectral"] = (test_loss["mse"], test_loss["mae"])

In [None]:
ETTM2_RESULTS

# Results

In [None]:
ETTH1_RESULTS = {'24_complete_spectral': (0.4622007438591531, 0.30501101441959144),
 '24_ess_spectral': (0.4727049479759634, 0.288319775855892),
 '24_vanilla_spectral': (0.4270755704750394, 0.30187430578586777),
 '48_complete_spectral': (0.759679712203114, 0.38969624580545276),
 '48_ess_spectral': (0.7753516732906234, 0.387353978282752),
 '48_vanilla_spectral': (0.7614292181928515, 0.3865646487949807),
 '96_complete_spectral': (0.7555281199250821, 0.3834116530603718),
 '96_ess_spectral': (1.1851305673970316, 0.40980954498431965),
 '96_vanilla_spectral': (0.8634562760628437, 0.48936872681049093),
 '168_complete_spectral': (0.9021050270250146, 0.4157375992348772),
 '168_ess_spectral': (0.834167979119915, 0.402345235634612),
 '168_vanilla_spectral': (0.8442587385406931, 0.40681322386101915),
 '192_complete_spectral': (0.8630211608914167, 0.45783084713681016),
 '192_ess_spectral': (0.8199182786159665, 0.4342503595720934),
 '192_vanilla_spectral': (0.8208010355827588, 0.43982778328383786),
 '336_complete_spectral': (0.8071296438915969, 0.3970712783642271),
 '336_ess_spectral': (1.426521624121189, 0.5495055765178436),
 '336_vanilla_spectral': (0.8546986619005856, 0.4184643910577625),
 '720_complete_spectral': (1.0712664136448589, 0.4734361688490667),
 '720_ess_spectral': (1.6110754871152995, 0.6035877630675096),
 '720_vanilla_spectral': (0.9605864774087068, 0.44913590352031946)}