## Graph U-Net Training Pipeline

###  Drive Mount & Installation

We mount Google Drive to access the data and install the required **PyTorch Geometric** libraries for our U-Net GNN implementation.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.0+cu121.html
!pip install torch-geometric

###  Configuration & Hyperparameters

This cell generates the `config.py` file, which centralizes all experiment parameters. This allows for quick adjustments to the configuration without modifying the model or training code directly.

Key definitions include:
* **Data Paths:** Absolute paths to the dataset (mounted via Google Drive).
* **GNN Architecture:** Hidden channel dimensions (`HIDDEN_CHANNELS`), depth (`DEPTH`), and pooling ratios.
* **Optimization:** Batch size, learning rate, and number of epochs.
* **Hardware:** Automatic selection of GPU (CUDA) if available.

In [None]:
%%writefile config.py
import torch

class Config:
    """
    Central Configuration for the experiment.
    All hyperparameters and paths are defined here for easy tuning.
    """

    # --- Data Paths (Google Drive) ---
    TRAIN_DATA_PATH = '/content/gdrive/MyDrive/DataSetML4/data/train_data.pt'
    VAL_DATA_PATH = '/content/gdrive/MyDrive/DataSetML4/data/val_data.pt'

    # --- Training Settings ---
    TRAIN_SUBSET_RATIO = 1      # Use 1.0 for full dataset, lower (e.g., 0.1) for debugging
    BATCH_SIZE = 64             # Number of graphs processed in parallel

    # --- Model Architecture (Graph U-Net) ---
    IN_CHANNELS = 12            # Input features (position, normals, boundary flags, etc.)
    HIDDEN_CHANNELS = 32        # Size of the hidden layers (model capacity)
    OUT_CHANNELS = 1            # Output: scalar Von Mises stress
    DEPTH = 3                   # Number of U-Net levels (Downsampling/Upsampling steps)
    POOL_RATIOS = 0.8           # Keep top 80% of nodes at each pooling step

    # --- Optimization ---
    NUM_EPOCHS = 2             # Total training iterations
    LEARNING_RATE = 0.0001      # Step size for the optimizer
    WEIGHT_DECAY = 1e-5         # L2 Regularization to prevent overfitting
    GRADIENT_CLIP = 1.0         # Max norm for gradients (stabilizes training)

    # --- Saving ---
    BEST_MODEL_PATH = 'best_model.pt'
    TRAINING_CURVE_PATH = 'training_curve.png'

    # --- Hardware ---
    # Automatically select GPU if available for faster computation
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###  Data Loading Compatibility

Although the dataset is already normalized, the saved `.pt` files contain instances of `UnitGaussianNormalizer`. We explicitly define this class here to ensure `torch.load` can deserialize the data correctly without raising `AttributeError` or warnings.

In [None]:
%%writefile utils.py
import torch

class UnitGaussianNormalizer:
    """
    Applies Unit Gaussian Normalization (Z-score standardization) to tensors.
    """

    def __init__(self, x=None, eps=1e-5):
        """
        Calculates and stores the mean and standard deviation of the input data 'x'.
        """
        self.eps = eps
        if x is not None:
            # Calculate statistics across the 0-th dimension (samples)
            self.mean = x.mean(dim=0, keepdim=True)
            self.std = x.std(dim=0, keepdim=True) + eps
        else:
            self.mean = None
            self.std = None

    def encode(self, x):
        """
        Normalizes the input x (subtract mean, divide by std).
        """
        if self.mean is None:
            return x
        return (x - self.mean) / self.std

    def decode(self, x, sample_idx=None):
        """
        Un-normalizes x (multiply by std, add mean) to recover original units.
        """
        if self.mean is None:
            return x
        return x * self.std + self.mean

### Model Definition: Graph U-Net

This cell generates `model.py`. We utilize the standard **GraphUNet** implementation from PyTorch Geometric.

The `create_model` function acts as a factory:
1.  It instantiates the U-Net architecture using the hyperparameters defined in `Config` (depth, hidden channels, pooling ratios).
2.  It automatically moves the model to the appropriate computational device (GPU/CPU).


In [None]:
%%writefile model.py
from torch_geometric.nn import GraphUNet

