## Summary
- Convert graphs to PyG datasets and split train/val/test.
- Train a classifier with optional clinical features.
- Save models, results, and datasets under `{DISEASE}_analysis_output/`.


In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import pickle
from pathlib import Path
import json
from datetime import datetime
from collections import defaultdict
from typing import List, Tuple, Dict, Optional

import networkx as nx

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    TORCH_AVAILABLE = True
    print(f" PyTorch {torch.__version__} available")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU devices: {torch.cuda.device_count()}")
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available - please install PyTorch first")
    raise ImportError("PyTorch is required for this notebook")

try:
    import torch_geometric
    from torch_geometric.data import Data, Dataset
    from torch_geometric.utils import from_networkx
    from torch_geometric.loader import DataLoader
    from torch_geometric.nn import GCNConv, GINEConv, global_mean_pool
    PYGEOMETRIC_AVAILABLE = True
    print(f" PyTorch Geometric {torch_geometric.__version__} available")
except ImportError:
    PYGEOMETRIC_AVAILABLE = False
    print("PyTorch Geometric not available")
    print(f"Install with: pip install torch-geometric")
    raise ImportError("PyTorch Geometric is required for this notebook")

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

print("\n {DISEASE} PyTorch Geometric Dataset Construction Pipeline")
print("=" * 50)
print(f" Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f" Python version: {sys.version}")
print(f" Working directory: {os.getcwd()}")


In [None]:
import yaml
from pathlib import Path
import os

config_path = None
if 'EXPERIMENT_CONFIG_PATH' in os.environ:
    config_path = Path(os.environ['EXPERIMENT_CONFIG_PATH'])
    print(f" Using config from environment: {config_path}")
elif Path("config.yaml").exists():
    config_path = Path("config.yaml")
    print(f" Found config in current directory")
elif Path("../../experiments/baseline/config.yaml").exists():
    config_path = Path("../../experiments/baseline/config.yaml")
    print(f"  Using fallback config: {config_path}")

if config_path and config_path.exists():
    print(f" Loading experiment configuration from: {config_path}")
    with open(config_path, 'r') as f:
        EXPERIMENT_CONFIG = yaml.safe_load(f)
    print(f" Loaded configuration for experiment")
    
    # FIX: Look for disease in multiple locations
    exp_disease = (
        EXPERIMENT_CONFIG.get('disease') or
        EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease') or
        EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease_criteria', {}).get('disease')
    )
    
    mt_config = EXPERIMENT_CONFIG.get('model_training', {})
    arch_config = mt_config.get('architecture', {})
    hidden_dim = arch_config.get('hidden_dim', 'N/A')
    num_layers = arch_config.get('num_layers', 'N/A')
    
    print(f"Disease from config: {exp_disease}")
    print(f"Model: hidden_dim={hidden_dim}, layers={num_layers}")
else:
    print(" No experiment config found - using default parameters from notebook")
    EXPERIMENT_CONFIG = None
    exp_disease = None

use_clinical = False
clinical_dim = 0
clinical_df = None

# FIX: Set DISEASE properly from config
DISEASE = "IBD"  # Default
if exp_disease:
    DISEASE = exp_disease.upper()
    print(f"DISEASE set to: {DISEASE}")


# Multi-Disease PyTorch Geometric Dataset and Training Pipeline

## Overview

This notebook converts the NetworkX graphs from `02_graph_construction.ipynb`
into PyTorch Geometric datasets and trains a classifier for the selected
disease. The disease and hyperparameters are loaded from the experiment config
(`EXPERIMENT_CONFIG_PATH`) or `pipeline_config.json` when available.

## Pipeline Architecture

Input Data (from 02_graph_construction.ipynb)
- NetworkX graphs (nx_graphs_{DISEASE}.pkl)
- Labels (labels_{DISEASE}.npy)
- Graph metadata (graph_metadata_{DISEASE}.csv)
- Graph config (graph_config_{DISEASE}.json)

Dataset/Training Pipeline:
1. Configuration and data loading
2. NetworkX to PyG conversion
3. Train/val/test split
4. Model training and evaluation
5. Outputs and artifacts

Output Files
- pytorch_geometric/{disease}_dataset.pt
- pytorch_geometric/train_dataset.pt
- pytorch_geometric/val_dataset.pt
- pytorch_geometric/test_dataset.pt
- pytorch_geometric/dataset_statistics.json
- results/training_history.json
- models/best_model.pt
- results/evaluation_results.json
- visualizations/training_curves.png
- visualizations/confusion_matrix_roc.png

## Key Features

- PyTorch Geometric conversion from NetworkX graphs
- Stratified splitting for balanced train/val/test sets
- Optional clinical feature integration via `src/feature_loader.py`
- Binary or multi-class training based on config
- Saved artifacts for reproducibility

## Prerequisites

- Completed execution of `02_graph_construction.ipynb`
- Generated files in `{DISEASE}_analysis_output/graphs/` directory
- PyTorch and PyTorch Geometric installed


## Configuration and Setup

### Pipeline Configuration


In [None]:
possible_config_paths = [
    Path(f"{DISEASE}_analysis_output/config/pipeline_config.json"),
    Path(f"../../notebooks/{DISEASE}_analysis_output/config/pipeline_config.json"),
]

config = None
for config_file in possible_config_paths:
    if config_file.exists():
        with open(config_file, "r") as f:
            config = json.load(f)
        print(f" Loaded configuration from: {config_file}")
        break

if config is None:
    print("No configuration file found - using default settings")
    config = {
        "disease": "IBD",
        "output_dir": f"{DISEASE}_analysis_output"
    }

DISEASE = config.get("disease", "IBD")

                                                      
if 'EXPERIMENT_CONFIG' in globals() and EXPERIMENT_CONFIG:
    exp_disease = EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease', None)
    if exp_disease and exp_disease != 'N/A':
        DISEASE = exp_disease
        print(f'Using disease from experiment config: {DISEASE}')

possible_output_dirs = [
    Path(f"{DISEASE}_analysis_output"),
    Path(f"../../notebooks/{DISEASE}_analysis_output"),
]

output_dir = None
for out_dir in possible_output_dirs:
    if out_dir.exists():
        output_dir = out_dir
        print(f" Found output directory: {output_dir}")
        break

if output_dir is None:
    output_dir = Path(config.get("output_dir", f"{DISEASE}_analysis_output"))
    print(f" Using fallback output directory: {output_dir}")

GRAPHS_PKL = output_dir / "graphs" / f"nx_graphs_{DISEASE}.pkl"
LABELS_NPY = output_dir / "graphs" / f"labels_{DISEASE}.npy"
METADATA_CSV = output_dir / "graphs" / f"graph_metadata_{DISEASE}.csv"
GRAPH_CONFIG_JSON = output_dir / "graphs" / f"graph_config_{DISEASE}.json"

# Determine output directories based on experiment config
experiment_output_dir = None
if EXPERIMENT_CONFIG is not None:
    exp_output = EXPERIMENT_CONFIG.get('output_dir')
    if exp_output:
        experiment_output_dir = Path(exp_output)
        print(f"Using experiment output directory: {experiment_output_dir}")

# Use experiment output dir if available, otherwise use disease-specific dir
if experiment_output_dir and experiment_output_dir.exists():
    base_output = experiment_output_dir
else:
    base_output = output_dir

pytorch_output_dir = base_output / "pytorch_geometric"
models_output_dir = base_output / "models"
results_output_dir = base_output / "results"
visualizations_dir = base_output / "visualizations"

for dir_path in [pytorch_output_dir, models_output_dir, results_output_dir, visualizations_dir]:
    dir_path.mkdir(exist_ok=True, parents=True)

print(f"\n Disease focus: {DISEASE}")
print(f" Input directory: {output_dir}")
print(f" PyTorch Geometric output: {pytorch_output_dir}")
print(f" Models output: {models_output_dir}")
print(f" Results output: {results_output_dir}")
print(f" Visualizations directory: {visualizations_dir}")

import sys
sys.path.insert(0, str(Path(__file__).parent.parent) if '__file__' in globals() else str(Path.cwd().parent))
from src.feature_loader import load_features

if EXPERIMENT_CONFIG is not None:
    print("\n Checking for additional features...")
    features = load_features(EXPERIMENT_CONFIG, output_dir)
    use_clinical = features['use_clinical']
    clinical_dim = features['clinical_dim']
    clinical_df = features['clinical']
    
    if use_clinical:
        print(f" Clinical features enabled: {clinical_dim} features")
        print(f" Clinical data shape: {clinical_df.shape}")
    else:
        print(f" Microbiome only (no clinical features)")
else:
    print("\n Using microbiome features only (no config provided)")


In [None]:
DATASET_PARAMS = {
    "test_size": 0.25,
    "val_size": 0.25,
    "random_seed": 42,
    "stratify": True,
    
    "node_feature_type": "weight",
    "edge_feature_type": "weight",
    "normalize_features": True,
    
    "enable_augmentation": False,
}

print("Dataset Configuration:")
for key, value in DATASET_PARAMS.items():
    print(f"{key}: {value}")


In [None]:
MODEL_PARAMS = {
    "model_type": "GCN",
    "hidden_dim": 160,
    "num_layers": 2,
    "dropout": 0.3,
    "pooling": "mean",
    "batch_size": 16,
    "num_epochs": 200,
    "learning_rate": 0.0001,
    "weight_decay": 0.000005,
    "patience": 30,
}

print("Model Configuration:")
for key, value in MODEL_PARAMS.items():
    print(f"{key}: {value}")


In [None]:
VIZ_PARAMS = {
    "figsize": (12, 8),
    "dpi": 150,
    "save_plots": True,
    "show_plots": False
}

print("Visualization Configuration:")
for key, value in VIZ_PARAMS.items():
    print(f"{key}: {value}")


In [None]:
if EXPERIMENT_CONFIG is not None:
    print("\n Overriding parameters with experiment configuration...")
    
    if "dataset_splitting" in EXPERIMENT_CONFIG:
        for key, value in EXPERIMENT_CONFIG["dataset_splitting"].items():
            if key in DATASET_PARAMS:
                DATASET_PARAMS[key] = value
                print(f" DATASET_PARAMS['{key}'] = {value}")
    
    if "model_training" in EXPERIMENT_CONFIG:
        mt_config = EXPERIMENT_CONFIG["model_training"]
        
        if "architecture" in mt_config:
            for key, value in mt_config["architecture"].items():
                if key in MODEL_PARAMS:
                    MODEL_PARAMS[key] = value
                    print(f" MODEL_PARAMS['{key}'] = {value}")
        
        if "training" in mt_config:
            for key, value in mt_config["training"].items():
                if key in MODEL_PARAMS:
                    MODEL_PARAMS[key] = value
                    print(f" MODEL_PARAMS['{key}'] = {value}")
        
        if "optimizer" in mt_config:
            for key, value in mt_config["optimizer"].items():
                if key in MODEL_PARAMS:
                    MODEL_PARAMS[key] = value
                    print(f" MODEL_PARAMS['{key}'] = {value}")
    
    print(f" Parameters overridden with experiment config")
else:
    print("\n Using default parameters from notebook")

import torch
MODEL_PARAMS["device"] = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n Device: {MODEL_PARAMS['device']}")


## Data Loading and Validation
### Load graphs and labels


In [None]:
print("Validating input files...")

required_files = [
    (GRAPHS_PKL, "NetworkX graphs pickle"),
    (LABELS_NPY, "Labels array"),
    (METADATA_CSV, "Graph metadata")
]

missing_files = []
for file_path, description in required_files:
    if file_path.exists():
        size_mb = file_path.stat().st_size / (1024*1024)
        print(f" {description}: {file_path} ({size_mb:.1f} MB)")
    else:
        print(f" {description}: {file_path} - MISSING")
        missing_files.append(file_path)

if missing_files:
    print(f"\n Missing required files: {len(missing_files)}")
    print("Please run notebook 01_data_extraction.ipynb first to generate these files.")
    raise FileNotFoundError("Required input files are missing")
else:
    print(f"\n All required files found!")

print("\n Loading data...")

print(f"Loading NetworkX graphs...")
with open(GRAPHS_PKL, "rb") as f:
    GRAPHS = pickle.load(f)
print(f" Loaded {len(GRAPHS):,} graphs")

print(f"Loading labels...")
LABELS = np.load(LABELS_NPY)
print(f" Loaded {len(LABELS):,} labels")

# Check if labels should be shuffled (control experiment)
shuffle_labels = False
if EXPERIMENT_CONFIG is not None:
    shuffle_labels = EXPERIMENT_CONFIG.get('data_extraction', {}).get('shuffle_labels', False)

if shuffle_labels:
    print(" SHUFFLING LABELS (control experiment)")
    np.random.seed(42)  # For reproducibility
    LABELS = np.random.permutation(LABELS)
    print(f" Labels shuffled - this should give ~50% accuracy if no data leakage")

print(f"Loading metadata...")
metadata_df = pd.read_csv(METADATA_CSV)
print(f" Loaded metadata: {metadata_df.shape}")

sample_ids = metadata_df['sample_id'].tolist()
print(f" Extracted {len(sample_ids)} sample IDs")

print("\n Validating data consistency...")
assert len(GRAPHS) == len(LABELS), "Number of graphs and labels don't match!"
assert len(GRAPHS) == len(metadata_df), "Number of graphs and metadata rows don't match!"

print(f" Data validation passed")
print(f" Total samples: {len(GRAPHS):,}")
print(f" {DISEASE} cases (label=1): {sum(LABELS):,}")
print(f" Controls (label=0): {len(LABELS) - sum(LABELS):,}")
print(f"  Class balance: {sum(LABELS)/len(LABELS)*100:.1f}% {DISEASE} cases")

print("\n Sample metadata:")
display(metadata_df.head())


## NetworkX to PyTorch Geometric Conversion
### Prepare PyG graphs


In [None]:
print("Preparing graphs for PyTorch Geometric conversion...")

def prepare_graph_for_pyg(G):
    """
    Prepare a NetworkX graph for PyTorch Geometric conversion.
    Adds integer node IDs and ensures consistent node/edge attributes.
    
    Parameters:
    -----------
    G : networkx.Graph
        Input NetworkX graph
        
    Returns:
    --------
    networkx.Graph
        Graph with added integer node IDs
    """
    G_copy = G.copy()
    
    node_list = list(G_copy.nodes())
    node_to_int = {node: idx for idx, node in enumerate(node_list)}
    
    for node in G_copy.nodes():
        G_copy.nodes[node]['int_id'] = node_to_int[node]
    
    for u, v in G_copy.edges():
        G_copy.edges[u, v]['source_int_id'] = node_to_int[u]
        G_copy.edges[u, v]['dest_int_id'] = node_to_int[v]
    
    for node in G_copy.nodes():
        weight = G_copy.nodes[node].get('weight', 1.0)
        while isinstance(weight, list):
            if len(weight) == 0:
                weight = 1.0
                break
            elif len(weight) == 1:
                weight = weight[0]
            else:
                weight = weight[0]
                break
        G_copy.nodes[node]['weight'] = float(weight)
    
    return G_copy

print(f"Processing graphs...")
GRAPHS_PREPARED = []
for i, G in enumerate(tqdm(GRAPHS, desc="Preparing graphs")):
    try:
        G_prepared = prepare_graph_for_pyg(G)
        GRAPHS_PREPARED.append(G_prepared)
    except Exception as e:
        print(f" Error preparing graph {i}: {e}")
        GRAPHS_PREPARED.append(G)

print(f" Prepared {len(GRAPHS_PREPARED):,} graphs")

sample_graph = GRAPHS_PREPARED[0]
print(f"\n Sample prepared graph:")
print(f"Nodes: {sample_graph.number_of_nodes()}")
print(f"Edges: {sample_graph.number_of_edges()}")
print(f"Node attributes: {list(sample_graph.nodes(data=True))[0]}")
if sample_graph.number_of_edges() > 0:
    print(f"Edge attributes: {list(sample_graph.edges(data=True))[0]}")


### PyTorch Geometric Dataset Class


In [None]:
class MicrobiomeGraphDataset(Dataset):
    """
    PyTorch Geometric dataset for disease graph classification.
    
    Converts NetworkX graphs to PyTorch Geometric Data objects with:
    - Node features (normalized weights)
    - Edge features (normalized weights)
    - Graph-level labels (Disease vs Control)
    """
    
    def __init__(self, graphs, labels, clinical_df=None, sample_ids=None, normalize=True):
        """
        Initialize dataset.
        
        Parameters:
        -----------
        graphs : list of networkx.Graph
            List of prepared NetworkX graphs
        labels : np.ndarray or list
            Binary labels (1=Disease, 0=Control)
        clinical_df : pd.DataFrame, optional
            Clinical features DataFrame with
        sample_ids : list, optional
            List of sample IDs corresponding to each graph
        normalize : bool
            Whether to normalize node and edge features
        """
        super().__init__()
        self.graphs = graphs
        self.labels = labels if isinstance(labels, np.ndarray) else np.array(labels)
        self.clinical_df = clinical_df
        self.sample_ids = sample_ids
        self.normalize = normalize
    
    def get_clinical_for_sample(self, sample_id):
        """Get clinical features for a specific sample."""
        if self.clinical_df is None:
            return None
        
                                                                      
        try:
            if sample_id in self.clinical_df.index:
                sample_row = self.clinical_df.loc[[sample_id]]
            elif '#SampleID' in self.clinical_df.columns:
                sample_row = self.clinical_df[self.clinical_df['#SampleID'] == sample_id]
            else:
                return None
        except (KeyError, TypeError):
            return None
        
        if len(sample_row) == 0:
            return None
        
                                                                                 
        features = sample_row.values[0]
        
        return features.astype(np.float32)
        
    def len(self):
        return len(self.graphs)
    
    def get(self, idx):
        """
        Get a single graph as PyTorch Geometric Data object.
        
        Parameters:
        -----------
        idx : int
            Index of the graph
            
        Returns:
        --------
        Data
            PyTorch Geometric Data object
        """
        graph = self.graphs[idx]
        label = self.labels[idx]
        
        node_features = []
        node_to_idx = {}
        for i, (node, data) in enumerate(graph.nodes(data=True)):
            node_to_idx[node] = i
            weight = data.get('weight', 1.0)
            node_features.append([weight])
        
        node_features = torch.tensor(node_features, dtype=torch.float)
        
        if self.normalize and node_features.numel() > 0:
                                                                                 
            max_val = node_features.abs().max()
            if max_val > 0:
                node_features = node_features / max_val
        
        edge_list = []
        edge_features = []
        
        for u, v, data in graph.edges(data=True):
            u_idx = node_to_idx[u]
            v_idx = node_to_idx[v]
            
            edge_list.append([u_idx, v_idx])
            edge_list.append([v_idx, u_idx])
            
            weight = data.get('weight', 1.0)
            edge_features.append([weight])
            edge_features.append([weight])
        
        if len(edge_list) > 0:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_features, dtype=torch.float)
            
            if self.normalize and edge_attr.numel() > 0:
                max_edge = edge_attr.abs().max()
                if max_edge > 0:
                    edge_attr = edge_attr / max_edge
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0, 1), dtype=torch.float)
        
        data = Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=torch.tensor([label], dtype=torch.float)
        )
        
        if self.clinical_df is not None and self.sample_ids is not None:
            sample_id = self.sample_ids[idx]
            clinical_features = self.get_clinical_for_sample(sample_id)
            if clinical_features is not None:
                clinical_tensor = torch.tensor(clinical_features, dtype=torch.float)
                if clinical_tensor.dim() == 0:
                    clinical_tensor = clinical_tensor.unsqueeze(0)
                data.clinical = clinical_tensor.flatten()
        
        return data

