## Transformer GNN Training Pipeline

###  Drive Mount & Installation

We mount Google Drive to access the data and install the required **PyTorch Geometric** libraries for our Transformer GNN implementation.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.0+cu121.html
!pip install torch-geometric

### Graph Transformer Configuration

This cell defines the hyperparameters specifically tuned for our **Graph Transformer** architecture.

**Key Differences from previous models:**
* **`HEADS = 4`**: We use Multi-Head Attention to allow the model to focus on different parts of the neighborhood simultaneously.
* **`EDGE_DIM = 4`**: Crucial for this architecture, we explicitly define the dimension of edge features (relative position + distance) which drive the attention mechanism.
* **`HIDDEN_CHANNELS = 128`**: Transformers generally require wider layers to be effective compared to simple GCNs.
* **`NUM_EPOCHS = 100`**: Attention mechanisms typically require more training steps to converge to a stable solution.

In [None]:
%%writefile config.py
import torch

class Config:
    """
    Configuration for the Graph Transformer experiment.
    """
    
    # --- Data Paths ---
    TRAIN_DATA_PATH = '/content/gdrive/MyDrive/DataSetML4/data/train_data.pt'
    VAL_DATA_PATH = '/content/gdrive/MyDrive/DataSetML4/data/val_data.pt'

    # --- Training Settings ---
    TRAIN_SUBSET_RATIO = 1      # Use 1.0 for full dataset
    BATCH_SIZE = 32             # Reduced batch size due to higher memory usage of Transformers

    # --- Graph Transformer Architecture ---
    IN_CHANNELS = 12            # Input node features
    HIDDEN_CHANNELS = 128       # Larger hidden dimension for better expressivity
    OUT_CHANNELS = 1            # Target: Scalar stress value
    
    # Transformer-specific parameters:
    HEADS = 4                   # Number of attention heads (Multi-Head Attention)
    CONCAT = True               # Concatenate attention head outputs (vs. averaging)
    BETA = True                 # Enable bias in the TransformerConv layer
    EDGE_DIM = 4                # Dimension of edge features (relative pos + distance)
    NUM_LAYERS = 5              # Deeper network to capture long-range dependencies

    # --- Optimization ---
    NUM_EPOCHS = 100            # Increased epochs as Transformers take longer to converge
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-5
    GRADIENT_CLIP = 1.0

    # --- Saving ---
    BEST_MODEL_PATH = 'best_model.pt'
    TRAINING_CURVE_PATH = 'training_curve.png'

    # --- Hardware ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###  Data Loading Compatibility

Although the dataset is already normalized, the saved `.pt` files contain instances of `UnitGaussianNormalizer`. We explicitly define this class here to ensure `torch.load` can deserialize the data correctly without raising `AttributeError` or warnings.Utils


In [None]:
%%writefile utils.py
import torch

class UnitGaussianNormalizer:
    """
    Applies Unit Gaussian Normalization (Z-score standardization) to tensors.
    """

    def __init__(self, x=None, eps=1e-5):
        """
        Calculates and stores the mean and standard deviation of the input data 'x'.
        """
        self.eps = eps
        if x is not None:
            # Calculate statistics across the 0-th dimension (samples)
            self.mean = x.mean(dim=0, keepdim=True)
            self.std = x.std(dim=0, keepdim=True) + eps
        else:
            self.mean = None
            self.std = None

    def encode(self, x):
        """
        Normalizes the input x (subtract mean, divide by std).
        """
        if self.mean is None:
            return x
        return (x - self.mean) / self.std

    def decode(self, x, sample_idx=None):
        """
        Un-normalizes x (multiply by std, add mean) to recover original units.
        """
        if self.mean is None:
            return x
        return x * self.std + self.mean

### Model Definition: Graph Transformer

This cell defines our custom `StressTransformer` architecture. Unlike the U-Net which relies on pooling, this model uses **`TransformerConv`** layers.