def create_model(config):
    """
    Instantiates the Graph U-Net architecture.
    """
    # We use the standard GraphUNet implementation from PyTorch Geometric.
    # This architecture uses a U-shaped design with Graph Pooling (gPool)
    # and Unpooling (gUnpool) operations to capture hierarchical features.
    model = GraphUNet(
        in_channels=config.IN_CHANNELS,         # Input features per node (12)
        hidden_channels=config.HIDDEN_CHANNELS, # Width of the hidden layers
        out_channels=config.OUT_CHANNELS,       # Output target (1 scalar: von Mises stress)
        depth=config.DEPTH,                     # Number of pooling/unpooling steps
        pool_ratios=config.POOL_RATIOS,         # Ratio of nodes kept after each pooling layer
        sum_res=True,                           # Enables residual connections (Skip connections)
        act='relu'                              # Activation function
    )

    # Move the entire model to the selected hardware (GPU/CPU)
    model = model.to(config.DEVICE)
    return model

### Data Loading & Reconstruction

This cell generates `data_loading.py`. Because our dataset is stored in a raw PyTorch Geometric format (a tuple of concatenated tensors and slice indices) rather than a list of objects, we need a custom loading pipeline.

Key steps include:
1.  **Module Patching:** We manually inject `UnitGaussianNormalizer` into `sys.modules` to prevent pickling errors during `torch.load`.
2.  **Graph Reconstruction:** The `get_graph_from_tuple` function manually slices the large concatenated tensors back into individual `Data` objects using the stored indices.
3.  **DataLoaders:** Finally, we organize the graphs into `DataLoader` instances to handle efficient mini-batching for the GPU.

In [None]:
%%writefile data_loading.py
import torch
import sys
import types
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from utils import UnitGaussianNormalizer

# --- CRITICAL WORKAROUND ---
# The dataset was saved (pickled) with a reference to 'utils.UnitGaussianNormalizer'.
# During loading, PyTorch looks for this exact module path. Since we are running in a
# notebook/script environment where 'utils' might be defined differently, we manually
# inject the class into sys.modules to prevent an AttributeError during torch.load.
sys.modules['utils'] = types.ModuleType('utils')
sys.modules['utils'].UnitGaussianNormalizer = UnitGaussianNormalizer

def get_graph_from_tuple(data, slices, idx):
    """
    Reconstructs a single PyTorch Geometric Data object from the large concatenated
    tensors using slice indices.

    Args:
        data: The huge object containing concatenated attributes (x, edge_index, etc.) for all graphs.
        slices: A dictionary storing the start/end indices for each attribute.
        idx: The index of the specific graph to retrieve.
    """
    data_dict = {}
    for key in slices.keys():
        start, end = slices[key][idx].item(), slices[key][idx + 1].item()

        # Handle 2D tensors (like edge_index or faces) differently than 1D attributes
        if key in ['edge_index', 'face'] and data[key].dim() == 2:
            data_dict[key] = data[key][:, start:end]
        else:
            data_dict[key] = data[key][start:end]

    return Data(**data_dict)

def load_data(config):
    """
    Loads training and validation datasets from .pt files, applies subsetting,
    and returns DataLoaders for the GNN.
    """
    print("Loading data...")

    # Load Training Data
    # weights_only=False is required here because the file contains complex objects
    # (like the Normalizer class instance), not just state dictionaries.
    train_dataset_tuple = torch.load(config.TRAIN_DATA_PATH, weights_only=False)
    train_data, train_slices = train_dataset_tuple

    # --- Subset Training Data (Optional) ---
    # Reduces the dataset size based on TRAIN_SUBSET_RATIO (useful for faster debugging)
    num_graphs = train_slices['x'].size(0) - 1
    keep = int(num_graphs * config.TRAIN_SUBSET_RATIO)

    new_slices = {}
    for key in train_slices.keys():
        new_slices[key] = train_slices[key][:keep+1].clone()
    train_slices = new_slices

    # Load Validation Data
    test_dataset_tuple = torch.load(config.VAL_DATA_PATH, weights_only=False)
    test_data, test_slices = test_dataset_tuple

    # Reconstruct individual Data objects list from the massive storage tensors
    # This might take a few seconds but makes iteration easier.
    train_graphs = [get_graph_from_tuple(train_data, train_slices, i)
                    for i in range(train_slices['x'].size(0) - 1)]
    val_graphs = [get_graph_from_tuple(test_data, test_slices, i)
                  for i in range(test_slices['x'].size(0) - 1)]

    # Create PyTorch Geometric DataLoaders
    # These handle the dynamic batching (diagonal stacking of adjacency matrices)
    train_loader = DataLoader(train_graphs, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_graphs, batch_size=config.BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, len(train_graphs), len(val_graphs), train_graphs, val_graphs

### Training Loop & Metric Tracking

This block defines the core training and evaluation logic:

1.  **`validate`**: This function goes beyond simple loss calculation. It computes the **$R^2$ score for each individual geometry** in the validation set. It then calculates the 10th, 50th, and 90th percentiles to measure model robustness and identifies the indices of the best, median, and worst performing graphs for later visualization.
2.  **`train_model`**: The main loop that iterates through epochs, updates weights, and saves the model checkpoint (`best_model.pt`) whenever validation loss improves.
3.  **`plot_metrics`**: A utility to generate comprehensive learning curves, allowing us to visualize convergence and the stability of predictions across different percentiles.

In [None]:
%%writefile training.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

def train_epoch(model, train_loader, optimizer, device, gradient_clip, num_train_graphs):
    """Performs one training epoch."""
    model.train()
    train_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = F.huber_loss(out.view(-1), batch.y.view(-1), delta=0.5)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip)
        optimizer.step()
        train_loss += loss.item() * batch.num_graphs
    return train_loss / num_train_graphs

