In [122]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import yaml
import sys
import re

# Add the FLARE directory to path
sys.path.insert(0, '/home/vedantpu/.julia/dev/FLARE-dev.py/')

import pdebench
from pdebench.dataset.utils import load_dataset

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Base directories
PROJDIR = '/home/vedantpu/.julia/dev/FLARE-dev.py'
DATADIR_BASE = '/mnt/hdd1/vedantpu/data/' if os.path.exists('/mnt/hdd1/vedantpu/data/') else os.path.join(PROJDIR, 'data')
CHECKPOINT_BASE = '/home/vedantpu/.julia/dev/FLARE-dev.py/out/pdebench/vis_ckpts'
OUTPUT_BASE = '/home/vedantpu/.julia/dev/FLARE-dev.py/figs/vis_out'

# Datasets to visualize
# datasets_to_visualize = ['elasticity', 'darcy', 'airfoil_steady', 'pipe']
datasets_to_visualize = ['elasticity', 'darcy', 'pipe']


Using device: cuda


In [123]:
#
import math
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange

__all__ = [
    "FLAREModel",
]

#======================================================================#
# Activation Functions
#======================================================================#
ACTIVATIONS = {
    'gelu': nn.GELU(approximate='tanh'),
    'silu': nn.SiLU(),
}

#======================================================================#
# Residual MLP Block
#======================================================================#

class ResidualMLP(nn.Module):
    def __init__(
            self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int = 2,
            act: str = None, input_residual: bool = False, output_residual: bool = False,
        ):
        super().__init__()

        self.num_layers = num_layers
        assert self.num_layers >= -1, f"num_layers must be at least -1. Got {self.num_layers}."

        # nn.Linear if num_layers == -1
        if self.num_layers == -1:
            self.fc = nn.Linear(in_dim, out_dim)
            self.residual = input_residual and output_residual and (in_dim == out_dim)
            return

        self.act = ACTIVATIONS[act] if act else ACTIVATIONS['gelu']
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fcs = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.fc2 = nn.Linear(hidden_dim, out_dim)

        self.input_residual  = input_residual  and (in_dim  == hidden_dim)
        self.output_residual = output_residual and (hidden_dim == out_dim)

    def forward(self, x):

        if self.num_layers == -1:
            x = x + self.fc(x) if self.residual else self.fc(x)
            return x

        x = x + self.act(self.fc1(x)) if self.input_residual else self.act(self.fc1(x))
        for fc in self.fcs:
            x = x + self.act(fc(x))
        x = x + self.fc2(x) if self.output_residual else self.fc2(x)

        return x

#======================================================================#
# FLARE
#======================================================================#
class FLARE(nn.Module):
    def __init__(
        self,
        channel_dim: int,
        num_heads: int = 8,
        num_latents: int = 32,
        attn_scale: float = 1.0,
        act: str = None,
        num_layers_kv_proj: int = 3,
        kv_proj_mlp_ratio: float = 1.0,
    ):
        super().__init__()

        self.channel_dim = channel_dim
        self.num_latents = num_latents
        self.num_heads = channel_dim // 8 if num_heads is None else num_heads
        self.head_dim = self.channel_dim // self.num_heads

        assert self.channel_dim % self.num_heads == 0, f"channel_dim must be divisible by num_heads. Got {self.channel_dim} and {self.num_heads}."
        assert attn_scale > 0.0, f"attn_scale must be greater than 0. Got {attn_scale}."

        self.attn_scale = attn_scale

        self.latent_q = nn.Parameter(torch.empty(self.channel_dim, self.num_latents))
        nn.init.normal_(self.latent_q, mean=0.0, std=0.1)

        self.k_proj, self.v_proj = [
            ResidualMLP(
                in_dim=self.channel_dim,
                hidden_dim=int(self.channel_dim * kv_proj_mlp_ratio),
                out_dim=self.channel_dim,
                num_layers=num_layers_kv_proj,
                act=act,
                input_residual=True,
                output_residual=True,
            ) for _ in range(2)
        ]

        self.out_proj = nn.Linear(self.channel_dim, self.channel_dim)

    def forward(self, x, return_scores: bool = False):

        # x: [B N C]

        q = self.latent_q.view(self.num_heads, self.num_latents, self.head_dim) # [H M D]
        k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) # [B H N D]
        v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)

        #--------------------------------------------#
        if not return_scores:
            q = q.unsqueeze(0).expand(x.size(0), -1, -1, -1) # required for fused attention
            z = F.scaled_dot_product_attention(q, k, v, scale=self.attn_scale)
            y = F.scaled_dot_product_attention(k, q, z, scale=self.attn_scale)
            scores = None
        else:
            # (1) Compute projection weights
            scores = q @ k.transpose(-2, -1) # [B H M N]
            W_encode = F.softmax(scores, dim=-1)
            W_decode = F.softmax(scores.transpose(-2, -1), dim=-1)

            # (2) Project to latent sequence
            z = W_encode @ v # [B H M D]

            # (3) Project back to input space
            y = W_decode @ z # [B H N D]
        #--------------------------------------------#

        y = rearrange(y, 'b h n d -> b n (h d)')
        y = self.out_proj(y)

        return y, scores