print("MicrobiomeGraphDataset class defined")
print(f"Features:")
print(f"- Node features: Normalized node weights")
print(f"- Edge features: Normalized edge weights")
print(f"- Labels: Binary classification ({DISEASE} vs Control)")


## Dataset Splitting and DataLoader Creation


In [None]:
print("Creating train/validation/test splits...")

test_size = DATASET_PARAMS["test_size"]
val_size = DATASET_PARAMS["val_size"]
random_seed = DATASET_PARAMS["random_seed"]
stratify = DATASET_PARAMS["stratify"]

stratify_labels = LABELS if stratify else None
train_val_graphs, test_graphs, train_val_labels, test_labels, train_val_ids, test_ids = train_test_split(
    GRAPHS_PREPARED, 
    LABELS,
    sample_ids,
    test_size=test_size,
    random_state=random_seed,
    stratify=stratify_labels
)

stratify_train_val = train_val_labels if stratify else None
train_graphs, val_graphs, train_labels, val_labels, train_ids, val_ids = train_test_split(
    train_val_graphs,
    train_val_labels,
    train_val_ids,
    test_size=val_size,
    random_state=random_seed,
    stratify=stratify_train_val
)

print(f" Dataset split complete")
print(f"\n Split statistics:")
print(f"Training set: {len(train_graphs):,} samples ({len(train_graphs)/len(GRAPHS_PREPARED)*100:.1f}%)")
print(f"   - {DISEASE} cases: {sum(train_labels):,} ({sum(train_labels)/len(train_labels)*100:.1f}%)")
print(f"   - Controls: {len(train_labels) - sum(train_labels):,} ({(len(train_labels)-sum(train_labels))/len(train_labels)*100:.1f}%)")