**Key Architectural Features:**
1.  **Edge-Conditioned Attention:** The `edge_dim` parameter allows the model to use the geometric edge features (distances, relative vectors) to calculate attention scores. This means the model learns to "pay attention" to nodes based on their physical spatial relationship, not just their connectivity.
2.  **Layer Normalization & GELU:** We use standard Transformer components (`LayerNorm`, `GELU`) which stabilize training for deeper networks.
3.  **Multi-Head Aggregation:** The input and hidden layers concatenate multiple attention heads to capture diverse features, while the final layer averages them (`concat=False`) to produce the single stress scalar.

In [None]:
%%writefile model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, LayerNorm


class StressTransformer(torch.nn.Module):
    """
    Graph Transformer Network for Stress Prediction.
    Uses Multi-Head Attention to capture long-range dependencies and geometric relationships.
    """
    def __init__(self, config):
        super().__init__()

        self.num_layers = config.NUM_LAYERS
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        # --- Input Layer ---
        # Projects input features to hidden dimension.
        # Note: When concat=True, the output dimension is heads * out_channels.
        # So we divide config.HIDDEN_CHANNELS by config.HEADS to keep the total width constant.
        self.convs.append(TransformerConv(
            in_channels=config.IN_CHANNELS,
            out_channels=config.HIDDEN_CHANNELS // config.HEADS,
            heads=config.HEADS,
            concat=config.CONCAT,
            beta=config.BETA,
            edge_dim=config.EDGE_DIM # Crucial: Incorporates edge features (distance/relative pos) into attention
        ))
        self.norms.append(LayerNorm(config.HIDDEN_CHANNELS))


        # --- Hidden Layers ---
        # Stack of Transformer blocks
        for _ in range(config.NUM_LAYERS - 2):
          self.convs.append(TransformerConv(
            in_channels=config.HIDDEN_CHANNELS,
            out_channels=config.HIDDEN_CHANNELS // config.HEADS,
            heads=config.HEADS,
            concat=config.CONCAT,
            beta=config.BETA,
            edge_dim=config.EDGE_DIM
          ))
          self.norms.append(LayerNorm(config.HIDDEN_CHANNELS))


        # --- Output Layer ---
        # Projects back to scalar stress value.
        # We set concat=False to average the heads and get a single output per node.
        self.out_conv = TransformerConv(
            in_channels=config.HIDDEN_CHANNELS,
            out_channels=config.OUT_CHANNELS,      # Output size: 1 (Von Mises Stress)
            heads=config.HEADS,
            concat=False,       # Average the attention heads instead of concatenating
            edge_dim=config.EDGE_DIM,
            beta=config.BETA,
            bias=False
        )

    def forward(self, x, edge_index, edge_attr):
      """
      Forward pass.
      Args:
          x: Node features [Num_Nodes, In_Channels]
          edge_index: Graph connectivity [2, Num_Edges]
          edge_attr: Edge features [Num_Edges, Edge_Dim] (Distances, vectors)
      """
      # Pass through hidden layers with Residuals, Norm, and Activation
      for i in range(self.num_layers - 1):
        x = self.convs[i](x, edge_index, edge_attr)
        x = self.norms[i](x)
        x = F.gelu(x) # GELU is standard for Transformers (smoother than ReLU)

      # Final prediction layer (No activation for regression)
      x = self.out_conv(x, edge_index, edge_attr)

      return x


def create_model(config):
    """
    Factory function to instantiate the StressTransformer.
    """
    model = StressTransformer(config)
    model = model.to(config.DEVICE)
    return model

### Data Loading & Deserialization

This cell generates `data_loading.py`. The provided dataset is stored in a highly optimized "Collate" format (concatenated tensors + slice indices) rather than a simple list of objects.

Key implementation details:
1.  **Module Patching:** We manually inject `UnitGaussianNormalizer` into `sys.modules`. This is a necessary hack to fix a serialization issue where `torch.load` fails to find the custom class definition used when the dataset was originally saved.
2.  **Reconstruction:** The `get_graph_from_tuple` function slices the massive concatenated tensors back into individual `Data` objects.
3.  **`weights_only=False`**: We explicitly allow pickling of complex objects to support the custom normalizer class embedded in the file.

In [None]:
%%writefile data_loading.py
import torch
import sys
import types
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from utils import UnitGaussianNormalizer