#======================================================================#
# FLARE Block
#======================================================================#
class FLAREBlock(nn.Module):
    def __init__(
        self,
        channel_dim: int,
        num_heads: int = None,
        num_latents: int = None,
        attn_scale: float = 1.0,
        act: str = None,
        rmsnorm: bool = False,
        num_layers_kv_proj: int = 3,
        num_layers_ffn: int = 3,
        kv_proj_mlp_ratio: float = 1.0,
        ffn_mlp_ratio: float = 1.0,
    ):
        super().__init__()
        self.ln1 = nn.RMSNorm(channel_dim) if rmsnorm else nn.LayerNorm(channel_dim)
        self.ln2 = nn.RMSNorm(channel_dim) if rmsnorm else nn.LayerNorm(channel_dim)
        self.att = FLARE(
            channel_dim=channel_dim,
            num_heads=num_heads,
            num_latents=num_latents,
            attn_scale=attn_scale,
            act=act,
            num_layers_kv_proj=num_layers_kv_proj,
            kv_proj_mlp_ratio=kv_proj_mlp_ratio,
        )
        self.mlp = ResidualMLP(
            in_dim=channel_dim,
            hidden_dim=int(channel_dim * ffn_mlp_ratio),
            out_dim=channel_dim,
            num_layers=num_layers_ffn,
            act=act,
            input_residual=True,
            output_residual=True,
        )

    def forward(self, x, return_scores: bool = False):
        # x: [B, N, C]

        # x = x + att(ln1(x))
        # x = x + mlp(ln2(x))
        # return x

        _x, scores = self.att(self.ln1(x), return_scores=return_scores)
        x = x + _x
        x = x + self.mlp(self.ln2(x))

        return x, scores

#======================================================================#
# MODEL
#======================================================================#
class FLAREModel(nn.Module):
    def __init__(self,
        in_dim: int,
        out_dim: int,
        channel_dim: int = 64,
        num_blocks: int = 8,
        num_heads: int = None,
        act: str = None,
        rmsnorm: bool = False,
        out_proj_norm: bool = True,
        num_layers_in_out_proj: int = 2,
        #
        attn_scale: float = 1.0,
        num_latents: int = None,
        num_layers_kv_proj: int = 3,
        kv_proj_mlp_ratio: float = 1.0,
        num_layers_ffn: int = 3,
        ffn_mlp_ratio: float = 1.0,
        #
    ):
        super().__init__()

        self.in_proj = ResidualMLP(
            in_dim=in_dim,
            hidden_dim=channel_dim,
            out_dim=channel_dim,
            num_layers=num_layers_in_out_proj,
            act=act,
            input_residual=False,
            output_residual=True,
        )
        
        Norm = nn.RMSNorm if rmsnorm else nn.LayerNorm

        self.out_proj = nn.ModuleDict()
        if out_proj_norm:
            self.out_proj['ln'] = Norm(channel_dim)
        else:
            self.out_proj['ln'] = nn.Identity()
        self.out_proj['mlp'] = ResidualMLP(
            in_dim=channel_dim,
            hidden_dim=channel_dim,
            out_dim=out_dim,
            num_layers=num_layers_in_out_proj,
            act=act,
            input_residual=True,
            output_residual=False,
        )

        self.blocks = nn.ModuleList([
            FLAREBlock(
                channel_dim=channel_dim,
                num_heads=num_heads,
                act=act,
                rmsnorm=rmsnorm,
                attn_scale=attn_scale,
                num_latents=num_latents,
                num_layers_kv_proj=num_layers_kv_proj,
                num_layers_ffn=num_layers_ffn,
                kv_proj_mlp_ratio=kv_proj_mlp_ratio,
                ffn_mlp_ratio=ffn_mlp_ratio,
            )
            for _ in range(num_blocks)
        ])

        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0.)
        elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)):
            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.constant_(m.weight, 1.)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0.)

    def forward(self, x, return_scores: bool = False):
        # x: [B, N, C]

        if return_scores:
            scores = []

        x = self.in_proj(x)
        for block in self.blocks:
            x, score = block(x, return_scores=return_scores)
            if return_scores:
                scores.append(score)

        x = self.out_proj['mlp'](self.out_proj['ln'](x))

        return (x, scores) if return_scores else x