print(f"\n   Validation set: {len(val_graphs):,} samples ({len(val_graphs)/len(GRAPHS_PREPARED)*100:.1f}%)")
print(f"   - {DISEASE} cases: {sum(val_labels):,} ({sum(val_labels)/len(val_labels)*100:.1f}%)")
print(f"   - Controls: {len(val_labels) - sum(val_labels):,} ({(len(val_labels)-sum(val_labels))/len(val_labels)*100:.1f}%)")

print(f"\n   Test set: {len(test_graphs):,} samples ({len(test_graphs)/len(GRAPHS_PREPARED)*100:.1f}%)")
print(f"   - {DISEASE} cases: {sum(test_labels):,} ({sum(test_labels)/len(test_labels)*100:.1f}%)")
print(f"   - Controls: {len(test_labels) - sum(test_labels):,} ({(len(test_labels)-sum(test_labels))/len(test_labels)*100:.1f}%)")

print("\n Creating PyTorch Geometric datasets...")
train_dataset = MicrobiomeGraphDataset(
    train_graphs, train_labels, 
    clinical_df=clinical_df if use_clinical else None,
    sample_ids=train_ids if use_clinical else None,
    normalize=DATASET_PARAMS["normalize_features"]
)
val_dataset = MicrobiomeGraphDataset(
    val_graphs, val_labels,
    clinical_df=clinical_df if use_clinical else None,
    sample_ids=val_ids if use_clinical else None,
    normalize=DATASET_PARAMS["normalize_features"]
)
test_dataset = MicrobiomeGraphDataset(
    test_graphs, test_labels,
    clinical_df=clinical_df if use_clinical else None,
    sample_ids=test_ids if use_clinical else None,
    normalize=DATASET_PARAMS["normalize_features"]
)

