# SF graph learning + Graph VAR dynamics

In [1]:
import os, pathlib
ROOT = pathlib.Path.home() / "sparse-graph-learning-use" / "sparse-graph-learning-main"
os.chdir(ROOT)

import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

from tsl.experiment import Experiment
from tsl.data import SpatioTemporalDataset, SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler
from tsl.datasets import AirQuality
from tsl.metrics.torch import MaskedMAE, MaskedMAPE, MaskedMSE

from lib.nn.graph_module import GraphModule
from lib.predictors.latent_graph_predictor import SFGraphPredictor
from lib.datasets.graph_polynomial_var import GraphPolyVARFilter  
import tsl
from omegaconf import OmegaConf

##  Data loader

In [2]:
def get_dataset(dataset_name):
    if dataset_name == "aqi":
        data_root = os.path.join("data", "airquality")
        dataset = AirQuality(root=data_root, small=True)  # 36 站
        data, mask, eval_mask, dist = dataset.load()
        tsl.logger.info(
            f"Loaded AirQuality: data={tuple(data.shape)}, mask={tuple(mask.shape)}, dist={tuple(dist.shape)}"
        )
        return dataset
    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")

## Config

In [3]:
cfg = OmegaConf.create({
    "dataset": {
        "name": "aqi",
        "splits": {"val_len": 0.1, "test_len": 0.2}
    },
    "window": 24,
    "horizon": 1,          # GraphPolyVARFilter set H=1
    "stride": 1,
    "batch_size": 128,
    "epochs": 30,
    "patience": 60,
    "batches_epoch": -1,   
    "graph_mode": "sf",    
    "mc_samples": 1,
    "lam": None,           
    "sf_weight": 1.0,
    "use_baseline": True,
    "variance_reduced": True,
    "workers": 0,
    "num_threads": 1,
    "run": {"dir": "logs/aqi_sf_bes_graphvar"},
    "clip_grad": False,
    "sampler": {
        "temperature": 0.5,
        "symmetric": True,
        "hard": True
    },
    "optim": {"lr": 1e-3}
})

## Experiment Function