# --- CRITICAL WORKAROUND ---
# The dataset was pickled with a reference to 'utils.UnitGaussianNormalizer'.
# To prevent an AttributeError during torch.load, we manually inject this class
# into sys.modules so the unpickler can find the definition.
sys.modules['utils'] = types.ModuleType('utils')
sys.modules['utils'].UnitGaussianNormalizer = UnitGaussianNormalizer

def get_graph_from_tuple(data, slices, idx):
    """
    Reconstructs a single Data object from the concatenated storage format.

    Args:
        data: The monolithic object containing features for all graphs.
        slices: Dictionary defining the start/end indices for each graph's features.
        idx: Index of the graph to retrieve.
    """
    data_dict = {}
    for key in slices.keys():
        start, end = slices[key][idx].item(), slices[key][idx + 1].item()

        # Handle 2D tensors (like edge_index and face) differently
        if key in ['edge_index', 'face'] and data[key].dim() == 2:
            data_dict[key] = data[key][:, start:end]
        else:
            data_dict[key] = data[key][start:end]
    return Data(**data_dict)

def load_data(config):
    """
    Loads training and validation data, reconstructs graph objects, and creates DataLoaders.
    """
    print("Loading data...")

    # Load Train Data
    # weights_only=False is required to load custom objects (Normalizer) safely in this context
    train_dataset_tuple = torch.load(config.TRAIN_DATA_PATH, weights_only=False)
    train_data, train_slices = train_dataset_tuple

    # --- Subset Logic ---
    # Reduces dataset size based on TRAIN_SUBSET_RATIO (useful for quick tests)
    num_graphs = train_slices['x'].size(0) - 1
    keep = int(num_graphs * config.TRAIN_SUBSET_RATIO)

    new_slices = {}
    for key in train_slices.keys():
        new_slices[key] = train_slices[key][:keep+1].clone()
    train_slices = new_slices

    # Load Validation Data
    test_dataset_tuple = torch.load(config.VAL_DATA_PATH, weights_only=False)
    test_data, test_slices = test_dataset_tuple

    # Reconstruct individual graph objects from the storage tensors
    # This loop converts the efficient storage format back into a list of Data objects
    train_graphs = [get_graph_from_tuple(train_data, train_slices, i)
                    for i in range(train_slices['x'].size(0) - 1)]
    val_graphs = [get_graph_from_tuple(test_data, test_slices, i)
                  for i in range(test_slices['x'].size(0) - 1)]

    # Create PyTorch Geometric DataLoaders
    # These handle the dynamic batching of graphs (diagonal adjacency stacking)
    train_loader = DataLoader(train_graphs, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_graphs, batch_size=config.BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, len(train_graphs), len(val_graphs), train_graphs, val_graphs

### Dynamic Training Loop with Geometric Features

This cell generates `training.py`.

**Important Difference for Transformers:**
Unlike standard GNNs, the Transformer architecture requires explicit geometric edge features (`edge_attr`) to compute attention scores.
* **On-the-fly computation:** Inside `train_epoch` and `validate`, we dynamically calculate the **relative position vectors** and **Euclidean distances** between connected nodes.
* **Input:** These features are concatenated and passed to the `model` forward pass alongside node features (`batch.x`) and connectivity (`batch.edge_index`).

In [None]:
%%writefile training.py
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

def train_epoch(model, train_loader, optimizer, device, gradient_clip, num_train_graphs):
    """
    Performs one training epoch for the Transformer.
    Crucial: Computes edge features (relative pos + distance) on the fly.
    """
    model.train()
    train_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # --- Dynamic Edge Feature Computation ---
        # The Transformer needs geometric info on edges.
        # We calculate relative positions and Euclidean distances between connected nodes.
        src, dst = batch.edge_index
        rel_pos = batch.pos[src] - batch.pos[dst]            # Relative position vector (dx, dy, dz)
        dist = torch.norm(rel_pos, dim=1, keepdim=True)      # Euclidean distance (scalar)
        edge_attr = torch.cat([rel_pos, dist], dim=1)        # Concatenate -> [Edges, 4]

        # Forward pass with edge attributes
        out = model(batch.x, batch.edge_index, edge_attr)

        loss = F.huber_loss(out.view(-1), batch.y.view(-1), delta=0.5)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip)
        optimizer.step()
        train_loss += loss.item() * batch.num_graphs

    return train_loss / num_train_graphs