print(f" Created datasets")
print(f"Training dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

print("\n Testing dataset loading...")
sample_data = train_dataset[0]
print(f"Sample Data object:")
print(f"- Node features shape: {sample_data.x.shape}")
print(f"- Edge index shape: {sample_data.edge_index.shape}")
print(f"- Edge features shape: {sample_data.edge_attr.shape}")
print(f"- Label: {sample_data.y.item()}")
print(f"- Number of nodes: {sample_data.num_nodes}")
print(f"- Number of edges: {sample_data.num_edges}")


### Create DataLoaders


In [None]:
print("Creating DataLoaders...")

batch_size = MODEL_PARAMS["batch_size"]

g = torch.Generator()
g.manual_seed(random_seed)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=0,
    generator=g
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f" DataLoaders created")
print(f"Batch size: {batch_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

print("\n Testing batch loading...")
for batch in train_loader:
    print(f"Sample batch:")
    print(f"- Batch size: {batch.num_graphs}")
    print(f"- Node features shape: {batch.x.shape}")
    print(f"- Edge index shape: {batch.edge_index.shape}")
    print(f"- Labels shape: {batch.y.shape}")
    print(f"- Batch vector shape: {batch.batch.shape}")
    break


## GNN Model Definition

### Graph Neural Network Architecture


In [None]:
print("Creating GNN model...")

device = torch.device(MODEL_PARAMS["device"])
print(f"Using device: {device}")

model_type = MODEL_PARAMS.get('model_type', 'GINEConv')

num_classes = EXPERIMENT_CONFIG.get('model_training', {}).get('architecture', {}).get('num_classes', 2) if EXPERIMENT_CONFIG else 2
use_clinical = EXPERIMENT_CONFIG.get('model_training', {}).get('architecture', {}).get('use_clinical_features', False) if EXPERIMENT_CONFIG else False
clinical_dim = clinical_df.shape[1] if use_clinical and 'clinical_df' in dir() and clinical_df is not None else 0

from src.models import (
    GNN_GCN, GNN_GINEConv, GNN_GAT, GNN_GraphSAGE, 
    EdgeCentricRGCN, MLP_Baseline, CNN_Baseline, get_model,
    GNN_Clinical
)

if model_type == 'MLP':
    model = MLP_Baseline(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)
    print(f" CONTROL EXPERIMENT: MLP Baseline (NO graph structure)")
    
elif model_type == 'GCN':
    model = GNN_GCN(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)
    print(f" GCN model created (simple GCN, NO edge features)")
    
elif model_type == 'GINEConv':
    model = GNN_GINEConv(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)
    print(f" GINEConv model created (uses edge features)")
    
elif model_type == 'GAT':
    model = GNN_GAT(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)
    print(f" GAT model created (attention mechanism)")
    
elif model_type == 'GraphSAGE':
    model = GNN_GraphSAGE(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)
    print(f" GraphSAGE model created")
    
elif model_type == 'EdgeCentricRGCN':
    model = EdgeCentricRGCN(
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_classes=num_classes,
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"]
    ).to(device)
    print(f" EdgeCentricRGCN model created (Professor's reference)")
    
elif model_type == 'CNN':
                                                                       
                                                 
    max_nodes = max(data.num_nodes for data in train_dataset)
    print(f"CONTROL EXPERIMENT: CNN Baseline (ignores graph structure)")
    print(f"   Max nodes in graphs: {max_nodes}")
    model = CNN_Baseline(
        input_dim=max_nodes,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_classes=num_classes,
        dropout=MODEL_PARAMS["dropout"]
    ).to(device)
    print(f" CNN_Baseline model created")
    
else:
    print(f" Unknown model type '{model_type}', falling back to GINEConv")
    model = GNN_GINEConv(
        input_dim=1,
        hidden_dim=MODEL_PARAMS["hidden_dim"],
        num_layers=MODEL_PARAMS["num_layers"],
        dropout=MODEL_PARAMS["dropout"],
        pooling=MODEL_PARAMS["pooling"],
        num_classes=num_classes
    ).to(device)

print(f"\n   Model type: {model_type}")
print(f"Hidden dimension: {MODEL_PARAMS['hidden_dim']}")
print(f"Number of layers: {MODEL_PARAMS['num_layers']}")
print(f"Dropout: {MODEL_PARAMS['dropout']}")
print(f"Pooling: {MODEL_PARAMS['pooling']}")
print(f"\n   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


## Model Training

### Training Loop with Early Stopping


In [None]:
def train_epoch(model, loader, optimizer, criterion, device, num_classes=2):
    """Train for one epoch (supports binary and multi-class)."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        out = model(batch)
        
        if num_classes == 2:
            loss = criterion(out, batch.y)
        else:
            loss = criterion(out, batch.y.long())
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item() * batch.num_graphs
        if num_classes == 2:
            pred = (out > 0.5).float()
            correct += (pred == batch.y).sum().item()
        else:
            pred = out.argmax(dim=1)
            correct += (pred == batch.y.long()).sum().item()
        total += batch.num_graphs
    
                                                         
    avg_loss = total_loss / total if total > 0 else 0.0
    accuracy = correct / total if total > 0 else 0.0
    
    return avg_loss, accuracy

@torch.no_grad()
def evaluate(model, loader, criterion, device, num_classes=2):
    """Evaluate model (supports binary and multi-class)."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    for batch in loader:
        batch = batch.to(device)
        
        out = model(batch)
        
        if num_classes == 2:
            loss = criterion(out, batch.y)
        else:
            loss = criterion(out, batch.y.long())
        
        total_loss += loss.item() * batch.num_graphs
        if num_classes == 2:
            pred = (out > 0.5).float()
            correct += (pred == batch.y).sum().item()
            all_probs.extend(out.cpu().numpy())
        else:
            pred = out.argmax(dim=1)
            correct += (pred == batch.y.long()).sum().item()
            all_probs.extend(F.softmax(out, dim=1).cpu().numpy())
        total += batch.num_graphs
        
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(batch.y.cpu().numpy())
    
                                                         
    avg_loss = total_loss / total if total > 0 else 0.0
    accuracy = correct / total if total > 0 else 0.0
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_labels), np.array(all_probs)

print("Training functions defined")
print(f"- train_epoch: Train for one epoch")
print(f"- evaluate: Evaluate model performance")


In [None]:
print("Setting up training...")

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=MODEL_PARAMS["learning_rate"],
    weight_decay=MODEL_PARAMS["weight_decay"]
)

