<a href="https://colab.research.google.com/github/zbovaird/UHG-Models/blob/main/uhg_intrusion_detection5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install PyTorch Geometric (matches your current torch/cuda)
!pip -q install --upgrade pip
import torch
pt = torch.__version__.split('+')[0]
cuda = torch.version.cuda
if torch.cuda.is_available() and cuda:
  idx = f"https://data.pyg.org/whl/torch-{pt}+cu{cuda.replace('.','')}.html"
else:
  idx = f"https://data.pyg.org/whl/torch-{pt}+cpu.html"

!pip -q install torch_scatter torch_sparse torch_cluster torch_spline_conv -f {idx}
!pip -q install torch_geometric scikit-learn scipy pandas tqdm

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m1.1/1.8 MB[0m [31m32.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:

"""
Intrusion Detection using Universal Hyperbolic Geometry (UHG) - GPT5 v2
Based on original GPT5 training approach with comprehensive v4.5 metrics added.

Configuration (Matched to v4.5 for Direct Comparison):
- 10% data sampling (matching v4.5)
- k=2 neighbors (matching v4.5)
- No class weights (vs v4.5 with class weights)
- No PCA (vs v4.5 with PCA 77→20 dims)
- Standard cross-entropy loss (vs v4.5 with weighted loss)
- n_jobs=-1 for KNN (vs v4.5 with n_jobs=4)

This configuration isolates the impact of:
1. Class weighting
2. PCA dimensionality reduction

v2 Additions:
- Comprehensive timing instrumentation and bottleneck analysis
- GPU hardware detection and memory tracking
- Detailed per-class evaluation with classification report
- UHG constraint verification (initial and post-training)
- Enhanced progress reporting
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from scipy.sparse import coo_matrix
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from typing import Tuple
import os
import sys
import time
import json
import traceback
import platform
from datetime import datetime

# Optional: Drive mount (only in Colab)
try:
    from google.colab import drive
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
except Exception:
    pass

# Device configuration with detailed GPU info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("\n" + "="*80)
print("🖥️  HARDWARE CONFIGURATION")
print("="*80)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
    cuda_version = torch.version.cuda
    gpu_capability = torch.cuda.get_device_capability(0)

    print(f"✅ GPU Detected:")
    print(f"   • Model: {gpu_name}")
    print(f"   • Memory: {gpu_memory:.1f} GB")
    print(f"   • CUDA Version: {cuda_version}")
    print(f"   • Compute Capability: {gpu_capability[0]}.{gpu_capability[1]}")
    print(f"   • Device: cuda:0")
else:
    print(f"⚠️  No GPU available - using CPU")
    print(f"   • This will be significantly slower for training")

print("="*80 + "\n")

# File paths (same as base)
FILE_PATH = '/content/drive/MyDrive/CIC_data.csv'
MODEL_SAVE_PATH = '/content/drive/MyDrive/uhg_ids_model_gpt5_v2.pth'
RESULTS_PATH = '/content/drive/MyDrive/uhg_ids_results'
os.makedirs(RESULTS_PATH, exist_ok=True)

# ========================
# Metrics/Env helpers
# ========================

def get_env_info() -> dict:
    info = {
        'python': sys.version.split()[0],
        'platform': platform.platform(),
        'torch': torch.__version__,
        'cuda_available': torch.cuda.is_available(),
        'cudnn_enabled': torch.backends.cudnn.enabled,
    }
    if torch.cuda.is_available():
        info.update({
            'gpu_name': torch.cuda.get_device_name(0),
            'cuda_capability': torch.cuda.get_device_capability(0),
            'cuda_version': torch.version.cuda,
        })
    return info

def save_json(obj: dict, path: str) -> None:
    try:
        with open(path, 'w') as f:
            json.dump(obj, f, indent=2)
        print(f"Saved metrics to: {path}")
    except Exception as e:
        print(f"Failed to save metrics JSON: {e}")

# ===============
# UHG primitives
# ===============

def minkowski_dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Minkowski inner product with J=diag(1,1,-1) for final dim = 3 or D.
    ⟨x,y⟩ = Σ_i x_i y_i (i < D) − x_D y_D
    """
    return (x[..., :-1] * y[..., :-1]).sum(dim=-1) - x[..., -1] * y[..., -1]

def projective_normalize(points: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """Ensure last coordinate is time-like such that x^2 + y^2 − z^2 = −1.
    Recompute z = sqrt(1 + ||spatial||^2) for stability.
    """
    spatial = points[..., :-1]
    time_like = torch.sqrt(torch.clamp(1.0 + (spatial * spatial).sum(dim=-1, keepdim=True), min=eps))
    return torch.cat([spatial, time_like], dim=-1)

def uhg_quadrance_vectorized(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """Projective quadrance q = 1 − (⟨x,y⟩^2) / (⟨x,x⟩⟨y,y⟩). Stable and batched."""
    xx = minkowski_dot(x, x)
    yy = minkowski_dot(y, y)
    xy = minkowski_dot(x, y)
    denom = torch.clamp(xx * yy, min=eps)
    q = 1.0 - (xy * xy) / denom
    return torch.clamp(q, 0.0, 1.0)

def uhg_spread_vectorized(L: torch.Tensor, M: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """Spread between lines: S = 1 − (⟨L,M⟩^2) / (⟨L,L⟩⟨M,M⟩). Stable and batched."""
    LL = minkowski_dot(L, L)
    MM = minkowski_dot(M, M)
    LM = minkowski_dot(L, M)
    denom = torch.clamp(LL * MM, min=eps)
    s = 1.0 - (LM * LM) / denom
    return torch.clamp(s, 0.0, 1.0)

def verify_uhg_constraints(x: torch.Tensor, eps: float = 1e-3, name: str = "features") -> bool:
    """Verify Minkowski norm constraint: x² + y² - z² = -1"""
    spatial_norm = (x[:, :-1] ** 2).sum(dim=-1)
    time_norm = x[:, -1] ** 2
    minkowski_norm = spatial_norm - time_norm

    # Should be -1 for all points
    violation = torch.abs(minkowski_norm + 1.0)
    max_violation = violation.max().item()
    mean_violation = violation.mean().item()

    print(f"UHG Constraint Check ({name}):")
    print(f"  Max violation: {max_violation:.6f}")
    print(f"  Mean violation: {mean_violation:.6f}")

    if max_violation > eps:
        print(f"  ⚠️ WARNING: Constraints violated!")
        return False
    else:
        print(f"  ✅ Constraints satisfied")
        return True

# ==============================
# Data loading / preprocessing
# ==============================

def load_and_preprocess_data(file_path: str = FILE_PATH, sample_frac: float = 0.10) -> Tuple[torch.Tensor, torch.Tensor, dict, dict]:
    """Load and preprocess data with detailed timing measurements"""
    timings = {}

    print(f"\nLoading data from: {file_path}")
    t0 = time.perf_counter()
    data = pd.read_csv(file_path, low_memory=False)
    timings['csv_read'] = time.perf_counter() - t0
    print(f"  ⏱️  CSV read: {timings['csv_read']:.2f}s")

    t0 = time.perf_counter()
    data.columns = data.columns.str.strip()
    data['Label'] = data['Label'].str.strip()
    timings['column_cleanup'] = time.perf_counter() - t0

    unique_labels = data['Label'].unique()
    print(f"\nUnique labels in the dataset: {unique_labels}")
    label_counts = data['Label'].value_counts()
    print("\nLabel distribution in the dataset:")
    print(label_counts)

    # Sample data (10% - matching v4.5 for comparison)
    print(f"\nApplying random sampling (frac={sample_frac})...")
    t0 = time.perf_counter()
    data_sampled = data.sample(frac=sample_frac, random_state=42)
    timings['sampling'] = time.perf_counter() - t0
    print(f"  ⏱️  Sampling: {timings['sampling']:.2f}s")

    print(f"\nSampled label distribution:")
    sampled_label_counts = data_sampled['Label'].value_counts()
    print(sampled_label_counts)

    # Convert to numeric and handle missing values
    t0 = time.perf_counter()
    data_numeric = data_sampled.apply(pd.to_numeric, errors='coerce')
    timings['to_numeric'] = time.perf_counter() - t0
    print(f"  ⏱️  Convert to numeric: {timings['to_numeric']:.2f}s")

    # Fill NaN and inf
    t0 = time.perf_counter()
    data_filled = data_numeric.fillna(data_numeric.mean())
    data_filled = data_filled.replace([np.inf, -np.inf], np.nan)
    data_filled = data_filled.fillna(data_filled.max())
    if data_filled.isnull().values.any():
        data_filled = data_filled.fillna(0)
    timings['fillna'] = time.perf_counter() - t0
    print(f"  ⏱️  Fill NaN/inf: {timings['fillna']:.2f}s")

    labels = data_sampled['Label']
    features = data_filled.drop(columns=['Label'])

    t0 = time.perf_counter()
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)
    timings['scaling'] = time.perf_counter() - t0
    print(f"  ⏱️  Scaling: {timings['scaling']:.2f}s")

    t0 = time.perf_counter()
    node_features = torch.tensor(features_scaled, dtype=torch.float32)
    label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
    labels_numeric = labels.map(label_mapping).values
    labels_tensor = torch.tensor(labels_numeric, dtype=torch.long)
    timings['to_tensor'] = time.perf_counter() - t0
    print(f"  ⏱️  Convert to tensors: {timings['to_tensor']:.2f}s")

    print("\nPreprocessing complete.")
    print(f"Feature shape: {node_features.shape}")
    print(f"Number of unique labels: {len(unique_labels)}")

    timings['total'] = sum(timings.values())
    print(f"\n⏱️  Total data loading time: {timings['total']:.2f}s")

    return node_features, labels_tensor, label_mapping, timings

# =========================
# Graph construction (KNN)
# =========================

def create_graph_data(node_features: torch.Tensor, labels: torch.Tensor, k: int = 2) -> Tuple[Data, dict]:
    """Original GPT5: k=2, no PCA, n_jobs=-1"""
    timings = {}

    print("\nCreating graph structure...")
    t0 = time.perf_counter()
    features_np = node_features.cpu().numpy()
    timings['to_numpy'] = time.perf_counter() - t0

    print(f"\nComputing KNN graph with k={k}...")
    print(f"  • Input shape: {features_np.shape}")
    print(f"  • Number of samples: {features_np.shape[0]:,}")
    print(f"  • Number of features: {features_np.shape[1]}")
    print(f"  • Using n_jobs=-1 (all CPU cores)")
    print(f"  • No PCA (77 features - vs v4.5 with 20 features)")
    print(f"  • Expected: Slower KNN than v4.5 due to higher dimensionality")

    import sys
    sys.stdout.flush()

    t0 = time.perf_counter()
    knn_graph = kneighbors_graph(
        features_np,
        k,
        mode='connectivity',
        include_self=False,
        n_jobs=-1
    )
    timings['knn_computation'] = time.perf_counter() - t0
    print(f"  ✅ KNN computation: {timings['knn_computation']:.2f}s")

    t0 = time.perf_counter()
    knn_graph_coo = coo_matrix(knn_graph)
    edge_index = torch.from_numpy(
        np.vstack((knn_graph_coo.row, knn_graph_coo.col))
    ).long().to(device)
    timings['edge_index_creation'] = time.perf_counter() - t0
    print(f"  ⏱️  Edge index creation: {timings['edge_index_creation']:.2f}s")

    print(f"Edge index shape: {edge_index.shape}")

    # Add homogeneous coordinate (projective)
    t0 = time.perf_counter()
    node_features_uhg = torch.cat([
        node_features.to(device),
        torch.ones(node_features.size(0), 1, device=device)
    ], dim=1)
    timings['add_homogeneous'] = time.perf_counter() - t0

    t0 = time.perf_counter()
    node_features_uhg = projective_normalize(node_features_uhg)
    timings['projective_normalize'] = time.perf_counter() - t0
    print(f"  ⏱️  UHG projection: {timings['projective_normalize']:.2f}s")

    print(f"Feature shape with homogeneous coordinate: {node_features_uhg.shape}")

    # Verify UHG constraints
    t0 = time.perf_counter()
    verify_uhg_constraints(node_features_uhg, name="initial features")
    timings['constraint_verification'] = time.perf_counter() - t0

    t0 = time.perf_counter()
    total_samples = len(node_features_uhg)
    indices = torch.randperm(total_samples)
    train_size = int(0.7 * total_samples)
    val_size = int(0.15 * total_samples)

    train_mask = torch.zeros(total_samples, dtype=torch.bool, device=device)
    val_mask = torch.zeros(total_samples, dtype=torch.bool, device=device)
    test_mask = torch.zeros(total_samples, dtype=torch.bool, device=device)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size+val_size]] = True
    test_mask[indices[train_size+val_size:]] = True
    timings['split_creation'] = time.perf_counter() - t0

    print(f"\nTrain size: {train_mask.sum()}, Val size: {val_mask.sum()}, Test size: {test_mask.sum()}")

    timings['total'] = sum(timings.values())
    print(f"\n⏱️  Total graph construction time: {timings['total']:.2f}s")

    return Data(
        x=node_features_uhg,
        edge_index=edge_index,
        y=labels.to(device),
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    ).to(device), timings

# ==============================
# UHG GraphSAGE Message Passing
# ==============================

from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing

class UHGMessagePassing(MessagePassing):
    def __init__(self, in_features: int, out_features: int):
        super().__init__(aggr='add')
        self.in_features = in_features
        self.out_features = out_features
        self.weight_msg = nn.Parameter(torch.Tensor(in_features, out_features))
        self.weight_node = nn.Parameter(torch.Tensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_msg)
        nn.init.xavier_uniform_(self.weight_node)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        # x includes homogeneous coord
        # Transform node features (spatial only)
        features = x[:, :-1]
        z = x[:, -1:]
        transformed_features = features @ self.weight_node
        # Propagate using full projective vectors for weight computation
        out = self.propagate(edge_index, x=x, size=None)
        # Combine
        out = out + transformed_features
        # Recompute time-like to maintain Minkowski norm -1
        out_full = torch.cat([out, z], dim=1)
        out_full = projective_normalize(out_full)
        return out_full

    def message(self, x_i: torch.Tensor, x_j: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        # x_i, x_j are full projective vectors
        weights = torch.exp(-uhg_quadrance_vectorized(x_i, x_j))
        # Transform neighbor features (spatial only)
        messages = (x_j[:, :-1]) @ self.weight_msg
        return messages * weights.view(-1, 1)

    def aggregate(self, inputs: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
        # Sum messages per destination
        numerator = scatter_add(inputs, index, dim=0)
        # Sum weights per destination (approximate by ones per feature dim)
        weights_sum = scatter_add(torch.ones_like(inputs), index, dim=0)
        return numerator / torch.clamp(weights_sum, min=1e-6)

class UHGGraphSAGE(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, dropout: float = 0.2):
        super().__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        # in_channels includes homogeneous coord
        actual_in = in_channels - 1
        self.layers.append(UHGMessagePassing(actual_in, hidden_channels))
        for _ in range(num_layers - 2):
            self.layers.append(UHGMessagePassing(hidden_channels, hidden_channels))
        self.layers.append(UHGMessagePassing(hidden_channels, out_channels))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        h = x
        for layer in self.layers[:-1]:
            h = layer(h, edge_index)
            # Apply ReLU on spatial part only
            spatial = F.relu(h[:, :-1])
            h = torch.cat([spatial, h[:, -1:]], dim=1)
            h = self.dropout(h)
        h = self.layers[-1](h, edge_index)
        return h[:, :-1]  # logits on spatial part

# =====================
# Training / Evaluation
# =====================

def train_epoch(model: nn.Module, graph_data: Data, optimizer: torch.optim.Optimizer, criterion: nn.Module, detailed_timing: bool = False) -> Tuple[float, dict]:
    """Train one epoch with optional detailed timing"""
    model.train()
    timings = {}

    try:
        t0 = time.perf_counter()
        optimizer.zero_grad(set_to_none=True)
        if detailed_timing:
            timings['zero_grad'] = time.perf_counter() - t0

        # Single full-batch forward/backward on the static graph
        t0 = time.perf_counter()
        out = model(graph_data.x, graph_data.edge_index)
        if detailed_timing:
            timings['forward_pass'] = time.perf_counter() - t0

        t0 = time.perf_counter()
        loss = criterion(out[graph_data.train_mask], graph_data.y[graph_data.train_mask])
        if detailed_timing:
            timings['loss_computation'] = time.perf_counter() - t0

        t0 = time.perf_counter()
        loss.backward()
        if detailed_timing:
            timings['backward_pass'] = time.perf_counter() - t0

        t0 = time.perf_counter()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        if detailed_timing:
            timings['grad_clipping'] = time.perf_counter() - t0

        t0 = time.perf_counter()
        optimizer.step()
        if detailed_timing:
            timings['optimizer_step'] = time.perf_counter() - t0
            timings['total'] = sum(timings.values())

        return float(loss.item()), timings
    except Exception as e:
        print(f"Train step failure: {e}")
        traceback.print_exc()
        raise

@torch.no_grad()
def evaluate(model: nn.Module, graph_data: Data, mask: torch.Tensor, detailed_timing: bool = False) -> Tuple[float, dict]:
    """Evaluate with optional detailed timing"""
    timings = {}
    model.eval()

    t0 = time.perf_counter()
    out = model(graph_data.x, graph_data.edge_index)
    if detailed_timing:
        timings['forward_pass'] = time.perf_counter() - t0

    t0 = time.perf_counter()
    pred = out[mask].argmax(dim=1)
    acc = (pred == graph_data.y[mask]).float().mean().item()
    if detailed_timing:
        timings['prediction'] = time.perf_counter() - t0
        timings['total'] = sum(timings.values())

    return acc, timings

@torch.no_grad()
def evaluate_detailed(model: nn.Module, graph_data: Data, mask: torch.Tensor, label_mapping: dict, phase: str = "Test") -> dict:
    """Detailed per-class evaluation"""
    model.eval()
    out = model(graph_data.x, graph_data.edge_index)
    pred = out[mask].argmax(dim=1).cpu().numpy()
    true = graph_data.y[mask].cpu().numpy()

    # Reverse label mapping
    idx_to_label = {v: k for k, v in label_mapping.items()}

    # Only include classes that actually appear in test set
    unique_classes = np.unique(np.concatenate([true, pred]))
    target_names = [idx_to_label[i] for i in unique_classes]

    # Show which classes are missing
    all_classes = set(range(len(label_mapping)))
    present_classes = set(unique_classes)
    missing_classes = all_classes - present_classes

    print(f"\n{'='*80}")
    print(f"{phase} Set - Detailed Performance Report")
    print(f"{'='*80}")

    if missing_classes:
        print(f"\n⚠️  WARNING: {len(missing_classes)} classes not present in {phase.lower()} set:")
        for class_idx in sorted(missing_classes):
            print(f"  • {idx_to_label[class_idx]}")
        print(f"  (This is normal with small sample sizes and rare classes)")

    # Overall accuracy
    overall_acc = (pred == true).mean()
    print(f"\nOverall Accuracy: {overall_acc:.4f}")
    print(f"Classes evaluated: {len(unique_classes)}/{len(label_mapping)}")

    # Per-class metrics (only for classes present in test set)
    print("\nPer-Class Classification Report:")
    print(classification_report(true, pred, labels=unique_classes, target_names=target_names, zero_division=0, digits=4))

    # Confusion matrix (abbreviated)
    cm = confusion_matrix(true, pred)
    print("\nPer-Class Accuracy:")
    for i, label in enumerate(target_names):
        class_acc = cm[i, i] / cm[i].sum() if cm[i].sum() > 0 else 0.0
        class_samples = cm[i].sum()
        print(f"  {label:30s}: {class_acc:.4f} ({int(class_samples)} samples)")

    # Macro and weighted F1
    f1_macro = f1_score(true, pred, average='macro', zero_division=0)
    f1_weighted = f1_score(true, pred, average='weighted', zero_division=0)
    print(f"\nF1 Score (Macro):    {f1_macro:.4f}")
    print(f"F1 Score (Weighted): {f1_weighted:.4f}")

    return {
        'accuracy': float(overall_acc),
        'f1_macro': float(f1_macro),
        'f1_weighted': float(f1_weighted),
        'confusion_matrix': cm.tolist(),
    }

def main():
    run_started = time.perf_counter()
    run_id = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
    metrics = {
        'version': 'GPT5 v2',
        'run_id': run_id,
        'env': get_env_info(),
        'paths': {
            'file_path': FILE_PATH,
            'model_save_path': MODEL_SAVE_PATH,
            'results_path': RESULTS_PATH,
        },
        'configuration': [
            'Original GPT5 training approach (modified for comparison)',
            '10% data sampling (matching v4.5)',
            'k=2 neighbors (matching v4.5)',
            'No PCA dimensionality reduction',
            'Standard cross-entropy loss (no class weighting)',
            'n_jobs=-1 for KNN',
        ],
        'v2_additions': [
            'Comprehensive timing instrumentation and bottleneck analysis',
            'GPU hardware detection and memory tracking',
            'Detailed per-class evaluation with classification report',
            'UHG constraint verification (initial and post-training)',
            'Enhanced progress reporting',
        ],
        'data': {},
        'graph': {},
        'model': {},
        'train': {
            'epochs': [],
            'best_val': 0.0,
            'best_epoch': None,
        },
        'errors': None,
        'timing': {},
        'gpu_memory': {},
    }

    try:
        # Data loading with detailed timing
        node_features, labels, label_mapping, data_timings = load_and_preprocess_data(FILE_PATH, sample_frac=0.10)

        metrics['data'] = {
            'num_nodes': int(node_features.size(0)),
            'num_features': int(node_features.size(1)),
            'num_classes': int(len(label_mapping)),
            'sample_fraction': 0.10,
        }
        metrics['timing']['data_load'] = data_timings

        # Graph construction with detailed timing
        graph_data, graph_timings = create_graph_data(node_features, labels, k=2)

        metrics['timing']['graph_build'] = graph_timings
        metrics['graph'] = {
            'num_nodes': int(graph_data.x.size(0)),
            'num_edges': int(graph_data.edge_index.size(1)),
            'k_neighbors': 2,
            'pca_enabled': False,
            'train_nodes': int(graph_data.train_mask.sum().item()),
            'val_nodes': int(graph_data.val_mask.sum().item()),
            'test_nodes': int(graph_data.test_mask.sum().item()),
        }

        in_channels = graph_data.x.size(1)
        hidden_channels = 64
        out_channels = len(label_mapping)
        num_layers = 2

        model = UHGGraphSAGE(in_channels, hidden_channels, out_channels, num_layers).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

        # Standard cross-entropy loss (no class weighting - original GPT5)
        criterion = nn.CrossEntropyLoss()
        print(f"\n✅ Using standard CrossEntropyLoss (no class weighting)")

        n_params = sum(p.numel() for p in model.parameters())
        metrics['model'] = {
            'hidden_channels': hidden_channels,
            'out_channels': out_channels,
            'num_layers': num_layers,
            'num_parameters': int(n_params),
            'class_weighted_loss': False,
        }

        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            metrics['gpu_memory']['before_allocated'] = int(torch.cuda.memory_allocated())
            metrics['gpu_memory']['before_reserved'] = int(torch.cuda.memory_reserved())
            print(f"\n💾 GPU Memory (before training):")
            print(f"   • Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
            print(f"   • Reserved:  {torch.cuda.memory_reserved()/1024**3:.2f} GB")

        best_val = 0.0
        patience = 20
        no_improve = 0
        num_epochs = 200

        print("\nStarting training...")
        train_start = time.perf_counter()

        # Store detailed timing for specific epochs
        detailed_timing_epochs = [1, 2, 50, 100]
        epoch_timings_detailed = {}

        for epoch in range(1, num_epochs + 1):
            epoch_t0 = time.perf_counter()

            # Enable detailed timing for specific epochs
            detailed = epoch in detailed_timing_epochs

            loss, train_timings = train_epoch(model, graph_data, optimizer, criterion, detailed_timing=detailed)
            val_acc, val_timings = evaluate(model, graph_data, graph_data.val_mask, detailed_timing=detailed)
            test_acc, test_timings = evaluate(model, graph_data, graph_data.test_mask, detailed_timing=detailed)

            scheduler.step(val_acc)
            lr = optimizer.param_groups[0]['lr']
            epoch_time = time.perf_counter() - epoch_t0

            epoch_metrics = {
                'epoch': epoch,
                'loss': float(loss),
                'val_acc': float(val_acc),
                'test_acc': float(test_acc),
                'lr': float(lr),
                'time_s': float(epoch_time),
            }

            # Store detailed timing breakdowns
            if detailed:
                epoch_timings_detailed[f'epoch_{epoch}'] = {
                    'train': train_timings,
                    'val': val_timings,
                    'test': test_timings,
                    'total': epoch_time,
                }
                print(f"\n⏱️  Epoch {epoch} Detailed Timing:")
                print(f"    Train: Forward={train_timings.get('forward_pass', 0):.3f}s, Backward={train_timings.get('backward_pass', 0):.3f}s, Optimizer={train_timings.get('optimizer_step', 0):.3f}s")
                print(f"    Val:   Forward={val_timings.get('forward_pass', 0):.3f}s")
                print(f"    Test:  Forward={test_timings.get('forward_pass', 0):.3f}s")

            metrics['train']['epochs'].append(epoch_metrics)

            if val_acc > best_val:
                best_val = val_acc
                no_improve = 0
                metrics['train']['best_val'] = float(best_val)
                metrics['train']['best_epoch'] = int(epoch)
                torch.save(model.state_dict(), MODEL_SAVE_PATH)
                print(f"Epoch {epoch:03d} | Loss {loss:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f} | LR {lr:.5f} | {epoch_time:.2f}s | (saved)")
            else:
                no_improve += 1
                if epoch % 10 == 0:
                    print(f"Epoch {epoch:03d} | Loss {loss:.4f} | Val {val_acc:.4f} | Test {test_acc:.4f} | LR {lr:.5f} | {epoch_time:.2f}s")
                if no_improve >= patience:
                    print("Early stopping.")
                    break

        train_total_time = time.perf_counter() - train_start
        metrics['timing']['train_total_s'] = train_total_time
        metrics['timing']['epoch_details'] = epoch_timings_detailed

        # Calculate average epoch time
        avg_epoch_time = train_total_time / epoch
        print(f"\n⏱️  Average epoch time: {avg_epoch_time:.2f}s")

        if torch.cuda.is_available():
            metrics['gpu_memory']['after_allocated'] = int(torch.cuda.memory_allocated())
            metrics['gpu_memory']['after_reserved'] = int(torch.cuda.memory_reserved())
            metrics['gpu_memory']['peak_allocated'] = int(torch.cuda.max_memory_allocated())

            print(f"\n💾 GPU Memory Usage Summary:")
            print(f"   • Peak Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
            print(f"   • Final Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
            print(f"   • Final Reserved: {torch.cuda.memory_reserved()/1024**3:.2f} GB")

        if os.path.exists(MODEL_SAVE_PATH):
            print("\nLoading best model for final evaluation...")
            model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))

            # Verify UHG constraints after training
            print("\n" + "="*80)
            print("POST-TRAINING UHG CONSTRAINT VERIFICATION")
            print("="*80)
            with torch.no_grad():
                # Check features after one forward pass
                h = model.layers[0](graph_data.x, graph_data.edge_index)
                verify_uhg_constraints(h, name="after layer 1")

            # Detailed evaluation
            final_metrics = evaluate_detailed(model, graph_data, graph_data.test_mask, label_mapping, phase="Test")
            metrics['train']['final_test_metrics'] = final_metrics
            print(f"\nFinal Test Accuracy: {final_metrics['accuracy']:.4f}")

    except Exception as e:
        tb = traceback.format_exc()
        print("\nFATAL ERROR during run:\n", tb)
        metrics['errors'] = {
            'message': str(e),
            'traceback': tb,
        }
    finally:
        total_runtime = time.perf_counter() - run_started
        metrics['timing']['total_s'] = total_runtime

        # ===== COMPREHENSIVE TIMING SUMMARY =====
        print("\n" + "="*80)
        print("⏱️  COMPREHENSIVE TIMING BREAKDOWN")
        print("="*80)

        if 'data_load' in metrics['timing']:
            data_t = metrics['timing']['data_load']
            print(f"\n📊 DATA LOADING ({data_t.get('total', 0):.2f}s total):")
            print(f"  • CSV Read:          {data_t.get('csv_read', 0):6.2f}s ({data_t.get('csv_read', 0)/total_runtime*100:5.1f}%)")
            print(f"  • Sampling:          {data_t.get('sampling', 0):6.2f}s ({data_t.get('sampling', 0)/total_runtime*100:5.1f}%)")
            print(f"  • To Numeric:        {data_t.get('to_numeric', 0):6.2f}s ({data_t.get('to_numeric', 0)/total_runtime*100:5.1f}%)")
            print(f"  • Fill NaN/Inf:      {data_t.get('fillna', 0):6.2f}s ({data_t.get('fillna', 0)/total_runtime*100:5.1f}%)")
            print(f"  • Scaling:           {data_t.get('scaling', 0):6.2f}s ({data_t.get('scaling', 0)/total_runtime*100:5.1f}%)")
            print(f"  • To Tensors:        {data_t.get('to_tensor', 0):6.2f}s ({data_t.get('to_tensor', 0)/total_runtime*100:5.1f}%)")

        if 'graph_build' in metrics['timing']:
            graph_t = metrics['timing']['graph_build']
            print(f"\n🕸️  GRAPH CONSTRUCTION ({graph_t.get('total', 0):.2f}s total):")
            print(f"  • KNN Computation:   {graph_t.get('knn_computation', 0):6.2f}s ({graph_t.get('knn_computation', 0)/total_runtime*100:5.1f}%)")
            print(f"  • Edge Index:        {graph_t.get('edge_index_creation', 0):6.2f}s ({graph_t.get('edge_index_creation', 0)/total_runtime*100:5.1f}%)")
            print(f"  • UHG Projection:    {graph_t.get('projective_normalize', 0):6.2f}s ({graph_t.get('projective_normalize', 0)/total_runtime*100:5.1f}%)")
            print(f"  • Constraint Check:  {graph_t.get('constraint_verification', 0):6.2f}s ({graph_t.get('constraint_verification', 0)/total_runtime*100:5.1f}%)")

        if 'train_total_s' in metrics['timing']:
            train_t = metrics['timing']['train_total_s']
            print(f"\n🎓 TRAINING ({train_t:.2f}s total, {train_t/total_runtime*100:.1f}% of runtime):")
            avg_epoch = train_t / metrics['train'].get('best_epoch', 1) if 'train' in metrics else 0
            print(f"  • Avg Epoch Time:    {avg_epoch:6.2f}s")
            print(f"  • Total Epochs:      {metrics['train'].get('best_epoch', 0)}")

            # Show detailed breakdown from epoch 1
            if 'epoch_details' in metrics['timing'] and 'epoch_1' in metrics['timing']['epoch_details']:
                ep1 = metrics['timing']['epoch_details']['epoch_1']
                train_detail = ep1.get('train', {})
                print(f"\n  Epoch Breakdown (Epoch 1):")
                print(f"    - Forward Pass:    {train_detail.get('forward_pass', 0):6.3f}s ({train_detail.get('forward_pass', 0)/ep1['total']*100:5.1f}%)")
                print(f"    - Backward Pass:   {train_detail.get('backward_pass', 0):6.3f}s ({train_detail.get('backward_pass', 0)/ep1['total']*100:5.1f}%)")
                print(f"    - Optimizer Step:  {train_detail.get('optimizer_step', 0):6.3f}s ({train_detail.get('optimizer_step', 0)/ep1['total']*100:5.1f}%)")
                print(f"    - Loss Compute:    {train_detail.get('loss_computation', 0):6.3f}s ({train_detail.get('loss_computation', 0)/ep1['total']*100:5.1f}%)")
                print(f"    - Val Eval:        {ep1.get('val', {}).get('forward_pass', 0):6.3f}s ({ep1.get('val', {}).get('forward_pass', 0)/ep1['total']*100:5.1f}%)")
                print(f"    - Test Eval:       {ep1.get('test', {}).get('forward_pass', 0):6.3f}s ({ep1.get('test', {}).get('forward_pass', 0)/ep1['total']*100:5.1f}%)")

        # High-level summary
        print(f"\n📈 HIGH-LEVEL SUMMARY:")
        data_pct = metrics['timing'].get('data_load', {}).get('total', 0) / total_runtime * 100
        graph_pct = metrics['timing'].get('graph_build', {}).get('total', 0) / total_runtime * 100
        train_pct = metrics['timing'].get('train_total_s', 0) / total_runtime * 100

        print(f"  • Data Loading:      {data_pct:5.1f}% of total runtime")
        print(f"  • Graph Building:    {graph_pct:5.1f}% of total runtime")
        print(f"  • Training:          {train_pct:5.1f}% of total runtime")
        print(f"  • Total Runtime:     {total_runtime:.2f}s")

        # GPU summary if available
        if torch.cuda.is_available() and 'gpu_memory' in metrics:
            peak_gb = metrics['gpu_memory'].get('peak_allocated', 0) / 1024**3
            print(f"  • Peak GPU Memory:   {peak_gb:.2f} GB")

        # Identify bottlenecks
        print(f"\n🔍 BOTTLENECK ANALYSIS:")
        bottlenecks = []

        if 'data_load' in metrics['timing']:
            data_t = metrics['timing']['data_load']
            for key, val in data_t.items():
                if key != 'total' and val > 1.0:  # More than 1 second
                    bottlenecks.append((f"Data: {key}", val, val/total_runtime*100))

        if 'graph_build' in metrics['timing']:
            graph_t = metrics['timing']['graph_build']
            for key, val in graph_t.items():
                if key != 'total' and val > 1.0:
                    bottlenecks.append((f"Graph: {key}", val, val/total_runtime*100))

        if bottlenecks:
            bottlenecks.sort(key=lambda x: x[1], reverse=True)
            for i, (name, time_s, pct) in enumerate(bottlenecks[:5], 1):
                print(f"  {i}. {name:30s} {time_s:6.2f}s ({pct:5.1f}%)")
        else:
            print("  No major bottlenecks detected (all operations < 1s)")

        out_path = os.path.join(RESULTS_PATH, f"metrics_gpt5_v2_{run_id}.json")
        save_json(metrics, out_path)

        print("\n" + "="*80)
        print("UHG IDS Model GPT5 v2 - Training Complete")
        print("="*80)
        print(f"Results saved to: {out_path}")

if __name__ == "__main__":
    main()



Mounting Google Drive...
Mounted at /content/drive

🖥️  HARDWARE CONFIGURATION
✅ GPU Detected:
   • Model: NVIDIA L4
   • Memory: 22.2 GB
   • CUDA Version: 12.4
   • Compute Capability: 8.9
   • Device: cuda:0


Loading data from: /content/drive/MyDrive/CIC_data.csv
  ⏱️  CSV read: 35.43s

Unique labels in the dataset: ['BENIGN' 'DDoS' 'PortScan' 'Bot' 'Infiltration'
 'Web Attack � Brute Force' 'Web Attack � XSS'
 'Web Attack � Sql Injection' 'FTP-Patator' 'SSH-Patator' 'DoS slowloris'
 'DoS Slowhttptest' 'DoS Hulk' 'DoS GoldenEye' 'Heartbleed']

Label distribution in the dataset:
Label
BENIGN                        2273097
DoS Hulk                       231073
PortScan                       158930
DDoS                           128027
DoS GoldenEye                   10293
FTP-Patator                      7938
SSH-Patator                      5897
DoS slowloris                    5796
DoS Slowhttptest                 5499
Bot                              1966
Web Attack � Brute Force 