def validate(model, val_loader, device, num_val_graphs):
    """
    Evaluates the model on the validation set.
    Calculates global metrics (RMSE, MAE) and per-graph R2 scores.
    Returns the indices of the Best, Median, and Worst performing graphs.
    """
    model.eval()
    val_loss = 0

    # Lists for global RMSE/MAE calculation
    all_preds_global = []
    all_targets_global = []

    # List to store (global_index, r2_score) tuples
    # Allows tracking which graph corresponds to which score
    graph_scores_with_indices = []

    # Offset to track the real global index across batches
    global_idx_offset = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)

            # Loss Calculation
            loss = F.huber_loss(out.view(-1), batch.y.view(-1), delta=0.5)
            val_loss += loss.item() * batch.num_graphs

            # Data Prep (Move to CPU for sklearn metrics)
            preds_cpu = out.view(-1).cpu().numpy()
            targets_cpu = batch.y.view(-1).cpu().numpy()
            batch_indices = batch.batch.cpu().numpy()

            all_preds_global.append(preds_cpu)
            all_targets_global.append(targets_cpu)

            # Calculate R2 per individual graph
            unique_graphs = np.unique(batch_indices)

            for graph_id in unique_graphs:
                mask = (batch_indices == graph_id)
                graph_targets = targets_cpu[mask]
                graph_preds = preds_cpu[mask]

                # Compute R2 only if graph has more than 1 node
                if len(graph_targets) > 1:
                    score = r2_score(graph_targets, graph_preds)

                    # Calculate absolute index in the validation dataset
                    real_global_index = global_idx_offset + graph_id
                    graph_scores_with_indices.append((real_global_index, score))

            # Update offset for next batch
            global_idx_offset += batch.num_graphs

    val_loss /= num_val_graphs

    metrics = {}
    selected_indices = [None, None, None] # [Best, Mean, Worst]

    if len(all_preds_global) > 0:
        final_preds = np.concatenate(all_preds_global)
        final_targets = np.concatenate(all_targets_global)

        metrics['RMSE'] = np.sqrt(mean_squared_error(final_targets, final_preds))
        metrics['MAE'] = mean_absolute_error(final_targets, final_preds)

        if len(graph_scores_with_indices) > 0:
            # Separate indices and scores
            indices_arr = np.array([item[0] for item in graph_scores_with_indices])
            scores_arr = np.array([item[1] for item in graph_scores_with_indices])

            # 1. Calculate R2 percentiles (10th, 50th, 90th)
            metrics['R2_90PCT'] = np.percentile(scores_arr, 90)
            metrics['R2_50PCT'] = np.percentile(scores_arr, 50)
            metrics['R2_10PCT'] = np.percentile(scores_arr, 10)

            # 2. Identify key graph indices for visualization

            # A. Best Case (Max R2)
            idx_best = indices_arr[np.argmax(scores_arr)]

            # B. Worst Case (Min R2)
            idx_worst = indices_arr[np.argmin(scores_arr)]

            # C. Median/Average Case (Closest to mean R2)
            mean_score = np.mean(scores_arr)
            idx_closest_to_mean = indices_arr[(np.abs(scores_arr - mean_score)).argmin()]

            # Return order: [Best, Median, Worst]
            selected_indices = [int(idx_best), int(idx_closest_to_mean), int(idx_worst)]

        else:
            metrics['R2_90PCT'] = 0; metrics['R2_50PCT'] = 0; metrics['R2_10PCT'] = 0

    return val_loss, metrics, selected_indices