num_classes = EXPERIMENT_CONFIG.get('model_training', {}).get('architecture', {}).get('num_classes', 2) if EXPERIMENT_CONFIG else 2

if num_classes == 2:
    criterion = nn.BCELoss()
    print(f"Loss function: Binary Cross-Entropy (2 classes)")
else:
    criterion = nn.CrossEntropyLoss()
    print(f"Loss function: Cross-Entropy ({num_classes} classes)")

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_loss = float('inf')
best_val_acc = 0.0
patience_counter = 0
patience = MODEL_PARAMS["patience"]

print(f"Optimizer: Adam")
print(f"Learning rate: {MODEL_PARAMS['learning_rate']}")
print(f"Weight decay: {MODEL_PARAMS['weight_decay']}")
print(f"Loss function: Binary Cross-Entropy")
print(f"Early stopping patience: {patience} epochs")

print(f"\n Starting training for {MODEL_PARAMS['num_epochs']} epochs...")
print("=" * 80)

num_epochs = MODEL_PARAMS["num_epochs"]

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, num_classes)
    
    val_loss, val_acc, val_preds, val_labels, val_probs = evaluate(model, val_loader, criterion, device, num_classes)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_acc
        patience_counter = 0
        best_model_state = model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, models_output_dir / 'best_model.pt')
    else:
        patience_counter += 1
        
    if patience_counter >= patience:
        print(f"\n  Early stopping triggered at epoch {epoch+1}")
        print(f"Best validation loss: {best_val_loss:.4f}")
        print(f"Best validation accuracy: {best_val_acc:.4f}")
        break