def validate(model, val_loader, device, num_val_graphs):
    """
    Evaluates the Transformer model.
    Computes edge features dynamically for validation graphs.
    """
    model.eval()
    val_loss = 0

    all_preds_global = []
    all_targets_global = []
    graph_scores_with_indices = []
    global_idx_offset = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)

            # --- Dynamic Edge Feature Computation (Same as training) ---
            src, dst = batch.edge_index
            rel_pos = batch.pos[src] - batch.pos[dst]
            dist = torch.norm(rel_pos, dim=1, keepdim=True)
            edge_attr = torch.cat([rel_pos, dist], dim=1)

            out = model(batch.x, batch.edge_index, edge_attr)

            loss = F.huber_loss(out.view(-1), batch.y.view(-1), delta=0.5)
            val_loss += loss.item() * batch.num_graphs

            preds_cpu = out.view(-1).cpu().numpy()
            targets_cpu = batch.y.view(-1).cpu().numpy()
            batch_indices = batch.batch.cpu().numpy()

            all_preds_global.append(preds_cpu)
            all_targets_global.append(targets_cpu)

            # R2 Calculation per graph
            unique_graphs = np.unique(batch_indices)
            for graph_id in unique_graphs:
                mask = (batch_indices == graph_id)
                graph_targets = targets_cpu[mask]
                graph_preds = preds_cpu[mask]

                if len(graph_targets) > 1:
                    score = r2_score(graph_targets, graph_preds)
                    real_global_index = global_idx_offset + graph_id
                    graph_scores_with_indices.append((real_global_index, score))

            global_idx_offset += batch.num_graphs

    val_loss /= num_val_graphs

    metrics = {}
    selected_indices = [None, None, None]

    if len(all_preds_global) > 0:
        final_preds = np.concatenate(all_preds_global)
        final_targets = np.concatenate(all_targets_global)

        metrics['RMSE'] = np.sqrt(mean_squared_error(final_targets, final_preds))
        metrics['MAE'] = mean_absolute_error(final_targets, final_preds)

        if len(graph_scores_with_indices) > 0:
            indices_arr = np.array([item[0] for item in graph_scores_with_indices])
            scores_arr = np.array([item[1] for item in graph_scores_with_indices])

            metrics['R2_90PCT'] = np.percentile(scores_arr, 90)
            metrics['R2_50PCT'] = np.percentile(scores_arr, 50)
            metrics['R2_10PCT'] = np.percentile(scores_arr, 10)

            idx_best = indices_arr[np.argmax(scores_arr)]
            idx_worst = indices_arr[np.argmin(scores_arr)]
            mean_score = np.mean(scores_arr)
            idx_closest_to_mean = indices_arr[(np.abs(scores_arr - mean_score)).argmin()]

            selected_indices = [int(idx_best), int(idx_closest_to_mean), int(idx_worst)]
        else:
            metrics['R2_90PCT'] = 0; metrics['R2_50PCT'] = 0; metrics['R2_10PCT'] = 0

    return val_loss, metrics, selected_indices

