In [None]:
"""
End-to-end PyTorch Geometric training script for a TopologicPy CSV dataset.

Expected files (same folder):
  - graphs.csv
  - nodes.csv
  - edges.csv
  - meta.yaml (optional; used only for metadata)

This script:
  1) Loads the dataset with the PyG helper class
  2) Builds a graph-level GNN model (graph classification by default)
  3) Trains with a train/val split
  4) Evaluates on validation + test
  5) Visualizes learning curves + confusion matrix + prints metrics

Notes
-----
- Requires: torch, torch_geometric, pandas, pyyaml, numpy, plotly, scikit-learn
- Install (example):
    pip install torch pandas pyyaml numpy plotly scikit-learn
    # then install torch-geometric following their official instructions for your OS/CUDA.
"""
# This cell is not needed if you have pip installed topologicpy
import sys
sys.path.append("C:/Users/sarwj/OneDrive - Cardiff University/Documents/GitHub/topologicpy/src")

from __future__ import annotations

import os
from pathlib import Path
import json
import yaml

# ---- If PyG_documented.py is in the same folder as this script, this works:
# from PyG_documented import PyG
#
# ---- If it's elsewhere, add its folder to sys.path, e.g.:
# import sys
# sys.path.append(str(Path(__file__).resolve().parent))
# from PyG_documented import PyG

from topologicpy.PyG import PyG


def pretty_print_metrics(title: str, metrics: dict) -> None:
    print("\n" + "=" * 80)
    print(title)
    print("=" * 80)
    for k in sorted(metrics.keys()):
        v = metrics[k]
        if isinstance(v, float):
            print(f"{k:30s}: {v:.6f}")
        else:
            print(f"{k:30s}: {v}")
    print("=" * 80 + "\n")


def main():
    # ---------------------------------------------------------------------
    # 0) Locate dataset folder (the attached CSVs are typically in cwd)
    # ---------------------------------------------------------------------
    #dataset_dir = Path(".").resolve()
    dataset_dir = Path(r"C:\Users\sarwj\OneDrive - Cardiff University\Documents\GitHub\topologicpy\assets\MachineLearning\training_dataset").resolve()

    # Optional: read meta.yaml (purely informational)
    meta_path = dataset_dir / "meta.yaml"
    if meta_path.exists():
        meta = yaml.safe_load(meta_path.read_text(encoding="utf-8"))
        print("Loaded meta.yaml:")
        print(json.dumps(meta, indent=2))
    else:
        print("meta.yaml not found (this is fine).")

    # ---------------------------------------------------------------------
    # 1) Create the helper and load/build Data objects
    # ---------------------------------------------------------------------
    # This dataset has graphs.csv with a categorical 'label' column -> graph classification.
    pyg = PyG.ByCSVPath(
        path=str(dataset_dir),
        level="graph",              # "graph" | "node" | "edge" | "link"
        task="classification",      # "classification" | "regression" | "link_prediction"
        graphLabelType="categorical",
        nodeLabelType="categorical",
        edgeLabelType="categorical",
        # If your headers differ, override here (your attached CSVs match defaults):
        # graphIDHeader="graph_id", graphLabelHeader="label",
        # nodeIDHeader="node_id", nodeLabelHeader="label",
        # edgeSRCHeader="src_id", edgeDSTHeader="dst_id", edgeLabelHeader="label",
    )

    # ---------------------------------------------------------------------
    # 2) Set hyperparameters / model architecture
    # ---------------------------------------------------------------------
    # You can tweak these aggressively; these are sane starting points for imbalanced 5-class data.
    pyg.SetHyperparameters(
        # splitting / determinism
        cv="holdout",
        split=(0.80, 0.10, 0.10),   # train/val/test
        random_state=42,
        shuffle=True,

        # training
        epochs=20,
        batch_size=64,
        lr=1e-3,
        weight_decay=1e-4,
        optimizer="adamw",
        gradient_clip_norm=1.0,
        early_stopping=True,
        early_stopping_patience=12,
        use_gpu=True,              # will use CUDA if available

        # model
        conv="sage",               # "sage" | "gcn" | "gatv2"
        hidden_dims=(128, 128),    # depth = len(hidden_dims)
        activation="relu",         # "relu" | "gelu" | "elu"
        dropout=0.20,
        batch_norm=True,
        residual=True,
        pooling="mean",            # "mean" | "max" | "add" (graph-level only)
    )

    # Print a compact summary of the current config and inferred dims/classes
    print("PyG config summary:")
    print(pyg.Summary())

    # ---------------------------------------------------------------------
    # 3) Train
    # ---------------------------------------------------------------------
    history = pyg.Train()  # returns dict of per-epoch curves (loss + metrics when available)

    # ---------------------------------------------------------------------
    # 4) Validate + Test
    # ---------------------------------------------------------------------
    val_metrics = pyg.Validate()
    test_metrics = pyg.Test()

    pretty_print_metrics("Validation metrics", val_metrics)
    pretty_print_metrics("Test metrics", test_metrics)

    # ---------------------------------------------------------------------
    # 5) Visualize learning curves and performance metrics
    # ---------------------------------------------------------------------
    # The helper provides Plotly figures (nice in notebooks; can also be saved to disk).
    fig_hist = pyg.PlotHistory()
    fig_hist.show()

    # Confusion matrix only valid for classification tasks
    fig_cm = pyg.PlotConfusionMatrix(split="test")
    fig_cm.show()

    # Optional: save figures (requires kaleido: pip install -U kaleido)
    # out_dir = dataset_dir / "outputs"
    # out_dir.mkdir(parents=True, exist_ok=True)
    # try:
    #     fig_hist.write_image(str(out_dir / "learning_curves.png"))
    #     fig_cm.write_image(str(out_dir / "confusion_matrix_test.png"))
    #     print(f"Saved figures to: {out_dir}")
    # except Exception as e:
    #     print(
    #         "Could not save Plotly images (this is usually because kaleido is missing).\n"
    #         "Install it via: pip install -U kaleido\n"
    #         f"Error: {e}"
    #     )

    # If you want to inspect the raw history keys:
    print("History keys:", list(history.keys()))
    # Common keys include: train_loss, val_loss, and sometimes train_acc/val_acc, etc.


if __name__ == "__main__":
    main()

Loaded meta.yaml:
{
  "dataset_name": "topologic_training_dataset",
  "edge_data": [
    {
      "file_name": "edges.csv"
    }
  ],
  "node_data": [
    {
      "file_name": "nodes.csv"
    }
  ],
  "graph_data": {
    "file_name": "graphs.csv"
  }
}
PyG config summary:
{'level': 'graph', 'task': 'classification', 'graph_label_type': 'categorical', 'node_label_type': 'categorical', 'edge_label_type': 'categorical', 'cv': 'holdout', 'split': (0.8, 0.1, 0.1), 'k_folds': 5, 'conv': 'sage', 'hidden_dims': (128, 128), 'activation': 'relu', 'dropout': 0.2, 'batch_norm': True, 'residual': True, 'pooling': 'mean', 'epochs': 20, 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.0001, 'optimizer': 'adamw', 'gradient_clip_norm': 1.0, 'early_stopping': True, 'early_stopping_patience': 12, 'device': 'cuda:0', 'num_graphs': 1496, 'num_outputs': 5}

Validation metrics
val_accuracy                  : 0.993289
val_f1                        : 0.993078
val_precision                 : 0.993475
val_recall 

Saved figures to: C:\Users\sarwj\OneDrive - Cardiff University\Documents\GitHub\topologicpy\assets\MachineLearning\training_dataset\outputs
History keys: ['train_loss', 'val_loss']