print("\n" + "=" * 80)
print(f" Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Total epochs: {epoch+1}")

model.load_state_dict(best_model_state)
print(f" Loaded best model from epoch {epoch+1}")


## Model Evaluation and Visualization

### Test Set Evaluation


In [None]:
# Final evaluation on test set
test_loss, test_acc, test_preds, test_labels, test_probs = evaluate(
    model, test_loader, criterion, device, num_classes
)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

# Import additional metrics
from sklearn.metrics import (
    accuracy_score, 
    roc_auc_score, 
    classification_report, 
    confusion_matrix,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score
)

# Calculate AUC-ROC
try:
    if num_classes == 2:
        test_auc = roc_auc_score(test_labels, test_probs)
    else:
        test_auc = roc_auc_score(test_labels, test_probs, multi_class='ovr')
    print(f"Test AUC-ROC: {test_auc:.4f}")
except:
    test_auc = None
    print("AUC-ROC: N/A (requires probability scores)")

# Calculate additional metrics
test_balanced_acc = balanced_accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='weighted')
test_precision = precision_score(test_labels, test_preds, average='weighted', zero_division=0)
test_recall = recall_score(test_labels, test_preds, average='weighted', zero_division=0)

print(f"Test Balanced Accuracy: {test_balanced_acc:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")

# Classification report
target_names = [f'Class {i}' for i in range(num_classes)]
if num_classes == 2:
    target_names = ['Control', DISEASE]

report = classification_report(
    test_labels, 
    test_preds,
    target_names=target_names,
    output_dict=True,
    zero_division=0
)

print("\nClassification Report:")
print(classification_report(
    test_labels, 
    test_preds, 
    target_names=target_names,
    zero_division=0
))

# Confusion matrix
conf_matrix = confusion_matrix(test_labels, test_preds)
print("\nConfusion Matrix:")
print(conf_matrix)

# Save comprehensive evaluation results
evaluation_results = {
    'test_loss': float(test_loss),
    'test_accuracy': float(test_acc),
    'test_balanced_accuracy': float(test_balanced_acc),
    'test_f1_score': float(test_f1),
    'test_precision': float(test_precision),
    'test_recall': float(test_recall),
    'test_auc_roc': float(test_auc) if test_auc is not None else None,
    'classification_report': report,
    'confusion_matrix': conf_matrix.tolist(),
    'num_test_samples': len(test_labels),
    'timestamp': datetime.now().isoformat()
}