def train_model(model, train_loader, val_loader, optimizer, config, num_train_graphs, num_val_graphs):
    print(f"Starting training on {config.DEVICE}...")

    history = {
        'train_loss': [], 'val_loss': [],
        'RMSE': [], 'MAE': [],
        'R2_90PCT': [], 'R2_50PCT': [], 'R2_10PCT': []
    }

    best_val_loss = float('inf')

    for epoch in range(1, config.NUM_EPOCHS + 1):
        train_loss = train_epoch(model, train_loader, optimizer, config.DEVICE, config.GRADIENT_CLIP, num_train_graphs)
        val_loss, metrics, index = validate(model, val_loader, config.DEVICE, num_val_graphs)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['RMSE'].append(metrics['RMSE'])
        history['MAE'].append(metrics['MAE'])
        history['R2_90PCT'].append(metrics['R2_90PCT'])
        history['R2_50PCT'].append(metrics['R2_50PCT'])
        history['R2_10PCT'].append(metrics['R2_10PCT'])

        print(f"Ep {epoch:03d} | Val: {val_loss:.4f} | RMSE: {metrics['RMSE']:.3f} | "
              f"R2(10/50/90): {metrics['R2_10PCT']:.2f} / {metrics['R2_50PCT']:.2f} / {metrics['R2_90PCT']:.2f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), config.BEST_MODEL_PATH)

    return history, best_val_loss, index

def plot_metrics(history, save_path):
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axs = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    axs[0].plot(epochs, history['train_loss'], label="Train Loss", color='blue')
    axs[0].plot(epochs, history['val_loss'], label="Val Loss", color='red')
    axs[0].set_ylabel("Huber Loss")
    axs[0].set_title("Training & Validation Loss")
    axs[0].legend(); axs[0].grid(True, linestyle='--', alpha=0.6)

    axs[1].plot(epochs, history['RMSE'], label="RMSE", color='orange')
    axs[1].plot(epochs, history['MAE'], label="MAE", color='green')
    axs[1].set_ylabel("Error Units")
    axs[1].set_title("Global Error Metrics")
    axs[1].legend(); axs[1].grid(True, linestyle='--', alpha=0.6)

    axs[2].plot(epochs, history['R2_90PCT'], label="R2 Best (90%)", linestyle='--', color='purple')
    axs[2].plot(epochs, history['R2_50PCT'], label="R2 Median (50%)", linewidth=2, color='black')
    axs[2].plot(epochs, history['R2_10PCT'], label="R2 Worst (10%)", linestyle=':', color='brown')
    axs[2].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
    axs[2].set_ylabel("R2 Score")
    axs[2].set_xlabel("Epochs")
    axs[2].set_title("R2 Score Distribution (Per Geometry)")
    axs[2].legend(loc='lower right'); axs[2].grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Learning curves saved to {save_path}")

### Main Execution: Graph Transformer

This cell generates `main.py`. It serves as the central entry point that integrates:
1.  The **Transformer Configuration** (defining heads, layers, etc.).
2.  The **`StressTransformer` model**.
3.  The **Dynamic Training Loop** (which handles edge feature computation).

When executed, it trains the model, saves the performance history, and identifies the best/worst test cases for the final visualization.

In [None]:
%%writefile main.py
import torch
from config import Config
from data_loading import load_data
from model import create_model
from training import train_model, plot_metrics

def main():
    """
    Orchestrates the Graph Transformer training pipeline.
    Connects the configuration, data loader, custom Transformer model,
    and the training loop.
    """
    # 1. Initialize Transformer-specific configuration
    config = Config()

    # 2. Load Data
    # Returns loaders and raw graph lists (needed for visualization later)
    train_loader, val_loader, num_train, num_val, train_graph, val_graph = load_data(config)

    # 3. Build the Graph Transformer Model
    # Uses the hyperparameters (Heads, Layers, Edge Dim) defined in Config
    model = create_model(config)

    # 4. Setup Optimizer
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )

    # 5. Execute Training
    # The 'train_model' function handles the epoch loop and the dynamic
    # computation of edge features required by the Transformer.
    history, best_val_loss, index = train_model(
        model, train_loader, val_loader, optimizer, config, num_train, num_val
    )

    # 6. Save Learning Curves
    plot_metrics(history, config.TRAINING_CURVE_PATH)

    # Return trained model and validation samples for the 3D visualization step
    return model, index, val_graph

### Execution

This cell runs the entire pipeline defined in `main.py`. It trains the model, plots the learning curves, and returns the best/worst case indices needed for the qualitative analysis in the next section.

In [None]:
from main import main


model, index, val_graph = main()

### Qualitative Analysis: 3D Stress Field Visualization

This cell performs the final visual inspection of the model's predictions. We iterate through the specific test cases identified earlier: **Best, Median, and Worst** performance.

**Key Feature: Independent Color Scaling**
The `visualize_side_by_side_independent` function allows the color map to scale dynamically for each plot:
* **Left (Target):** Ground truth from FEM.
* **Right (Prediction):** GNN output.