#======================================================================#
#

In [124]:
def load_model_from_checkpoint(checkpoint_dir, metadata, device):
    """Load FLARE model from checkpoint directory."""
    # Load config
    config_path = os.path.join(checkpoint_dir, 'config.yaml')
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Find latest checkpoint
    ckpt_dirs = [d for d in os.listdir(checkpoint_dir) if d.startswith('ckpt')]
    if not ckpt_dirs:
        raise ValueError(f"No checkpoint directories found in {checkpoint_dir}")
    latest_ckpt = sorted(ckpt_dirs, key=lambda x: int(x.replace('ckpt', '')))[-1]
    model_path = os.path.join(checkpoint_dir, latest_ckpt, 'model.pt')
    
    print(f"Loading model from: {model_path}")
    
    # Create model
    model = FLAREModel(
        in_dim=metadata['c_in'],
        out_dim=metadata['c_out'],
        channel_dim=config['channel_dim'],
        num_blocks=config['num_blocks'],
        num_latents=config['num_latents'],
        num_heads=config['num_heads'],
        act=config.get('act', None),
        num_layers_kv_proj=config.get('num_layers_kv_proj', 3),
        num_layers_ffn=config.get('num_layers_mlp', 3),
        num_layers_in_out_proj=config.get('num_layers_in_out_proj', 2),
        ffn_mlp_ratio=config.get('mlp_ratio', 1.0),
        kv_proj_mlp_ratio=config.get('kv_proj_ratio', 1.0),
        # in_out_proj_ratio=config.get('in_out_proj_ratio', 1.0),
        out_proj_norm=config.get('out_proj_ln', True),
    )
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract model state - check if checkpoint has 'model_state' key
    if 'model_state' in checkpoint:
        model_state = checkpoint['model_state']
    else:
        # If no 'model_state' key, assume the checkpoint itself is the state dict
        model_state = checkpoint
    
    # Handle DDP wrapped models
    if any(k.startswith('_orig_mod.') for k in model_state.keys()):
        model_state = {k.replace('_orig_mod.', ''): v for k, v in model_state.items()}
    
    model.load_state_dict(model_state)
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully! Parameters: {sum(p.numel() for p in model.parameters()):,}")
    return model, latest_ckpt


In [125]:
def calculate_rel_l2_error(model, test_data, metadata, device, batch_size=1):
    """Calculate RelL2 error for entire test dataset."""
    model.eval()
    rel_l2_loss = pdebench.RelL2Loss()
    y_normalizer = metadata['y_normalizer'].to(device)
    
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    total_error = 0.0
    total_samples = 0
    
    print(f"Calculating RelL2 error on {len(test_data)} test samples...")
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            x, y = batch[0].to(device), batch[1].to(device)
            
            # Forward pass
            yh = model(x)
            
            # Decode from normalized space
            yh_decoded = y_normalizer.decode(yh)
            y_decoded = y_normalizer.decode(y)
            
            # Calculate error
            error = rel_l2_loss(yh_decoded, y_decoded)
            
            batch_size_actual = x.shape[0]
            total_error += error.item() * batch_size_actual
            total_samples += batch_size_actual
            
            if (i + 1) % 10 == 0:
                print(f"  Processed {i+1}/{len(test_loader)} batches...")
    
    mean_error = total_error / total_samples
    print(f"\nRelL2 Error on test set: {mean_error:.6e}")
    return mean_error