results_path = results_output_dir / 'evaluation_results.json'
with open(results_path, 'w') as f:
    json.dump(evaluation_results, f, indent=2)

print(f"\nEvaluation results saved to: {results_path}")

## Optional: Cross-Validation Evaluation
### k-fold runs


In [None]:
enable_cv = False
cv_results = None

if EXPERIMENT_CONFIG is not None:
    eval_config = EXPERIMENT_CONFIG.get('model_training', {}).get('evaluation', {})
    enable_cv = eval_config.get('enable_cross_validation', False)
    n_folds = eval_config.get('n_folds', 5)
    num_runs_per_fold = eval_config.get('num_runs_per_experiment', 1)

if enable_cv:
    print(f"\n{'='*60}")
    print("Running Cross-Validation")
    print(f"{'='*60}")
    print(f"Folds: {n_folds}")
    print(f"Runs per fold: {num_runs_per_fold}")
    
    import sys
    sys.path.insert(0, str(Path.cwd().parent))
    try:
        from src.cross_validation import run_kfold_experiment, save_cv_results
        
        all_graphs = GRAPHS_PREPARED
                                                                                                    
        all_labels = LABELS
        
        if use_clinical and clinical_dim > 0:
            model_class = GNN_Clinical
            model_kwargs = {
                'input_dim': 1,
                'hidden_dim': MODEL_PARAMS['hidden_dim'],
                'num_layers': MODEL_PARAMS['num_layers'],
                'dropout': MODEL_PARAMS['dropout'],
                'pooling': MODEL_PARAMS['pooling'],
                'clinical_dim': clinical_dim,
                'clinical_hidden_dim': MODEL_PARAMS.get('clinical_hidden_dim', 32),
                'num_classes': num_classes
            }
        else:
            model_class = GNN_GINEConv
            model_kwargs = {
                'input_dim': 1,
                'hidden_dim': MODEL_PARAMS['hidden_dim'],
                'num_layers': MODEL_PARAMS['num_layers'],
                'dropout': MODEL_PARAMS['dropout'],
                'pooling': MODEL_PARAMS['pooling'],
                'num_classes': num_classes
            }
        
        training_kwargs = {
            'batch_size': MODEL_PARAMS['batch_size'],
            'learning_rate': MODEL_PARAMS['learning_rate'],
            'weight_decay': MODEL_PARAMS['weight_decay'],
            'num_epochs': MODEL_PARAMS['num_epochs'],
            'early_stopping_patience': MODEL_PARAMS.get('patience', 30)
        }
        
        cv_results = run_kfold_experiment(
            graphs=all_graphs,
            labels=all_labels,
            model_class=model_class,
            model_params=model_kwargs,
            training_params=training_kwargs,
            n_folds=n_folds,
            num_runs_per_fold=num_runs_per_fold,
            random_seed=random_seed,
            device=device,
            num_classes=num_classes,
            output_dir=results_output_dir / 'cross_validation',
            verbose=True,
            dataset_class=MicrobiomeGraphDataset,
            clinical_df=clinical_df if use_clinical else None,
            sample_ids=sample_ids if use_clinical else None
        )
        
        cv_path = results_output_dir / 'cv_results.json'
        save_cv_results(cv_results, cv_path)
        print(f"\n Cross-validation results saved to: {cv_path}")
        
        evaluation_results['cross_validation'] = {
            'enabled': True,
            'n_folds': n_folds,
            'num_runs_per_fold': num_runs_per_fold,
            'aggregated': cv_results.get('aggregated', {})
        }
        
        with open(results_path, 'w') as f:
            json.dump(evaluation_results, f, indent=2)
        print(f" Updated evaluation results with cross-validation metrics")
        
    except ImportError as e:
        print(f"  Cross-validation module not available: {e}")
        print(f"Skipping cross-validation...")
    except Exception as e:
        print(f" Cross-validation failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("\n Cross-validation disabled (set enable_cross_validation: true in config to enable)")


### Training Curves Visualization


In [None]:
print("Generating training visualization...")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=VIZ_PARAMS["figsize"])

ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
ax2.plot(history['val_acc'], label='Val Accuracy', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()

if VIZ_PARAMS["save_plots"]:
    plot_path = visualizations_dir / "training_curves.png"
    plt.savefig(plot_path, dpi=VIZ_PARAMS["dpi"], bbox_inches='tight')
    print(f" Saved training curves to: {plot_path}")

if VIZ_PARAMS["show_plots"]:
    plt.show()
else:
    plt.close()

print(f" Training curves generated")


### Confusion Matrix and ROC Curve


In [None]:
print("Generating evaluation visualizations...")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=VIZ_PARAMS["figsize"])

cm_labels = ['Control', disease_label] if 'disease_label' in dir() else ['Control', 'Case']
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=cm_labels, 
            yticklabels=cm_labels,
            ax=ax1, cbar_kws={'label': 'Count'})
ax1.set_xlabel('Predicted', fontsize=12)
ax1.set_ylabel('Actual', fontsize=12)
ax1.set_title('Confusion Matrix', fontsize=14, fontweight='bold')

if test_auc is not None:
    from sklearn.metrics import roc_curve
    fpr, tpr, thresholds = roc_curve(test_labels, test_probs)
    
    ax2.plot(fpr, tpr, linewidth=2, label=f'ROC Curve (AUC = {test_auc:.4f})')
    ax2.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
    ax2.set_xlabel('False Positive Rate', fontsize=12)
    ax2.set_ylabel('True Positive Rate', fontsize=12)
    ax2.set_title('ROC Curve', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'ROC Curve\nCould not be computed', 
             ha='center', va='center', fontsize=14)
    ax2.set_title('ROC Curve', fontsize=14, fontweight='bold')