def train_model(model, train_loader, val_loader, optimizer, config, num_train_graphs, num_val_graphs):
    """
    Main training loop.
    Iterates over epochs, runs validation, saves the best model, and logs history.
    """
    print(f"Starting training on {config.DEVICE}...")

    # Initialize history dictionary
    history = {
        'train_loss': [], 'val_loss': [],
        'RMSE': [], 'MAE': [],
        'R2_90PCT': [], 'R2_50PCT': [], 'R2_10PCT': []
    }

    best_val_loss = float('inf')

    for epoch in range(1, config.NUM_EPOCHS + 1):
        train_loss = train_epoch(model, train_loader, optimizer, config.DEVICE, config.GRADIENT_CLIP, num_train_graphs)
        val_loss, metrics, index = validate(model, val_loader, config.DEVICE, num_val_graphs)

        # Update History
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['RMSE'].append(metrics['RMSE'])
        history['MAE'].append(metrics['MAE'])
        history['R2_90PCT'].append(metrics['R2_90PCT'])
        history['R2_50PCT'].append(metrics['R2_50PCT'])
        history['R2_10PCT'].append(metrics['R2_10PCT'])

        print(f"Ep {epoch:03d} | Val: {val_loss:.4f} | RMSE: {metrics['RMSE']:.3f} | "
              f"R2(10/50/90): {metrics['R2_10PCT']:.2f} / {metrics['R2_50PCT']:.2f} / {metrics['R2_90PCT']:.2f}")

        # Checkpoint: Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), config.BEST_MODEL_PATH)

    return history, best_val_loss, index

def plot_metrics(history, save_path):
    """
    Generates and saves 3 plots:
    1. Training vs Validation Loss
    2. Global Error Metrics (RMSE, MAE)
    3. R2 Score Distribution (10th, 50th, 90th percentiles)
    """
    epochs = range(1, len(history['train_loss']) + 1)

    fig, axs = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    # 1. Loss Curves
    axs[0].plot(epochs, history['train_loss'], label="Train Loss", color='blue')
    axs[0].plot(epochs, history['val_loss'], label="Val Loss", color='red')
    axs[0].set_ylabel("Huber Loss")
    axs[0].set_title("Training & Validation Loss")
    axs[0].legend()
    axs[0].grid(True, linestyle='--', alpha=0.6)

    # 2. Error Metrics
    axs[1].plot(epochs, history['RMSE'], label="RMSE", color='orange')
    axs[1].plot(epochs, history['MAE'], label="MAE", color='green')
    axs[1].set_ylabel("Error Units")
    axs[1].set_title("Global Error Metrics")
    axs[1].legend()
    axs[1].grid(True, linestyle='--', alpha=0.6)

    # 3. R2 Percentiles
    axs[2].plot(epochs, history['R2_90PCT'], label="R2 Best (90%)", linestyle='--', color='purple')
    axs[2].plot(epochs, history['R2_50PCT'], label="R2 Median (50%)", linewidth=2, color='black')
    axs[2].plot(epochs, history['R2_10PCT'], label="R2 Worst (10%)", linestyle=':', color='brown')

    # Add zero line to highlight negative scores
    axs[2].axhline(y=0, color='gray', linestyle='-', alpha=0.3)

    axs[2].set_ylabel("R2 Score")
    axs[2].set_xlabel("Epochs")
    axs[2].set_title("R2 Score Distribution (Per Geometry)")
    axs[2].legend(loc='lower right')
    axs[2].grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Learning curves saved to {save_path}")

### Main Execution Pipeline

This cell generates `main.py`, which serves as the orchestrator for our experiment. It integrates the modular components defined previously (Config, Data, Model, Training).

The `main()` function performs the following steps:
1.  **Initialization:** Instantiates the configuration and loads the dataset.
2.  **Setup:** Builds the model and initializes the **Adam optimizer**.
3.  **Execution:** Runs the training loop and generates the learning curves.
4.  **Returns:** Crucially, it returns the trained `model`, the `val_graph` dataset, and the `index` list (containing IDs of the best/worst predictions). These are essential for the **qualitative analysis** in the next section.

In [None]:
%%writefile main.py
import torch
from config import Config
from data_loading import load_data
from model import create_model
from training import train_model, plot_metrics

def main():
    """
    Orchestrates the entire machine learning pipeline:
    1. Configuration setup
    2. Data loading
    3. Model initialization
    4. Training loop execution
    5. Result visualization
    """
    # 1. Initialize configuration parameters
    config = Config()

    # 2. Load datasets and create DataLoaders
    # Returns loaders for batching and raw lists of graphs for analysis
    train_loader, val_loader, num_train, num_val, train_graph, val_graph = load_data(config)

    # 3. Initialize the Graph Neural Network model
    model = create_model(config)

    # 4. Define the Optimizer (Adam)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )

    # 5. Execute the training loop
    # Returns:
    # - history: Dictionary containing loss and metrics over epochs
    # - best_val_loss: The lowest validation loss achieved
    # - index: List of indices for [Best, Median, Worst] validation cases
    history, best_val_loss, index = train_model(
        model, train_loader, val_loader, optimizer, config, num_train, num_val
    )

    # 6. Generate and save training curves (Loss, RMSE, R2)
    plot_metrics(history, config.TRAINING_CURVE_PATH)

    # Return key objects for the subsequent visualization steps in the notebook
    return model, index, val_graph