In [126]:
def visualize_elasticity(x, y_true, y_pred, sample_idx=0):
    """Visualize elasticity dataset (unstructured mesh)."""
    # Enable LaTeX rendering
    plt.rcParams['text.usetex'] = False
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['font.family'] = 'serif'
    
    # Get sample data
    x_sample = x[sample_idx].cpu().numpy()  # [N, 2]
    y_true_sample = y_true[sample_idx].cpu().numpy().squeeze()  # [N]
    y_pred_sample = y_pred[sample_idx].cpu().numpy().squeeze()  # [N]
    error_sample = y_true_sample - y_pred_sample
    
    # Determine colorbar limits for first two plots (Ground Truth and Prediction)
    vmin_gt_pred = min(y_true_sample.min(), y_pred_sample.min())
    vmax_gt_pred = max(y_true_sample.max(), y_pred_sample.max())
    
    # Layout: 3 equal plot axes; add dedicated colorbar axes manually
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    fig.subplots_adjust(wspace=0.35)
    ax0, ax1, ax2 = axes

    # Tighten ONLY gap between plot 1 and 2 (keep 2↔3 spacing unchanged)
    p0 = ax0.get_position()
    p1 = ax1.get_position()
    g12 = p1.x0 - p0.x1
    scale12 = 0.6  # keep 60% of original gap
    d = max(0.0, (g12 - g12 * scale12) / 2.0)
    if d > 0:
        ax0.set_position([p0.x0, p0.y0, p0.width + d, p0.height])
        ax1.set_position([p1.x0 - d, p1.y0, p1.width + d, p1.height])

    # Ground Truth
    scatter1 = ax0.scatter(x_sample[:, 0], x_sample[:, 1], c=y_true_sample,
                           cmap='RdBu_r', s=75, vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax0.set_title('Ground Truth', fontsize=24, fontfamily='serif')
    ax0.axis('off')
    ax0.set_aspect('equal')

    # Prediction
    scatter2 = ax1.scatter(x_sample[:, 0], x_sample[:, 1], c=y_pred_sample,
                           cmap='RdBu_r', s=75, vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax1.set_title('Prediction', fontsize=24, fontfamily='serif')
    ax1.axis('off')
    ax1.set_aspect('equal')

    # Error
    scatter3 = ax2.scatter(x_sample[:, 0], x_sample[:, 1], c=error_sample,
                           cmap='RdBu_r', s=75)
    ax2.set_title('Error', fontsize=24, fontfamily='serif')
    ax2.axis('off')
    ax2.set_aspect('equal')

    # Tick formatter: fixed-point, one decimal (e.g. 0.0)
    from matplotlib.ticker import FuncFormatter

    def format_func(value, pos):
        v = 0.0 if abs(value) < 1e-12 else value
        s = f"{v:.1f}"
        return "0.0" if s == "-0.0" else s

    formatter = FuncFormatter(format_func)

    # Helper: colorbar axis same height as plot axis, tight to the right
    def add_cbar_next_to(ax_left, mappable, *, x_max=None):
        pad = 0.004
        width = 0.014
        pos = ax_left.get_position()
        x = pos.x1 + pad
        if x_max is not None:
            x = min(x, x_max)
        x = min(x, 0.985 - width)
        cax = fig.add_axes([x, pos.y0, width, pos.height])
        cbar = fig.colorbar(mappable, cax=cax)
        cbar.ax.yaxis.set_major_formatter(formatter)
        cbar.ax.yaxis.get_offset_text().set_visible(False)
        cbar.ax.tick_params(labelsize=10)
        cbar.update_ticks()
        return cbar

    # Shared cbar tight to the right of plot 2, but don't overlap plot 3
    pos2 = ax2.get_position()
    x_max_shared = pos2.x0 - 0.004 - 0.014
    _ = add_cbar_next_to(ax1, scatter1, x_max=x_max_shared)

    # Error cbar tight to the right of plot 3
    _ = add_cbar_next_to(ax2, scatter3)

    return fig


In [127]:
def visualize_structured_2d(x, y_true, y_pred, metadata, sample_idx=0):
    """Visualize structured 2D datasets (darcy)."""
    # Enable LaTeX rendering
    plt.rcParams['text.usetex'] = False
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['font.family'] = 'serif'
    
    H = metadata['H']
    W = metadata['W']
    
    # Get sample data - x is [batch, N, c_in]
    x_sample = x[sample_idx].cpu().numpy()  # [N, c_in]
    y_true_sample = y_true[sample_idx].cpu().numpy().squeeze()  # [N]
    y_pred_sample = y_pred[sample_idx].cpu().numpy().squeeze()  # [N]
    error_sample = y_true_sample - y_pred_sample
    
    # Reshape to 2D grid
    y_true_2d = y_true_sample.reshape(H, W)
    y_pred_2d = y_pred_sample.reshape(H, W)
    error_2d = error_sample.reshape(H, W)
    
    # Extract coordinates
    # For darcy: input is [pos_x, pos_y, coeff], so first 2 dims are positions
    if x_sample.shape[1] >= 2:
        x_coords = x_sample[:, 0].reshape(H, W)
        y_coords = x_sample[:, 1].reshape(H, W)
        x_min, x_max = x_coords.min(), x_coords.max()
        y_min, y_max = y_coords.min(), y_coords.max()
    else:
        # Fallback: use grid indices
        x_coords = np.linspace(0, 1, W)
        y_coords = np.linspace(0, 1, H)
        x_coords, y_coords = np.meshgrid(x_coords, y_coords)
        x_min, x_max = 0, 1
        y_min, y_max = 0, 1
    
    # Layout: 3 equal plot axes; add dedicated colorbar axes manually
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    fig.subplots_adjust(wspace=0.35)
    ax0, ax1, ax2 = axes

    # Tighten ONLY gap between plot 1 and 2 (keep 2↔3 spacing unchanged)
    p0 = ax0.get_position()
    p1 = ax1.get_position()
    g12 = p1.x0 - p0.x1
    scale12 = 0.6  # keep 60% of original gap
    d = max(0.0, (g12 - g12 * scale12) / 2.0)
    if d > 0:
        ax0.set_position([p0.x0, p0.y0, p0.width + d, p0.height])
        ax1.set_position([p1.x0 - d, p1.y0, p1.width + d, p1.height])

    # Determine colorbar limits for first two plots (Ground Truth and Prediction)
    vmin_gt_pred = min(y_true_2d.min(), y_pred_2d.min())
    vmax_gt_pred = max(y_true_2d.max(), y_pred_2d.max())

    # Ground Truth
    im1 = ax0.imshow(y_true_2d, cmap='RdBu_r', origin='lower',
                     extent=[x_min, x_max, y_min, y_max],
                     aspect='auto', vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax0.set_title('Ground Truth', fontsize=24, fontfamily='serif')
    ax0.axis('off')

    # Prediction
    im2 = ax1.imshow(y_pred_2d, cmap='RdBu_r', origin='lower',
                     extent=[x_min, x_max, y_min, y_max],
                     aspect='auto', vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax1.set_title('Prediction', fontsize=24, fontfamily='serif')
    ax1.axis('off')

    # Error
    im3 = ax2.imshow(error_2d, cmap='RdBu_r', origin='lower',
                     extent=[x_min, x_max, y_min, y_max],
                     aspect='auto')
    ax2.set_title('Error', fontsize=24, fontfamily='serif')
    ax2.axis('off')

    # Tick formatter: fixed-point, one decimal (e.g. 0.0)
    from matplotlib.ticker import FuncFormatter

    def format_func(value, pos):
        v = 0.0 if abs(value) < 1e-12 else value
        s = f"{v:.1f}"
        return "0.0" if s == "-0.0" else s

    formatter = FuncFormatter(format_func)

    def add_cbar_next_to(ax_left, mappable, *, x_max=None):
        pad = 0.004
        width = 0.014
        pos = ax_left.get_position()
        x = pos.x1 + pad
        if x_max is not None:
            x = min(x, x_max)
        x = min(x, 0.985 - width)
        cax = fig.add_axes([x, pos.y0, width, pos.height])
        cbar = fig.colorbar(mappable, cax=cax)
        cbar.ax.yaxis.set_major_formatter(formatter)
        cbar.ax.yaxis.get_offset_text().set_visible(False)
        cbar.ax.tick_params(labelsize=10)
        cbar.update_ticks()
        return cbar

    # Shared cbar tight to the right of plot 2, but don't overlap plot 3
    pos2 = ax2.get_position()
    x_max_shared = pos2.x0 - 0.004 - 0.014
    _ = add_cbar_next_to(ax1, im1, x_max=x_max_shared)

    # Error cbar tight to the right of plot 3
    _ = add_cbar_next_to(ax2, im3)

    return fig

def visualize_original_coords(x, y_true, y_pred, metadata, sample_idx=0, device='cpu'):
    """Visualize datasets with original X, Y coordinates (pipe).
    
    For pipe: The coordinates need to be decoded from normalized space.
    """
    # Enable LaTeX rendering
    plt.rcParams['text.usetex'] = False
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['font.family'] = 'serif'
    
    # Get sample data - x is [batch, N, c_in] where batch=1
    # Remove batch dimension to get [N, c_in]
    if isinstance(x, torch.Tensor):
        if x.dim() == 3:
            x_sample = x[sample_idx].cpu().numpy()  # [N, c_in]
        else:
            x_sample = x.cpu().numpy()  # [N, c_in] (already no batch dim)
    else:
        # Already numpy array
        if x.ndim == 3:
            x_sample = x[sample_idx]  # [N, c_in]
        else:
            x_sample = x  # [N, c_in]
    
    y_true_sample = y_true[sample_idx].cpu().numpy().squeeze()  # [N]
    y_pred_sample = y_pred[sample_idx].cpu().numpy().squeeze()  # [N]
    error_sample = y_true_sample - y_pred_sample
    
    # Ensure x_sample has correct shape [N, c_in] - remove any extra batch dimensions
    while x_sample.ndim > 2:
        x_sample = x_sample.squeeze(0)
    
    if x_sample.ndim != 2:
        raise ValueError(f"x_sample should be 2D [N, c_in], got shape: {x_sample.shape}")
    
    # Extract original coordinates
    # For pipe: x_normalizer is UnitGaussianNormalizer, so we need to decode
    if 'x_normalizer' in metadata and metadata['x_normalizer'] is not None:
        x_normalizer = metadata['x_normalizer']
        
        # Check if it's an IdentityNormalizer (no decoding needed)
        if type(x_normalizer).__name__ == 'IdentityNormalizer' or not hasattr(x_normalizer, 'decode'):
            # Identity normalizer - coordinates are already in original space
            x_sample_decoded = x_sample
        else:
            # UnitGaussianNormalizer or similar - need to decode
            # Move normalizer to device if it has parameters
            if hasattr(x_normalizer, 'mean') or hasattr(x_normalizer, 'to'):
                try:
                    x_normalizer = x_normalizer.to(device)
                except:
                    pass
            
            # Decode coordinates if they were normalized
            # x_sample is [N, c_in], convert to tensor
            x_sample_tensor = torch.tensor(x_sample, dtype=torch.float32).to(device)
            x_sample_decoded = x_normalizer.decode(x_sample_tensor).cpu().numpy()
            
            # Remove any extra dimensions
            while x_sample_decoded.ndim > 2:
                x_sample_decoded = x_sample_decoded.squeeze(0)
            
            # Ensure shape is [N, c_in]
            if x_sample_decoded.ndim == 1:
                x_sample_decoded = x_sample_decoded.reshape(-1, x_sample.shape[1])
    else:
        x_sample_decoded = x_sample
    
    # Ensure x_sample_decoded has shape [N, c_in] - remove any batch dimensions
    while x_sample_decoded.ndim > 2:
        x_sample_decoded = x_sample_decoded.squeeze(0)
    
    if x_sample_decoded.ndim != 2:
        raise ValueError(f"x_sample_decoded should be 2D [N, c_in], got shape: {x_sample_decoded.shape}")
    
    # Check if shape is [N, c_in] or [c_in, N] and fix if needed
    if x_sample_decoded.shape[0] < x_sample_decoded.shape[1] and x_sample_decoded.shape[0] == metadata.get('c_in', 2):
        # Likely transposed: [c_in, N] -> transpose to [N, c_in]
        x_sample_decoded = x_sample_decoded.T
    
    # Extract X, Y coordinates (first 2 dimensions)
    # x_sample_decoded is [N, c_in], so [:, 0] gives [N] for X, [:, 1] gives [N] for Y
    x_coords = x_sample_decoded[:, 0].flatten()
    y_coords = x_sample_decoded[:, 1].flatten()
    
    # Ensure shapes match
    if len(x_coords) != len(y_true_sample):
        raise ValueError(f"Coordinate shape mismatch: x_coords={x_coords.shape} (len={len(x_coords)}), y_coords={y_coords.shape} (len={len(y_coords)}), y_true={y_true_sample.shape} (len={len(y_true_sample)}), x_sample_decoded shape={x_sample_decoded.shape}")
    
    # Determine colorbar limits for first two plots (Ground Truth and Prediction)
    vmin_gt_pred = min(y_true_sample.min(), y_pred_sample.min())
    vmax_gt_pred = max(y_true_sample.max(), y_pred_sample.max())
    
    # Layout: 3 equal plot axes; add dedicated colorbar axes manually
    # Make the figure taller so each plot area can be square (adds vertical whitespace)
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    fig.subplots_adjust(wspace=0.35)
    ax0, ax1, ax2 = axes

    # Tighten ONLY gap between plot 1 and 2 (keep 2↔3 spacing unchanged)
    p0 = ax0.get_position()
    p1 = ax1.get_position()
    g12 = p1.x0 - p0.x1
    scale12 = 0.6  # keep 60% of original gap
    d = max(0.0, (g12 - g12 * scale12) / 2.0)
    if d > 0:
        ax0.set_position([p0.x0, p0.y0, p0.width + d, p0.height])
        ax1.set_position([p1.x0 - d, p1.y0, p1.width + d, p1.height])

    # Ensure each axes box is square
    for ax in (ax0, ax1, ax2):
        ax.set_aspect('equal', adjustable='box')
        try:
            ax.set_box_aspect(1)
        except Exception:
            pass

    # Ground Truth
    scatter1 = ax0.scatter(x_coords, y_coords, c=y_true_sample,
                           cmap='RdBu_r', s=1, vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax0.set_title('Ground Truth', fontsize=24, fontfamily='serif')
    ax0.axis('off')
    ax0.set_aspect('equal')

    # Prediction
    scatter2 = ax1.scatter(x_coords, y_coords, c=y_pred_sample,
                           cmap='RdBu_r', s=1, vmin=vmin_gt_pred, vmax=vmax_gt_pred)
    ax1.set_title('Prediction', fontsize=24, fontfamily='serif')
    ax1.axis('off')
    ax1.set_aspect('equal')

    # Error
    scatter3 = ax2.scatter(x_coords, y_coords, c=error_sample,
                           cmap='RdBu_r', s=1)
    ax2.set_title('Error', fontsize=24, fontfamily='serif')
    ax2.axis('off')
    ax2.set_aspect('equal')

    # Tick formatter: fixed-point, one decimal (e.g. 0.0)
    from matplotlib.ticker import FuncFormatter

    def format_func(value, pos):
        v = 0.0 if abs(value) < 1e-12 else value
        s = f"{v:.1f}"
        return "0.0" if s == "-0.0" else s

    formatter = FuncFormatter(format_func)

    def add_cbar_next_to(ax_left, mappable, *, x_max=None):
        pad = 0.004
        width = 0.014
        pos = ax_left.get_position()
        x = pos.x1 + pad
        if x_max is not None:
            x = min(x, x_max)
        x = min(x, 0.985 - width)
        cax = fig.add_axes([x, pos.y0, width, pos.height])
        cbar = fig.colorbar(mappable, cax=cax)
        cbar.ax.yaxis.set_major_formatter(formatter)
        cbar.ax.yaxis.get_offset_text().set_visible(False)
        cbar.ax.tick_params(labelsize=10)
        cbar.update_ticks()
        return cbar

    # Shared cbar tight to the right of plot 2, but don't overlap plot 3
    pos2 = ax2.get_position()
    x_max_shared = pos2.x0 - 0.004 - 0.014
    _ = add_cbar_next_to(ax1, scatter1, x_max=x_max_shared)

    # Error cbar tight to the right of plot 3
    _ = add_cbar_next_to(ax2, scatter3)

    return fig


In [128]:
# Main visualization loop
results = {}

for dataset_name in datasets_to_visualize:
    print("\n" + "="*80)
    print(f"Processing dataset: {dataset_name}")
    print("="*80)
    
    # Load dataset
    print(f"\nLoading dataset: {dataset_name}")
    train_data, test_data, metadata = load_dataset(
        dataset_name, DATADIR_BASE, PROJDIR, mesh=False
    )
    print(f"Train samples: {len(train_data)}, Test samples: {len(test_data)}")
    print(f"Input dim: {metadata['c_in']}, Output dim: {metadata['c_out']}")
    
    # Find checkpoint directory
    checkpoint_dir = None
    for item in os.listdir(CHECKPOINT_BASE):
        if item.startswith(f'model_2_{dataset_name}'):
            checkpoint_dir = os.path.join(CHECKPOINT_BASE, item)
            break
    
    if checkpoint_dir is None:
        print(f"Warning: No checkpoint found for {dataset_name}, skipping...")
        continue
    
    print(f"Checkpoint directory: {checkpoint_dir}")
    
    # Load model
    model, ckpt_name = load_model_from_checkpoint(checkpoint_dir, metadata, device)
    
    # Calculate RelL2 error on test set
    rel_l2_error = calculate_rel_l2_error(model, test_data, metadata, device, batch_size=1)
    results[dataset_name] = {'rel_l2_error': rel_l2_error}
    
    # For elasticity, save two samples (0 and 1). Others: just sample 0.
    sample_indices = [0, 1] if dataset_name == 'elasticity' else [0]

    for sample_idx in sample_indices:
        # Get a sample for visualization
        x_sample_orig, y_true_sample = test_data[sample_idx]  # Original coordinates from dataset
        x_sample = x_sample_orig.unsqueeze(0).to(device)  # Add batch dimension for model

        # Get prediction
        with torch.no_grad():
            y_pred_normalized = model(x_sample)

        # Decode from normalized space
        y_normalizer = metadata['y_normalizer'].to(device)
        y_true_decoded = y_normalizer.decode(y_true_sample.unsqueeze(0).to(device))
        y_pred_decoded = y_normalizer.decode(y_pred_normalized)

        # Visualize
        print(f"\nVisualizing sample {sample_idx}...")
        if dataset_name == 'elasticity':
            fig = visualize_elasticity(x_sample, y_true_decoded, y_pred_decoded, sample_idx=0)
        elif dataset_name == 'darcy':
            fig = visualize_structured_2d(x_sample, y_true_decoded, y_pred_decoded, metadata, sample_idx=0)
        elif dataset_name == 'airfoil_steady':
            print(f"Skipping visualization for {dataset_name} (not implemented yet)")
            continue
        else:  # pipe - use original coordinates
            fig = visualize_original_coords(
                x_sample_orig.unsqueeze(0),
                y_true_decoded,
                y_pred_decoded,
                metadata,
                sample_idx=0,
                device=device,
            )

        # Save figure
        os.makedirs(OUTPUT_BASE, exist_ok=True)
        if dataset_name == 'elasticity' and sample_idx == 1:
            out_name = 'elasticity1_visualization.png'
        else:
            out_name = f'{dataset_name}_visualization.png'
        fig_path = os.path.join(OUTPUT_BASE, out_name)
        fig.savefig(fig_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to: {fig_path}")
        plt.close(fig)

    # Clean up
    del model
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

print("\n" + "="*80)
print("Summary of RelL2 Errors:")
print("="*80)
for dataset_name, result in results.items():
    print(f"{dataset_name:20s}: {result['rel_l2_error']:.6e}")



Processing dataset: elasticity

Loading dataset: elasticity
Train samples: 1000, Test samples: 200
Input dim: 2, Output dim: 1
Checkpoint directory: /home/vedantpu/.julia/dev/FLARE-dev.py/out/pdebench/vis_ckpts/model_2_elasticity_B_8_C_64_M_64_H_8
Loading model from: /home/vedantpu/.julia/dev/FLARE-dev.py/out/pdebench/vis_ckpts/model_2_elasticity_B_8_C_64_M_64_H_8/ckpt10/model.pt
Model loaded successfully! Parameters: 592,641
Calculating RelL2 error on 200 test samples...
  Processed 10/200 batches...
  Processed 20/200 batches...
  Processed 30/200 batches...
  Processed 40/200 batches...
  Processed 50/200 batches...
  Processed 60/200 batches...
  Processed 70/200 batches...
  Processed 80/200 batches...
  Processed 90/200 batches...
  Processed 100/200 batches...
  Processed 110/200 batches...
  Processed 120/200 batches...
  Processed 130/200 batches...
  Processed 140/200 batches...
  Processed 150/200 batches...
  Processed 160/200 batches...
  Processed 170/200 batches...
  Pr