plt.tight_layout()

if VIZ_PARAMS["save_plots"]:
    plot_path = visualizations_dir / "confusion_matrix_roc.png"
    plt.savefig(plot_path, dpi=VIZ_PARAMS["dpi"], bbox_inches='tight')
    print(f" Saved confusion matrix and ROC curve to: {plot_path}")

if VIZ_PARAMS["show_plots"]:
    plt.show()
else:
    plt.close()

print(f" Evaluation visualizations generated")


In [None]:
print("Saving processed datasets...")

torch.save({
    'train_dataset': train_dataset,
    'val_dataset': val_dataset,
    'test_dataset': test_dataset,
    'train_graphs': train_graphs,
    'val_graphs': val_graphs,
    'test_graphs': test_graphs,
    'train_labels': train_labels,
    'val_labels': val_labels,
    'test_labels': test_labels
}, pytorch_output_dir / f'{DISEASE.lower()}_dataset.pt')

print(f" Complete dataset saved to: {pytorch_output_dir / f'{DISEASE.lower()}_dataset.pt'}")

torch.save(train_dataset, pytorch_output_dir / 'train_dataset.pt')
torch.save(val_dataset, pytorch_output_dir / 'val_dataset.pt')
torch.save(test_dataset, pytorch_output_dir / 'test_dataset.pt')

print(f" Individual splits saved")

dataset_stats = {
    'total_samples': len(GRAPHS_PREPARED),
    'train_samples': len(train_dataset),
    'val_samples': len(val_dataset),
    'test_samples': len(test_dataset),
    'train_case_ratio': float(sum(train_labels) / len(train_labels)),
    'val_case_ratio': float(sum(val_labels) / len(val_labels)),
    'test_case_ratio': float(sum(test_labels) / len(test_labels)),
    'avg_nodes': float(np.mean([G.number_of_nodes() for G in GRAPHS_PREPARED])),
    'avg_edges': float(np.mean([G.number_of_edges() for G in GRAPHS_PREPARED])),
    'timestamp': datetime.now().isoformat()
}

stats_path = pytorch_output_dir / 'dataset_statistics.json'
with open(stats_path, 'w') as f:
    json.dump(dataset_stats, f, indent=2)

print(f" Dataset statistics saved to: {stats_path}")

history_path = results_output_dir / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

print(f" Training history saved to: {history_path}")

print(f"\n All outputs saved successfully!")


### Pipeline Summary


In [None]:
print("{DISEASE} PyTorch Geometric Pipeline Complete!")
print("=" * 80)

print(f"\n Dataset Summary:")
print(f"Disease: {DISEASE}")
print(f"Total graphs: {len(GRAPHS_PREPARED):,}")
print(f"Training set: {len(train_dataset):,} ({len(train_dataset)/len(GRAPHS_PREPARED)*100:.1f}%)")
print(f"Validation set: {len(val_dataset):,} ({len(val_dataset)/len(GRAPHS_PREPARED)*100:.1f}%)")
print(f"Test set: {len(test_dataset):,} ({len(test_dataset)/len(GRAPHS_PREPARED)*100:.1f}%)")

print(f"\n Graph Statistics:")
print(f"Average nodes per graph: {dataset_stats['avg_nodes']:.1f}")
print(f"Average edges per graph: {dataset_stats['avg_edges']:.1f}")
print(f"Node feature dimension: 1 (normalized weights)")
print(f"Edge feature dimension: 1 (normalized weights)")

print(f"\n Model Summary:")
print(f"Architecture: GNN with GINEConv layers")
print(f"Hidden dimension: {MODEL_PARAMS['hidden_dim']}")
print(f"Number of layers: {MODEL_PARAMS['num_layers']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\n Training Results:")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Total epochs: {len(history['train_loss'])}")

print(f"\n Test Set Performance:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
if test_auc is not None:
    print(f"Test AUC-ROC: {test_auc:.4f}")

print(f"\n Output Directory Structure:")
print(f"{output_dir}/")
print(f" pytorch_geometric/")
print(f"{DISEASE.lower()}_dataset.pt")
print(f"train_dataset.pt")
print(f"val_dataset.pt")
print(f"test_dataset.pt")
print(f"dataset_statistics.json")
print(f"dataset_config.json")
print(f" models/")
print(f"best_model.pt")
print(f" results/")
print(f"evaluation_results.json")
print(f"training_history.json")
print(f" visualizations/")
print(f" dataset_analysis/")
print(f"     training_curves.png")
print(f"     confusion_matrix_roc.png")

print(f"\n Next Steps:")
print(f"1. Analyze the trained model performance")
print(f"2. Use the saved model for inference on new samples")
print(f"3. Experiment with different model architectures")
print(f"4. Fine-tune hyperparameters for better performance")
print(f"5. Explore feature importance and graph structure analysis")

print(f"\n Model Loading:")
print(f"To load the best model:")
print(f"```python")
print(f"checkpoint = torch.load('{models_output_dir}/best_model.pt')")
print(f"model.load_state_dict(checkpoint['model_state_dict'])")
print(f"```")

print(f"\n Dataset Loading:")
print(f"To load the dataset:")
print(f"```python")
print(f"data = torch.load('{pytorch_output_dir}/{DISEASE.lower()}_dataset.pt')")
print(f"train_dataset = data['train_dataset']")
print(f"```")

print(f"\n Pipeline completed successfully!")
print(f" Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 80)