### Execution

This cell runs the entire pipeline defined in `main.py`. It trains the model, plots the learning curves, and returns the best/worst case indices needed for the qualitative analysis in the next section.

In [None]:
from main import main


model, index, val_graph = main()

### Qualitative Analysis: 3D Stress Field Visualization

This cell performs the final visual inspection of the model's predictions. We iterate through the specific test cases identified earlier: **Best, Median, and Worst** performance.

**Key Feature: Independent Color Scaling**
The `visualize_side_by_side_independent` function allows the color map to scale dynamically for each plot:
* **Left (Target):** Ground truth from FEM.
* **Right (Prediction):** GNN output.

**Why independent scales?**
This allows us to verify if the model has learned the correct **stress distribution patterns** (topology of the hotspots), even if the absolute **magnitudes** are underestimated (a common issue known as "over-smoothing" in regression tasks).

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
from config import Config
import sys
import importlib

# Import the main module and reload it to ensure latest changes are used
import main
importlib.reload(main)

# --- 1. SIDE-BY-SIDE VISUALIZATION FUNCTION (AUTO-SCALE) ---
def visualize_side_by_side_independent(target_data, pred_data, index):
    """
    Displays the ground truth and the prediction side by side.
    Each plot uses its OWN color scale to visualize patterns effectively,
    independently of the value magnitude.
    """

    # Face verification
    if not hasattr(target_data, 'face') or target_data.face is None:
        print("⚠️ No faces detected.")
        return

    # --- Data Extraction ---
    x = target_data.x[:, 0].cpu().numpy()
    y = target_data.x[:, 1].cpu().numpy()
    z = target_data.x[:, 2].cpu().numpy()
    faces = target_data.face.cpu().numpy()
    i, j, k = faces[0], faces[1], faces[2]

    val_target = target_data.y[:, 0].cpu().numpy()
    val_pred = pred_data.y[:, 0].cpu().numpy()

    # --- Figure Creation ---
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=(
            f"TARGET (Max: {val_target.max():.2f})",
            f"PREDICTION (Max: {val_pred.max():.2f})"
        ),
        horizontal_spacing=0.05
    )

    # --- Trace 1: Left (Target) ---
    fig.add_trace(
        go.Mesh3d(
            x=x, y=y, z=z, i=i, j=j, k=k,
            intensity=val_target,
            intensitymode='vertex',
            colorscale='Jet',
            # HERE: Force scale based on TARGET data only
            cmin=val_target.min(),
            cmax=val_target.max(),
            opacity=1.0,
            name='Target',
            colorbar=dict(title="Target", x=0.45, len=0.8)
        ),
        row=1, col=1
    )

    # --- Trace 2: Right (Prediction) ---
    fig.add_trace(
        go.Mesh3d(
            x=x, y=y, z=z, i=i, j=j, k=k,
            intensity=val_pred,
            intensitymode='vertex',
            colorscale='Jet',
            # HERE: Force scale based on PREDICTION data only
            cmin=val_pred.min(),
            cmax=val_pred.max(),
            opacity=1.0,
            name='Prediction',
            colorbar=dict(title="Pred", x=1.0, len=0.8) # A second legend bar on the right
        ),
        row=1, col=2
    )

    # --- Layout ---
    fig.update_layout(
        title_text=f"Stress Pattern Comparison (Part {index}) - Independent Scales",
        height=600, width=1200,
        margin=dict(r=0, b=0, l=0, t=50),
        scene=dict(aspectmode='data'),
        scene2=dict(aspectmode='data')
    )

    fig.show()

# --- 2. EXECUTION ---

for i in index:
  print(f"--- Comparison (Decoupled Scales) for Part #{i} ---")

  # Retrieval and calculation
  target_graph = val_graph[i].clone().to(Config.DEVICE)
  model.eval()
  with torch.no_grad():
      pred_tensor = model(target_graph.x, target_graph.edge_index, batch=None)

  prediction_graph = target_graph.clone()
  prediction_graph.y = pred_tensor

  # Display (Function call matches the definition above)
  visualize_side_by_side_independent(target_graph, prediction_graph, i)