**Why independent scales?**
This allows us to verify if the model has learned the correct **stress distribution patterns** (topology of the hotspots), even if the absolute **magnitudes** are underestimated (a common issue known as "over-smoothing" in regression tasks).

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import numpy as np
from config import Config
import sys
import importlib

# Import the main module and reload it to ensure latest changes are used
import main
importlib.reload(main)

# --- 1. SIDE-BY-SIDE VISUALIZATION FUNCTION (INDEPENDENT SCALES) ---
def visualize_side_by_side_independent(target_data, pred_data, index):
    """
    Displays Ground Truth and Prediction side-by-side.
    Uses INDEPENDENT color scales for each plot to visualize stress PATTERNS
    effectively, regardless of the difference in absolute magnitude.
    """

    # Check for mesh faces
    if not hasattr(target_data, 'face') or target_data.face is None:
        print("⚠️ No faces detected in mesh data.")
        return

    # --- Extract Mesh Data ---
    x = target_data.x[:, 0].cpu().numpy()
    y = target_data.x[:, 1].cpu().numpy()
    z = target_data.x[:, 2].cpu().numpy()
    faces = target_data.face.cpu().numpy()
    i, j, k = faces[0], faces[1], faces[2]

    val_target = target_data.y[:, 0].cpu().numpy()
    val_pred = pred_data.y[:, 0].cpu().numpy()

    # --- Create Subplots ---
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=(
            f"GROUND TRUTH (Max: {val_target.max():.2f})",
            f"PREDICTION (Max: {val_pred.max():.2f})"
        ),
        horizontal_spacing=0.05
    )

    # --- Trace 1: Left (Ground Truth) ---
    fig.add_trace(
        go.Mesh3d(
            x=x, y=y, z=z, i=i, j=j, k=k,
            intensity=val_target,
            intensitymode='vertex',
            colorscale='Jet',
            # INDEPENDENT SCALE: Based on TARGET min/max
            cmin=val_target.min(),
            cmax=val_target.max(),
            opacity=1.0,
            name='GROUND TRUTH',
            colorbar=dict(title="Target Stress", x=0.45, len=0.8)
        ),
        row=1, col=1
    )

    # --- Trace 2: Right (Prediction) ---
    fig.add_trace(
        go.Mesh3d(
            x=x, y=y, z=z, i=i, j=j, k=k,
            intensity=val_pred,
            intensitymode='vertex',
            colorscale='Jet',
            # INDEPENDENT SCALE: Based on PREDICTION min/max
            cmin=val_pred.min(),
            cmax=val_pred.max(),
            opacity=1.0,
            name='PREDICTION',
            colorbar=dict(title="Pred Stress", x=1.0, len=0.8) # Second legend bar
        ),
        row=1, col=2
    )

    # --- Layout ---
    fig.update_layout(
        title_text=f"Von Mises Stress Comparison - Independent Scales (Mesh #{index})",
        height=600, width=1200,
        margin=dict(r=0, b=0, l=0, t=50),
        scene=dict(aspectmode='data'),
        scene2=dict(aspectmode='data')
    )

    fig.show()

# --- 2. EXECUTION ---
# Specific indices chosen for analysis (Benchmarks and Complex geometries)

for i in index:
  print(f"--- Comparison (Decoupled Scales) for Part #{i} ---")

  # Prepare data
  target_graph = val_graph[i].clone().to(Config.DEVICE)
  model.eval()

  with torch.no_grad():
      # --- TRANSFORMER SPECIFIC STEP ---
      # We must re-compute edge attributes (Relative Position + Distance)
      # because the model expects them in the forward pass.
      src, dst = target_graph.edge_index
      rel_pos = target_graph.pos[src] - target_graph.pos[dst]
      dist = torch.norm(rel_pos, dim=1, keepdim=True)
      edge_attr = torch.cat([rel_pos, dist], dim=1)

      # Inference
      pred_tensor = model(target_graph.x, target_graph.edge_index, edge_attr)

  prediction_graph = target_graph.clone()
  prediction_graph.y = pred_tensor

  # Visualize
  visualize_side_by_side_independent(target_graph, prediction_graph, i)