In [4]:
def run_experiment(cfg):
    torch.set_num_threads(int(cfg.num_threads))

    dataset = get_dataset(cfg.dataset.name)
    data, mask, eval_mask, dist = dataset.load()

    data = np.asarray(data, dtype=np.float32)    
    mask = np.asarray(mask, dtype=bool)         
    adj  = np.exp(-dist / dist.std()).astype(np.float32)  
    time_index = dataset.index

    # TSL Dataset / DataModule
    torch_dataset = SpatioTemporalDataset(
        data,
        index=time_index,
        mask=mask,
        connectivity=adj,
        horizon=cfg.horizon,
        window=cfg.window,
        stride=cfg.stride,
    )

    dm = SpatioTemporalDataModule(
        dataset=torch_dataset,
        scalers={'target': StandardScaler(axis=(0, 1))},
        splitter=dataset.get_splitter(**cfg.dataset.splits),
        batch_size=cfg.batch_size,
        workers=cfg.workers,
    )
    dm.setup()

    # GraphModule kwargs
    gm_kwargs = dict(
        n_nodes=torch_dataset.n_nodes,
        mode=cfg.graph_mode,        
        sampler='bes',               
        tau=float(cfg.sampler.temperature),
        dummy_nodes=0,
    )

    # Loss / metrics
    loss_fn = MaskedMAE()
    metrics = {
        'mae': MaskedMAE(),
        'mse': MaskedMSE(),
        'mape': MaskedMAPE(),
    }

    # Predictor：use GraphPolyVARFilter 
    predictor = SFGraphPredictor(
        model_class=GraphPolyVARFilter,
        model_kwargs=dict(
            temporal_order=cfg.window,   
            spatial_order=3,            
            node_feature_dim=1,         
            horizon=cfg.horizon,         
            activation="tanh",
        ),
        optim_class=torch.optim.Adam,
        optim_kwargs=dict(lr=float(cfg.optim.lr)),
        loss_fn=loss_fn,
        metrics=metrics,
        graph_module_class=GraphModule,
        graph_module_kwargs=gm_kwargs,
        sf_weight=cfg.sf_weight,
        use_baseline=cfg.use_baseline,
        mc_samples=cfg.mc_samples,
        variance_reduced=cfg.variance_reduced,
        surrogate_lam=cfg.lam,
        scale_target=False
    )

    # Trainer / callbacks
    os.makedirs(cfg.run.dir, exist_ok=True)
    early_stop = EarlyStopping(monitor='val_mae', patience=cfg.patience, mode='min')
    ckpt = ModelCheckpoint(dirpath=cfg.run.dir, save_top_k=1, monitor='val_mae', mode='min')
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    clip_algo = 'value' if cfg.clip_grad else None
    clip_val  = 0.5 if cfg.clip_grad else 0.0

    trainer = pl.Trainer(
        max_epochs=cfg.epochs,
        limit_train_batches=(1.0 if cfg.batches_epoch < 0 else cfg.batches_epoch),
        default_root_dir=cfg.run.dir,
        accelerator=('gpu' if torch.cuda.is_available() else 'cpu'),
        callbacks=[early_stop, ckpt, lr_monitor],
        gradient_clip_algorithm=clip_algo,
        gradient_clip_val=clip_val,
        log_every_n_steps=1
    )

    # Fit & Test
    trainer.fit(predictor, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

    state = torch.load(ckpt.best_model_path, map_location='cpu')['state_dict']
    predictor.load_state_dict(state)
    predictor.freeze()
    test_out = trainer.test(predictor, dataloaders=dm.test_dataloader())
    

    return {"best_ckpt": ckpt.best_model_path, "test": test_out}


In [5]:
res = run_experiment(cfg)
res

  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]`

                                                                           

/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Epoch 29: 100%|██████████| 40/40 [00:01<00:00, 30.26it/s, v_num=2, val_mae=45.90, val_mape=0.766, val_mse=4.15e+3, train_mae=50.00, train_mape=0.864, train_mse=4.65e+3]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 40/40 [00:01<00:00, 30.20it/s, v_num=2, val_mae=45.90, val_mape=0.766, val_mse=4.15e+3, train_mae=50.00, train_mape=0.864, train_mse=4.65e+3]


  state = torch.load(ckpt.best_model_path, map_location='cpu')['state_dict']
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 81.50it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            44.85068893432617
        test_mae             44.99418640136719
        test_mape           0.8099589347839355
        test_mse             3851.104736328125
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


{'best_ckpt': '/home/wangxc1117/sparse-graph-learning-use/sparse-graph-learning-main/logs/aqi_sf_bes_graphvar/epoch=27-step=1120-v1.ckpt',
 'test': [{'test_mae': 44.99418640136719,
   'test_mape': 0.8099589347839355,
   'test_mse': 3851.104736328125,
   'test_loss': 44.85068893432617}]}

## Experiment parameter combinations

In [None]:
BEST_L        = 2
BEST_TAU      = 0.3
BEST_SFWEIGHT = 1.0
BEST_window   = 72
BEST_epochs   = 80

# SF graph learning + Graph VAR dynamics
import os, pathlib
ROOT = pathlib.Path.home() / "sparse-graph-learning-use" / "sparse-graph-learning-main"
os.chdir(ROOT)

import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

from tsl.experiment import Experiment
from tsl.data import SpatioTemporalDataset, SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler
from tsl.datasets import AirQuality
from tsl.metrics.torch import MaskedMAE, MaskedMAPE, MaskedMSE

from lib.nn.graph_module import GraphModule
from lib.predictors.latent_graph_predictor import SFGraphPredictor
from lib.datasets.graph_polynomial_var import GraphPolyVARFilter
import tsl
from omegaconf import OmegaConf

##  Data loader
def get_dataset(dataset_name):
    if dataset_name == "aqi":
        data_root = os.path.join("data", "airquality")
        dataset = AirQuality(root=data_root, small=True)  # 36 站
        data, mask, eval_mask, dist = dataset.load()
        tsl.logger.info(
            f"Loaded AirQuality: data={tuple(data.shape)}, mask={tuple(mask.shape)}, dist={tuple(dist.shape)}"
        )
        return dataset
    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")

## Config (base)
cfg = OmegaConf.create({
    "dataset": {
        "name": "aqi",
        "splits": {"val_len": 0.1, "test_len": 0.2}
    },
    "window": 24,              
    "horizon": 1,              
    "stride": 1,
    "batch_size": 128,
    "epochs": 30,              
    "patience": 60,            
    "batches_epoch": -1,
    "graph_mode": "sf",
    "mc_samples": 1,
    "lam": None,
    "sf_weight": 1.0,          
    "use_baseline": True,
    "variance_reduced": True,
    "workers": 0,
    "num_threads": 1,
    "run": {"dir": "logs/aqi_sf_bes_graphvar"},
    "clip_grad": False,
    "sampler": {
        "temperature": 0.5,    
        "symmetric": True,
        "hard": True
    },
    "optim": {"lr": 1e-3}
})

cfg.window = int(BEST_window)
cfg.epochs = int(BEST_epochs)
cfg.sf_weight = float(BEST_SFWEIGHT)
cfg.sampler.temperature = float(BEST_TAU)
cfg.patience = int(min(int(cfg.patience), cfg.epochs - 1)) if cfg.epochs > 1 else 0
cfg.run.dir = f"logs/aqi_sf_bes_graphvar_W{cfg.window}_E{cfg.epochs}_L{BEST_L}_T{cfg.sampler.temperature}_SF{cfg.sf_weight}"

## Experiment Function
def run_experiment(cfg):
    torch.set_num_threads(int(cfg.num_threads))

    dataset = get_dataset(cfg.dataset.name)
    data, mask, eval_mask, dist = dataset.load()

    data = np.asarray(data, dtype=np.float32)
    mask = np.asarray(mask, dtype=bool)
    adj  = np.exp(-dist / dist.std()).astype(np.float32)
    time_index = dataset.index

    # TSL Dataset / DataModule
    torch_dataset = SpatioTemporalDataset(
        data,
        index=time_index,
        mask=mask,
        connectivity=adj,
        horizon=cfg.horizon,
        window=cfg.window,
        stride=cfg.stride,
    )

    dm = SpatioTemporalDataModule(
        dataset=torch_dataset,
        scalers={'target': StandardScaler(axis=(0, 1))},
        splitter=dataset.get_splitter(**cfg.dataset.splits),
        batch_size=cfg.batch_size,
        workers=cfg.workers,
    )
    dm.setup()

    # GraphModule kwargs
    gm_kwargs = dict(
        n_nodes=torch_dataset.n_nodes,
        mode=cfg.graph_mode,
        sampler='bes',              
        tau=float(cfg.sampler.temperature),
        dummy_nodes=0,
    )

    # Loss / metrics
    loss_fn = MaskedMAE()
    metrics = {
        'mae': MaskedMAE(),
        'mse': MaskedMSE(),
        'mape': MaskedMAPE(),
    }

    # Predictor：use GraphPolyVARFilter
    predictor = SFGraphPredictor(
        model_class=GraphPolyVARFilter,
        model_kwargs=dict(
            temporal_order=cfg.window,     
            spatial_order=int(BEST_L),     
            node_feature_dim=1,
            horizon=cfg.horizon,
            activation="tanh",
        ),
        optim_class=torch.optim.Adam,
        optim_kwargs=dict(lr=float(cfg.optim.lr)),
        loss_fn=loss_fn,
        metrics=metrics,
        graph_module_class=GraphModule,
        graph_module_kwargs=gm_kwargs,
        sf_weight=cfg.sf_weight,          
        use_baseline=cfg.use_baseline,
        mc_samples=cfg.mc_samples,
        variance_reduced=cfg.variance_reduced,
        surrogate_lam=cfg.lam,
        scale_target=False
    )

    # Trainer / callbacks
    os.makedirs(cfg.run.dir, exist_ok=True)
    early_stop = EarlyStopping(monitor='val_mae', patience=cfg.patience, mode='min')
    ckpt = ModelCheckpoint(dirpath=cfg.run.dir, save_top_k=1, monitor='val_mae', mode='min')
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    clip_algo = 'value' if cfg.clip_grad else None
    clip_val  = 0.5 if cfg.clip_grad else 0.0

    trainer = pl.Trainer(
        max_epochs=cfg.epochs,                                 
        limit_train_batches=(1.0 if cfg.batches_epoch < 0 else cfg.batches_epoch),
        default_root_dir=cfg.run.dir,
        accelerator=('gpu' if torch.cuda.is_available() else 'cpu'),
        callbacks=[early_stop, ckpt, lr_monitor],
        gradient_clip_algorithm=clip_algo,
        gradient_clip_val=clip_val,
        log_every_n_steps=1
    )

    # Fit & Test
    trainer.fit(predictor, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

    state = torch.load(ckpt.best_model_path, map_location='cpu')['state_dict']
    predictor.load_state_dict(state)
    predictor.freeze()
    test_out = trainer.test(predictor, dataloaders=dm.test_dataloader())

    return {"best_ckpt": ckpt.best_model_path, "test": test_out}


res = run_experiment(cfg)
res


  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
  nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /home/wangxc1117/sparse-graph-learning-use/sparse-graph-learning-main/logs/aqi_sf_bes_graphvar_W72_E80_L2_T0.3_SF1.0 exists and is not empty.

  | Name          | Type                   | Params | Mode 
----------------------------------------------------

                                                                           

/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Epoch 79: 100%|██████████| 40/40 [00:01<00:00, 36.62it/s, v_num=0, val_mae=21.50, val_mape=0.416, val_mse=1.91e+3, train_mae=24.50, train_mape=0.538, train_mse=2.36e+3]

`Trainer.fit` stopped: `max_epochs=80` reached.


Epoch 79: 100%|██████████| 40/40 [00:01<00:00, 36.50it/s, v_num=0, val_mae=21.50, val_mape=0.416, val_mse=1.91e+3, train_mae=24.50, train_mape=0.538, train_mse=2.36e+3]


  state = torch.load(ckpt.best_model_path, map_location='cpu')['state_dict']
/home/wangxc1117/miniconda3/envs/sgl/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 78.99it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           17.593307495117188
        test_mae            17.691139221191406
        test_mape           0.5590103268623352
        test_mse              1280.744140625
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


{'best_ckpt': '/home/wangxc1117/sparse-graph-learning-use/sparse-graph-learning-main/logs/aqi_sf_bes_graphvar_W72_E80_L2_T0.3_SF1.0/epoch=75-step=3040.ckpt',
 'test': [{'test_mae': 17.691139221191406,
   'test_mape': 0.5590103268623352,
   'test_mse': 1280.744140625,
   'test_loss': 17.593307495117188}]}