In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
import numpy as np
from collections import defaultdict
import re

# Import model building functions
from train_ff import build_model, INPUT_DIMS, EXP_OUTPUT_DIMS
from probabilistic import predict

In [2]:
# Configuration
OUTPUTS_DIR = Path("outputs_final") 
OUTPUTS_PROB = Path("outputs") 
DATA_DIR = Path("../Data") 

# Convert to absolute paths for clarity
OUTPUTS_DIR = OUTPUTS_DIR.resolve()
DATA_DIR = DATA_DIR.resolve()

# Dataset configuration (from train_ff.py)
DATASET_TO_FILE = {
    "sphere": "sphere_dataset.pt",
    "disk": "disk_dataset.pt",
    "so3": "so3_dataset.pt",
    "cs": "cs_dataset.pt",
    "protein": "protein_dataset.pt",
}

INPUT_DIMS = {
    "sphere": 3,
    "disk": 2,
    "so3": 9,
    "cs": 9,
    "protein": 16,
}

EXP_OUTPUT_DIMS = {
    "sphere": 3,
    "so3": 3,
    "cs": 3,
    "protein": 6,
}

In [23]:
import torch

def _flatten_batch(X: torch.Tensor, d_last: int):
    """
    X: [B,d] or [B,S,d]
    Returns X_flat: [N,d], where N=B or B*S
    """
    if X.dim() == 3:
        B, S, D = X.shape
        assert D == d_last, f"Expected last dim {d_last}, got {D}"
        return X.reshape(B * S, D)
    else:
        B, D = X.shape
        assert D == d_last, f"Expected last dim {d_last}, got {D}"
        return X


def sphere_distance_stats(X: torch.Tensor):
    """
    Distance to unit sphere: d(x,S^2) = | ||x|| - 1 |
    X: [B,3] or [B,S,3]
    Returns: (mean_dist, max_dist)
    """
    X_flat = _flatten_batch(X, d_last=3)
    dist = torch.abs(torch.linalg.norm(X_flat, dim=-1) - 1.0)
    return dist.mean().item(), dist.max().item()


def disk_distance_stats(X: torch.Tensor):
    """
    Distance to closed unit disk: d(x,D) = max(0, ||x|| - 1)
    X: [B,2] or [B,S,2]
    Returns: (mean_dist, max_dist)
    """
    X_flat = _flatten_batch(X, d_last=2)
    dist = torch.clamp(torch.linalg.norm(X_flat, dim=-1) - 1.0, min=0.0)
    return dist.mean().item(), dist.max().item()


def so3_distance_stats(R9: torch.Tensor):
    """
    "Distance-like" constraint violations for SO(3), reporting two metrics:
      - orth_dist = || R R^T - I ||_F
      - det_dist  = |det(R) - 1|
    R9: [B,9] or [B,S,9] (row-major flatten)
    Returns:
      (orth_mean, orth_max), (det_mean, det_max)
    """
    R9_flat = _flatten_batch(R9, d_last=9)
    R = R9_flat.reshape(-1, 3, 3)

    I = torch.eye(3, device=R.device, dtype=R.dtype).expand(R.shape[0], 3, 3)
    orth_dist = torch.linalg.norm(R @ R.transpose(-1, -2) - I, dim=(-2, -1))
    det_dist = torch.abs(torch.linalg.det(R) - 1.0)

    sum_dist = orth_dist + det_dist

    return ((sum_dist.mean().item(), sum_dist.max().item()), 
            (orth_dist.mean().item(), orth_dist.max().item()), 
            (det_dist.mean().item(), det_dist.max().item()),)


def se3_distance_stats(G16: torch.Tensor):
    """
    "Distance-like" constraint violations for SE(3), reporting:
      - orth_dist = || R R^T - I ||_F   (R is top-left 3x3)
      - det_dist  = |det(R) - 1|
      - last_row_dist = || last_row - [0,0,0,1] ||_inf
    G16: [B,16] (row-major flatten)
    Returns:
      (orth_mean, orth_max), (det_mean, det_max), (last_mean, last_max)
    """
    assert G16.dim() == 2 and G16.shape[1] == 16, "Expected G16 shaped [B,16]"
    B = G16.shape[0]
    G = G16.reshape(B, 4, 4)

    R = G[:, :3, :3]
    I = torch.eye(3, device=G.device, dtype=G.dtype).expand(B, 3, 3)
    orth_dist = torch.linalg.norm(R @ R.transpose(-1, -2) - I, dim=(-2, -1))
    det_dist = torch.abs(torch.linalg.det(R) - 1.0)

    last = G[:, 3, :]
    target = torch.tensor([0., 0., 0., 1.], device=G.device, dtype=G.dtype).expand_as(last)
    last_dist = torch.max(torch.abs(last - target), dim=-1).values  # L_inf

    sum_dist = orth_dist + det_dist + last_dist
    return ((sum_dist.mean().item(), sum_dist.max().item()),
            (orth_dist.mean().item(), orth_dist.max().item()),
            (det_dist.mean().item(), det_dist.max().item()),
            (last_dist.mean().item(), last_dist.max().item()))

import torch

def constraint_satisfaction(Y_pred: torch.Tensor, dataset: str):
    """
    Returns constraint violation stats as plain Python floats so you can log them.

    Args
    ----
    Y_pred:
        Model outputs.
        Expected shapes by dataset:
          - "sphere":   [B,3] or [B,S,3]
          - "disk":     [B,2] or [B,S,2]
          - "so3":      [B,9] or [B,S,9]   (row-major 3x3)
          - "cs":       [B,9] or [B,S,9]   (row-major 3x3)  (SO(3))
          - "protein":  [B,16]             (row-major 4x4)  (SE(3))

    dataset:
        One of {"sphere","so3","disk","protein","cs"}.

    Returns
    -------
    const: dict
        A small dict of mean/max distances for the relevant constraints.
        Example keys:
          - sphere/disk: {"mean_dist": ..., "max_dist": ...}
          - so3/cs:      {"mean_orth_dist": ..., "max_orth_dist": ..., "mean_det_dist": ..., "max_det_dist": ...}
          - protein:     adds last-row stats as well.
    """
    ds = dataset.lower().strip()

    if ds == "sphere":
        mean_d, max_d = sphere_distance_stats(Y_pred)
        return {"mean_dist": mean_d, "max_dist": max_d}

    if ds == "disk":
        mean_d, max_d = disk_distance_stats(Y_pred)
        return {"mean_dist": mean_d, "max_dist": max_d}

    if ds in ("so3", "cs"):
        (m_sum, M_sum), (m_orth, M_orth), (m_det, M_det) = so3_distance_stats(Y_pred)
        return {
            "mean_sum_dist": m_sum,
            "max_sum_dist": M_sum,
            "mean_orth_dist": m_orth,
            "max_orth_dist": M_orth,
            "mean_det_dist": m_det,
            "max_det_dist": M_det,
        }

    if ds == "protein":
        (m_sum, M_sum), (m_orth, M_orth), (m_det, M_det), (m_last, M_last) = se3_distance_stats(Y_pred)
        return {
            "mean_sum_dist": m_sum,
            "max_sum_dist": M_sum,
            "mean_orth_dist": m_orth,
            "max_orth_dist": M_orth,
            "mean_det_dist": m_det,
            "max_det_dist": M_det,
            "mean_last_row_dist": m_last,
            "max_last_row_dist": M_last,
        }

    raise ValueError(f"Unknown dataset='{dataset}'. Expected one of: sphere, so3, disk, protein, cs.")

def _format_constraints(const):
    """
    const: dict returned by constraint_satisfaction(...)
    Returns a pretty multi-line string.
    """
    if const is None:
        return "  Constraints: (none)"

    # sphere/disk
    if "mean_dist" in const and "max_dist" in const:
        return (
            "  Constraints:\n"
            f"    dist-to-manifold  mean={const['mean_dist']:.3e}  max={const['max_dist']:.3e}"
        )

    lines = ["  Constraints:"]
    # SO(3) / CS
    if "mean_orth_dist" in const:
        lines.append(
            f"    ||RR^T-I||_F      mean={const['mean_orth_dist']:.3e}  max={const['max_orth_dist']:.3e}"
        )
    if "mean_det_dist" in const:
        lines.append(
            f"    |det(R)-1|        mean={const['mean_det_dist']:.3e}  max={const['max_det_dist']:.3e}"
        )
    # SE(3) (protein)
    if "mean_last_row_dist" in const:
        lines.append(
            f"    last-row (L_inf)  mean={const['mean_last_row_dist']:.3e}  max={const['max_last_row_dist']:.3e}"
        )
    if "mean_sum_dist" in const:
        lines.append(
            f" SUM (all terms) mean={const['mean_sum_dist']:.3e} max={const['max_sum_dist']:.3e}"
        )

    return "\n".join(lines)


def _print_eval_summary(dataset, model_type, best_val_loss, best_dir, eval_results):
    """
    eval_results is expected to contain:
      - "test_loss": float
      - "constraints": dict  (from constraint_satisfaction)
    """
    test_loss = eval_results.get("test_loss", None)
    const = eval_results.get("constraints", None)

    print(f"\n{'-'*60}")
    print(f"RESULT: {dataset} / {model_type}")
    print(f"{'-'*60}")
    print(f"  Best val loss: {best_val_loss:.6e}")
    print(f"  Run dir:       {best_dir}")
    if test_loss is None:
        print("  Test loss:     (missing)")
    else:
        print(f"  Test loss:     {test_loss:.6e}")
    print(_format_constraints(const))
    print(f"{'-'*60}\n")


In [4]:
def get_best_val_loss(run_dir: Path) -> Optional[float]:
    """
    Extract best validation loss from meta.pt or meta.json.
    
    Returns the best validation loss (lowest = best performance).
    """
    meta_pt = run_dir / "meta.pt"
    meta_json = run_dir / "meta.json"
    
    # if meta_pt.exists():
    #     try:
    meta = torch.load(str(meta_pt), map_location="cpu", weights_only=False)
    logs = meta.get("logs", {})
    best_val = logs.get("best_val")
    if best_val is not None:
        return float(best_val)
    
    return None


def find_best_run(dataset: str, model_type: str, outputs_dir: Path, verbose: bool = True) -> Optional[Tuple[Path, float]]:
    """
    Find the run directory with the best validation loss for a given dataset/model_type.
    
    Handles various directory structures, including:
    - outputs/{dataset}/{model_type}/depth{}/out{}/lr{}_wd{}/seed{}/
    - outputs/{dataset}/{model_type}/internalTrue/depth{}/out{}/lr{}_wd{}/seed{}/
    - outputs/{dataset}/probabilistic/depth{}/out{}/anchors{}/lr{}_wd{}/seed{}/
    
    Uses recursive search to find all model.pt files regardless of intermediate directory structure.
    """
    dataset_dir = outputs_dir / dataset / model_type
    
    if not dataset_dir.exists():
        if verbose:
            print(f"  Directory does not exist: {dataset_dir}")
        return None
    
    best_val = None
    best_dir = None
    all_runs = []  # Track all runs for debugging

    # Only search depth* directories that are direct children of dataset_dir
    depth_dirs = sorted([p for p in dataset_dir.iterdir() if p.is_dir() and p.name.startswith("depth")])

    
    # Recursively search for all run directories (those containing model.pt)
    for depth_dir in depth_dirs:
        for model_pt_path in depth_dir.rglob("model.pt"):
            run_dir = model_pt_path.parent
            val_loss = get_best_val_loss(run_dir)

            if val_loss is not None:
                all_runs.append((val_loss, run_dir))
                if best_val is None or val_loss < best_val:
                    best_val = val_loss
                    best_dir = run_dir
    
    if best_dir is None:
        return None
    
    return best_dir, best_val

In [5]:
find_best_run("sphere", "projected/internalFalse", OUTPUTS_DIR) 

(PosixPath('/projects/gtml/Constrained Networks/src/Models/outputs_final/sphere/projected/internalFalse/depth8/out3/lr0.001_wd0/seed0'),
 0.004208169380823771)

In [6]:
def load_model_and_meta(run_dir: Path, dataset: str, model_type: str, device: str = "cpu"):
    """Load model and metadata from a run directory"""
    model_path = run_dir / "model.pt"
    meta_path = run_dir / "meta.pt"
    
    if not model_path.exists():
        raise FileNotFoundError(f"Model not found: {model_path}")
    
    # Load metadata
    if meta_path.exists():
        meta = torch.load(str(meta_path), map_location="cpu", weights_only=False)
    else:
        meta = {}
    
    hparams = meta.get("hparams", {})
    
    # Extract hyperparameters from saved metadata
    depth = hparams.get("depth", 2)
    dropout = hparams.get("dropout", 0.0)
    residual = hparams.get("residual", True)
    dt = hparams.get("dt", 1.0)
    use_internal = hparams.get("use_internal", False)
    
    # For CS dataset transformers, nhead and d_hid are needed but might not be in HParams
    # Check if they're in the metadata directly (they might be saved separately)
    nhead = meta.get("nhead") or hparams.get("nhead", 3)
    d_hid = meta.get("d_hid") or hparams.get("d_hid", 2048)
    
    # Print hyperparameters being used (for verification)
    # print(f"  Building model with: depth={depth}, dropout={dropout}, residual={residual}, "
    #       f"dt={dt}, use_internal={use_internal}, nhead={nhead}, d_hid={d_hid}")
    
    # Determine output dimension
    if model_type == "exponential":
        d_out = EXP_OUTPUT_DIMS.get(dataset, INPUT_DIMS[dataset])
    else:
        d_out = INPUT_DIMS[dataset]
    
    # Build model
    model = build_model(
        model_type=model_type,
        dataset=dataset,
        depth=depth,
        d_out=d_out,
        dropout=dropout,
        residual=residual,
        dt=dt,
        use_internal=use_internal,
        outputsflow_dir=None,
        nhead=nhead,
        d_hid=d_hid,
    )
    
    # Handle probabilistic model - need to set the final linear layer size
    if model_type == "probabilistic":
        num_anchors = meta.get("num_anchors", 100)
        if isinstance(model, nn.Sequential) and len(model) > 1:
            # Replace the final linear layer
            hidden_dim = model[0].output_dim if hasattr(model[0], 'output_dim') else d_out
            model[-1] = nn.Linear(hidden_dim, num_anchors)
    
    # Load state dict
    checkpoint = torch.load(str(model_path), map_location=device, weights_only=False)
    state_dict = checkpoint.get("state_dict", checkpoint)
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    
    return model, meta

In [7]:
def normalize_se3_translation(X16: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
    """
    X16: [N,16] row-major flatten of 4x4.
    Normalizes translation column (rows 0..2, col 3) by tau.
    """
    G = X16.view(-1, 4, 4).clone()
    G[:, :3, 3] = G[:, :3, 3] / tau
    return G.view(-1, 16)


def load_test_data(dataset: str, data_dir: Path, normalize_protein: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load test data for a dataset.
    
    For protein dataset, applies the same normalization as during training.
    """
    data_file = data_dir / DATASET_TO_FILE[dataset]
    
    if not data_file.exists():
        raise FileNotFoundError(f"Data file not found: {data_file}")
    
    loaded = torch.load(str(data_file), map_location="cpu", weights_only=False)
    
    X_test = torch.tensor(loaded["X_test"], dtype=torch.float32)
    Y_test = torch.tensor(loaded["Y_test"], dtype=torch.float32)
    
    # Apply protein normalization if needed (same as in train_ff.py)
    if dataset == "protein" and normalize_protein:
        # Compute tau from training data (same as during training)
        X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
        t = X_train.view(-1, 4, 4)[:, :3, 3]  # Extract translation column
        tau = t.std().clamp_min(1e-8)
        
        # Normalize test data with the same tau
        X_test = normalize_se3_translation(X_test, tau)
        Y_test = normalize_se3_translation(Y_test, tau)
        
        print(f"  Applied protein normalization with tau={tau.item():.6g}")
    
    return X_test, Y_test

In [13]:
def evaluate_model(
    model: nn.Module,
    X_test: torch.Tensor,
    Y_test: torch.Tensor,
    dataset: str,
    model_type: str,
    meta: Dict,
    device: str = "cpu",
    batch_size: int = 256
) -> Dict:
    """Evaluate model on test data and compute loss + constraint satisfaction"""
    model.eval()
    
    # Handle probabilistic models specially
    if model_type == "probabilistic":
        print(meta.keys())
        # Load anchors from metadata or recreate them
        anchors_shape = meta.get("anchors_shape")
        num_anchors = meta.get("num_anchors", 100)
        anchors = torch.tensor(meta.get("anchors")).to(device)
    
    # Move data to device
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)
    
    # Compute predictions in batches
    all_preds = []

    if model_type == "probabilistic":
        print(X_test.shape)
        Y_pred = predict(model, X_test, anchors)
        print(Y_pred.shape)
    else:
        Y_pred = model(X_test)
    print(Y_test.shape)
    loss = F.mse_loss(Y_pred, Y_test)
    const = constraint_satisfaction(Y_pred, dataset)

    return {"test_loss":loss, "constraints":const}

In [26]:
# Main extraction loop
results = {}

# Define all model types and datasets to check
MODEL_TYPES = ["regular", "exponential", "probabilistic", "flow_matching", "projected/internalFalse", "exponential/internalFalse", "flow_matching/internalFalse", "flow_matching/internaTrue"]
# MODEL_TYPES = ["projected/depth8"]

DATASETS = ["cs"]

device = "cuda" if torch.cuda.is_available() else "cpu"

device = "cpu"
print(f"Using device: {device}")
print(f"\nSearching for models in: {OUTPUTS_DIR}")
print(f"Loading test data from: {DATA_DIR}")

for dataset in DATASETS:
    results[dataset] = {}
    
    for model_type in MODEL_TYPES:
        # print(f"\n{'='*60}")
        # print(f"Processing: {dataset} / {model_type}")
        # print(f"{'='*60}")
        
        # Find best run
        if model_type == "probabilistic":
            best_result = find_best_run(dataset, model_type, OUTPUTS_PROB, verbose=True)
        else:
            best_result = find_best_run(dataset, model_type, OUTPUTS_DIR, verbose=True)
        
        if best_result is None:
            print(f"No runs found for {dataset}/{model_type}")
            results[dataset][model_type] = None
            continue
        
        best_dir, best_val_loss = best_result
        # print(f"\nSelected best run:")
        # print(f"  Validation loss: {best_val_loss:.6e}")
        # print(f"  Directory: {best_dir}")
        

        # Load model and metadata
        base = model_type.split("/", 1)[0]
        model, meta = load_model_and_meta(best_dir, dataset, base, device)
        # print(f"Model loaded successfully")
        
        # Load test data (with normalization for protein dataset)
        X_test, Y_test = load_test_data(dataset, DATA_DIR, normalize_protein=True)
        # print(f"Test data loaded: X_test.shape={X_test.shape}, Y_test.shape={Y_test.shape}")
        
        # Evaluate model
        eval_results = evaluate_model(model, X_test, Y_test, dataset, model_type, meta, device)
        _print_eval_summary(dataset, model_type, best_val_loss, best_dir, eval_results)
        results[dataset][model_type] = eval_results
        # except:
        #     print("error for", dataset, model_type)
        #     results[dataset][model_type] = None

Using device: cpu

Searching for models in: /projects/gtml/Constrained Networks/src/Models/outputs_final
Loading test data from: /projects/gtml/Constrained Networks/src/Data
torch.Size([1200, 10, 9])

------------------------------------------------------------
RESULT: cs / regular
------------------------------------------------------------
  Best val loss: 3.323884e-01
  Run dir:       /projects/gtml/Constrained Networks/src/Models/outputs_final/cs/regular/depth4/out9/lr0.001_wd0.0001/seed0
  Test loss:     3.326826e-01
  Constraints:
    ||RR^T-I||_F      mean=1.722e+00  max=1.732e+00
    |det(R)-1|        mean=9.998e-01  max=1.003e+00
 SUM (all terms) mean=2.721e+00 max=2.732e+00
------------------------------------------------------------

torch.Size([1200, 10, 9])

------------------------------------------------------------
RESULT: cs / exponential
------------------------------------------------------------
  Best val loss: 6.357073e-01
  Run dir:       /projects/gtml/Constrain

In [None]:
latex = results_to_latex_table(results, DATASETS, MODEL_TYPES,
                               caption="Test loss and constraint violation (mean / max).",
                               label="tab:outputs-final")
print(latex)

In [None]:
from typing import Any, Dict, Optional
import math

def _escape_latex(s: str) -> str:
    # minimal escaping for your strings
    return (
        s.replace("\\", "\\textbackslash{}")
         .replace("_", "\\_")
         .replace("%", "\\%")
         .replace("&", "\\&")
         .replace("#", "\\#")
         .replace("{", "\\{")
         .replace("}", "\\}")
         .replace("^", "\\^{}")
         .replace("~", "\\~{}")
    )

def _sci_latex(x: Optional[float], sig: int = 3) -> str:
    """Format float as LaTeX scientific notation like 1.23\\times10^{-4}."""
    if x is None or (isinstance(x, float) and (math.isnan(x) or math.isinf(x))):
        return "--"
    if x == 0:
        return "0"
    exp = int(math.floor(math.log10(abs(x))))
    mant = x / (10 ** exp)
    return f"{mant:.{sig-1}f}\\times 10^{{{exp}}}"

def results_to_latex_table(
    results: Dict[str, Dict[str, Any]],
    datasets,
    model_types,
    caption: str = "Test loss and constraint violation (mean / max).",
    label: str = "tab:constraints",
) -> str:
    """
    Expects results[dataset][model_type] to be either:
      - None (missing / not found / error), OR
      - dict with keys: "test_loss" (float) and "constraints" (dict).
    Produces one combined LaTeX table.
    """
    lines = []
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\small")
    lines.append("\\setlength{\\tabcolsep}{4pt}")
    lines.append("\\renewcommand{\\arraystretch}{1.15}")
    lines.append("\\begin{tabular}{llrcccccc}")
    lines.append("\\toprule")
    lines.append(
        "Dataset & Model & Test loss & "
        "\\multicolumn{2}{c}{Dist} & "
        "\\multicolumn{2}{c}{$\\|RR^\\top-I\\|_F$} & "
        "\\multicolumn{2}{c}{$|\\det(R)-1|$} \\\\"
    )
    lines.append(
        " & & & mean & max & mean & max & mean & max \\\\"
    )
    lines.append("\\midrule")

    for ds in datasets:
        first_row = True
        for mt in model_types:
            entry = results.get(ds, {}).get(mt, None)

            if entry is None:
                test = dist_mean = dist_max = orth_mean = orth_max = det_mean = det_max = "--"
            else:
                test_loss = entry.get("test_loss", None)
                const = entry.get("constraints", None) or {}

                test = _sci_latex(float(test_loss)) if test_loss is not None else "--"

                # Defaults
                dist_mean = dist_max = "--"
                orth_mean = orth_max = "--"
                det_mean = det_max = "--"

                # sphere/disk
                if "mean_dist" in const:
                    dist_mean = _sci_latex(const.get("mean_dist"))
                    dist_max  = _sci_latex(const.get("max_dist"))

                # so3/cs/protein
                if "mean_orth_dist" in const:
                    orth_mean = _sci_latex(const.get("mean_orth_dist"))
                    orth_max  = _sci_latex(const.get("max_orth_dist"))
                if "mean_det_dist" in const:
                    det_mean = _sci_latex(const.get("mean_det_dist"))
                    det_max  = _sci_latex(const.get("max_det_dist"))

            ds_cell = _escape_latex(ds) if first_row else ""
            mt_cell = "\\texttt{" + _escape_latex(mt) + "}"
            lines.append(
                f"{ds_cell} & {mt_cell} & {test} & "
                f"{dist_mean} & {dist_max} & "
                f"{orth_mean} & {orth_max} & "
                f"{det_mean} & {det_max} \\\\"
            )
            first_row = False

        lines.append("\\midrule")

    # remove last midrule and replace with bottomrule
    if lines[-1] == "\\midrule":
        lines[-1] = "\\bottomrule"
    else:
        lines.append("\\bottomrule")

    lines.append("\\end{tabular}")
    lines.append(f"\\caption{{{_escape_latex(caption)}}}")
    lines.append(f"\\label{{{_escape_latex(label)}}}")
    lines.append("\\end{table}")
    return "\n".join(lines)


In [None]:
# Display results summary
print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)

for dataset in DATASETS:
    print(f"\n{dataset.upper()}:")
    print("-" * 80)
    
    for model_type in MODEL_TYPES:
        result = results.get(dataset, {}).get(model_type)
        
        if result is None:
            print(f"  {model_type:20s} - No runs found")
        elif "error" in result:
            print(f"  {model_type:20s} - ERROR: {result['error']}")
        else:
            val_loss = result.get("best_val_loss", "N/A")
            test_loss = result.get("test_loss", "N/A")
            constraint = result.get("constraint_satisfaction", {})
            sat_rate = constraint.get("satisfaction_rate", "N/A")
            
            print(f"  {model_type:20s} - Val: {val_loss:.6e}, Test: {test_loss:.6e}, "
                  f"Constraint: {sat_rate:.4f}")

In [30]:
# Debug Flow Matching Models
# This cell tests flow matching models in detail to diagnose issues

print("=" * 80)
print("FLOW MATCHING MODEL DEBUGGING")
print("=" * 80)

# Test on a few datasets that had issues
test_datasets = ["so3", "cs", "protein"]

for dataset in test_datasets:
    print(f"\n{'='*80}")
    print(f"DEBUGGING FLOW MATCHING FOR DATASET: {dataset.upper()}")
    print(f"{'='*80}\n")
    
    try:
        # Find the best run
        best_run = find_best_run(dataset, "flow_matching", OUTPUTS_DIR, verbose=False)
        if best_run is None:
            print(f"  ✗ No flow matching model found for {dataset}")
            continue
        
        run_dir, val_loss = best_run
        print(f"  Best run directory: {run_dir}")
        print(f"  Best validation loss: {val_loss:.6e}\n")
        
        # Load model and meta
        print("  Loading model...")
        model, meta = load_model_and_meta(run_dir, dataset, "flow_matching", device="cpu")
        
        # Check if projection function is set
        # ProjectedTransformer uses 'end_proj_func', ProjectedFeedForward uses 'final_proj_func'
        print("\n  Checking projection function...")
        print(f"    Model type: {type(model).__name__}")
        
        proj_func = None
        proj_attr_name = None
        
        if hasattr(model, 'end_proj_func'):
            proj_func = model.end_proj_func
            proj_attr_name = 'end_proj_func'
            print(f"    Found end_proj_func (ProjectedTransformer)")
        elif hasattr(model, 'final_proj_func'):
            proj_func = model.final_proj_func
            proj_attr_name = 'final_proj_func'
            print(f"    Found final_proj_func (ProjectedFeedForward)")
        else:
            print(f"    ✗ Model does not have end_proj_func or final_proj_func attribute!")
            print(f"    Model attributes: {[attr for attr in dir(model) if 'proj' in attr.lower()]}")
        
        if proj_func is not None:
            print(f"  ✓ {proj_attr_name} is set")
                
                # Test projection function directly
            print("\n  Testing projection function directly...")
            if dataset == "cs":
                    test_input = torch.randn(2, 10, INPUT_DIMS[dataset])
            else:
                    test_input = torch.randn(2, INPUT_DIMS[dataset])
                
            print(f"    Input shape: {test_input.shape}")
                
            with torch.no_grad():
                    try:
                        test_output = proj_func(test_input)
                        print(f"    Output shape: {test_output.shape}")
                        
                        # Check if projection changes values
                        diff = torch.norm(test_input - test_output).item()
                        print(f"    L2 difference: {diff:.6e}")
                        
                        if diff < 1e-6:
                            print(f"    ⚠ WARNING: Projection appears to be identity!")
                        else:
                            print(f"    ✓ Projection changes values")
                        
                        # Check for NaN/Inf
                        if torch.isnan(test_output).any():
                            print(f"    ✗ Output contains NaN!")
                        if torch.isinf(test_output).any():
                            print(f"    ✗ Output contains Inf!")
                        
                    except Exception as e:
                        print(f"    ✗ Projection function failed: {e}")
                        import traceback
                        traceback.print_exc()
        else:
            print(f"  ✗ {proj_attr_name} is None!")
        
        # Test full forward pass
        print("\n  Testing full forward pass...")
        X_test, Y_test = load_test_data(dataset, DATA_DIR)
        
        # Take a small batch
        batch_size = 5
        X_batch = X_test[:batch_size]
        Y_batch = Y_test[:batch_size]
        
        if dataset == "cs":
            print(f"    Input shape: {X_batch.shape} (should be [batch, seq, features])")
        else:
            print(f"    Input shape: {X_batch.shape} (should be [batch, features])")
        
        with torch.no_grad():
            try:
                pred = model(X_batch)
                print(f"    Prediction shape: {pred.shape}")
                
                # Check for NaN/Inf
                if torch.isnan(pred).any():
                    print(f"    ✗ Predictions contain NaN!")
                if torch.isinf(pred).any():
                    print(f"    ✗ Predictions contain Inf!")
                
                # For SO3/CS, check constraint satisfaction on this small batch
                if dataset in ("so3", "cs"):
                    # Extract last timestep if sequential
                    if pred.dim() == 3:
                        pred_check = pred[:, -1, :]
                    else:
                        pred_check = pred
                    
                    # Reshape to 3x3 matrices
                    if len(pred_check[0]) == 9:
                        R = pred_check[0].reshape(3, 3)
                        orth_err = torch.norm(R.T @ R - torch.eye(3)).item()
                        det_err = abs(torch.det(R).item() - 1.0)
                        print(f"\n    Constraint check on first prediction:")
                        print(f"      Orthogonality error: {orth_err:.6e}")
                        print(f"      Determinant error: {det_err:.6e}")
                        
                        if orth_err > 0.1 or det_err > 0.1:
                            print(f"      ✗ Prediction is far from SO(3) - projection may not be working!")
                        else:
                            print(f"      ✓ Prediction is close to SO(3)")
                
                # Check if projection was actually applied
                # We can't easily intercept, but we can check if predictions satisfy constraints
                print(f"\n    Computing constraint satisfaction on batch...")
                if dataset in ("so3", "cs"):
                    ok, max_orth, max_det = check_so3_flat(pred_check)
                    print(f"      Samples satisfying constraints: {ok.sum().item()}/{len(ok)}")
                    print(f"      Max orthogonality error: {max_orth:.6e}")
                    print(f"      Max determinant error: {max_det:.6e}")
                elif dataset == "protein":
                    # For protein, pred should be [batch, 16] (flattened 4x4 matrices)
                    if pred.dim() == 2 and pred.shape[1] == 16:
                        pred_check = pred
                    elif pred.dim() == 3:
                        # If sequential, take last timestep
                        pred_check = pred[:, -1, :]
                    else:
                        print(f"      ⚠ Unexpected prediction shape: {pred.shape}")
                        pred_check = pred
                    
                    print(f"      Using prediction shape: {pred_check.shape}")
                    if pred_check.shape[1] != 16:
                        print(f"      ✗ ERROR: Expected 16 features for SE(3), got {pred_check.shape[1]}")
                    else:
                        results = check_se3_flat(pred_check)
                        print(f"      Samples satisfying constraints: {results['ok'].sum().item()}/{len(results['ok'])}")
                        print(f"      Max orthogonality error: {results['max_orth_err']:.6e}")
                        print(f"      Max determinant error: {results['max_det_err']:.6e}")
                
            except Exception as e:
                print(f"    ✗ Forward pass failed: {e}")
                import traceback
                traceback.print_exc()
        
        # Test projection on known manifold points
        print("\n  Testing projection on known manifold points...")
        if dataset in ("so3", "cs"):
            # Create a valid SO(3) rotation matrix
            from scipy.spatial.transform import Rotation
            R_valid = Rotation.random().as_matrix()
            R9_valid = torch.tensor(R_valid.flatten(), dtype=torch.float32).unsqueeze(0)  # [1, 9]
            
            if dataset == "cs":
                # For CS, need [batch, seq, features]
                R9_valid = R9_valid.unsqueeze(1).expand(1, 10, 9)  # [1, 10, 9]
            
            print(f"    Created valid SO(3) point, shape: {R9_valid.shape}")
            
            # Check it satisfies constraints
            ok_before, _, _ = check_so3_flat(R9_valid)
            print(f"    Before projection: {ok_before.sum().item()}/{len(ok_before)} satisfy constraints")
            
            # Apply projection
            with torch.no_grad():
                R9_proj = proj_func(R9_valid)
                ok_after, max_orth, max_det = check_so3_flat(R9_proj)
                print(f"    After projection: {ok_after.sum().item()}/{len(ok_after)} satisfy constraints")
                print(f"    Max orthogonality error: {max_orth:.6e}")
                print(f"    Max determinant error: {max_det:.6e}")
                
                if ok_after.sum().item() < len(ok_after):
                    print(f"    ✗ WARNING: Projection moved point OFF the manifold!")
                    print(f"    This suggests the flow matching projection is not working correctly.")
                else:
                    print(f"    ✓ Projection preserves manifold structure")
        
        elif dataset == "protein":
            # Create a valid SE(3) transformation
            from scipy.spatial.transform import Rotation
            R_valid = Rotation.random().as_matrix()
            t_valid = np.random.randn(3)
            G_valid = np.eye(4)
            G_valid[:3, :3] = R_valid
            G_valid[:3, 3] = t_valid
            G16_valid = torch.tensor(G_valid.flatten(), dtype=torch.float32).unsqueeze(0)  # [1, 16]
            
            print(f"    Created valid SE(3) point, shape: {G16_valid.shape}")
            
            # Check it satisfies constraints
            results_before = check_se3_flat(G16_valid)
            print(f"    Before projection: {results_before['ok'].sum().item()}/{len(results_before['ok'])} satisfy constraints")
            
            # Apply projection
            with torch.no_grad():
                G16_proj = proj_func(G16_valid)
                results_after = check_se3_flat(G16_proj)
                print(f"    After projection: {results_after['ok'].sum().item()}/{len(results_after['ok'])} satisfy constraints")
                print(f"    Max orthogonality error: {results_after['max_orth_err']:.6e}")
                print(f"    Max determinant error: {results_after['max_det_err']:.6e}")
                
                if results_after['ok'].sum().item() < len(results_after['ok']):
                    print(f"    ✗ WARNING: Projection moved point OFF the manifold!")
                    print(f"    This suggests the flow matching projection is not working correctly.")
                else:
                    print(f"    ✓ Projection preserves manifold structure")
        
        # Check what flow matching model was loaded
        print("\n  Checking flow matching model source...")
        # The flow_model is captured in the closure, so we can't easily inspect it
        # But we can check if outputsflow_dir was used correctly
        hparams = meta.get("hparams", {})
        saved_outputsflow = hparams.get("outputsflow_dir") or meta.get("outputsflow_dir")
        if saved_outputsflow:
            print(f"    Saved outputsflow_dir: {saved_outputsflow}")
        else:
            print(f"    ⚠ outputsflow_dir not saved in metadata (we had to search for it)")
        
        # Check projection parameters
        print("\n  Flow matching projection parameters:")
        print(f"    T=2.0, num_steps=40, differentiable=True")
        print(f"    ⚠ Using differentiable=True (Euler integration) may be less accurate than differentiable=False (scipy solver)")
        
    except Exception as e:
        print(f"  ✗ Error debugging {dataset}: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 80)
print("SUMMARY AND DIAGNOSIS")
print("=" * 80)
print("=" * 80)
print("DEBUGGING COMPLETE")
print("=" * 80)

FLOW MATCHING MODEL DEBUGGING

DEBUGGING FLOW MATCHING FOR DATASET: SO3

  Best run directory: /projects/gtml/Constrained Networks/src/Models/outputs/so3/flow_matching/depth8/out9/lr0.001_wd0/seed0
  Best validation loss: 1.017315e-01

  Loading model...
  Building model with: depth=8, dropout=0.0, residual=True, dt=1.0, use_internal=True, nhead=3, d_hid=2048
  Found outputsflow directory: /projects/gtml/Constrained Networks/src/Models/outputsflow
  Flow matching model exists: /projects/gtml/Constrained Networks/src/Models/outputsflow/so3_dataset/BEST/model.pt
Loaded flow matching model from /projects/gtml/Constrained Networks/src/Models/outputsflow/so3_dataset/BEST/model.pt (input_dim=9)

  Checking projection function...
    Model type: ProjectedFeedForward
    Found final_proj_func (ProjectedFeedForward)
  ✓ final_proj_func is set

  Testing projection function directly...
    Input shape: torch.Size([2, 9])
    Output shape: torch.Size([2, 9])
    L2 difference: 1.215075e+00
    ✓ 

In [29]:
# Check constraint satisfaction on training data
print("=" * 80)
print("CHECKING CONSTRAINT SATISFACTION ON TRAINING DATA")
print("=" * 80)

# Define load_train_data if not already defined (should be in cell 6)
if 'load_train_data' not in globals():
    def load_train_data(dataset: str, data_dir: Path, normalize_protein: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Load training data for a dataset.
        
        For protein dataset, applies the same normalization as during training.
        """
        data_file = data_dir / DATASET_TO_FILE[dataset]
        
        if not data_file.exists():
            raise FileNotFoundError(f"Data file not found: {data_file}")
        
        loaded = torch.load(str(data_file), map_location="cpu", weights_only=False)
        
        X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
        Y_train = torch.tensor(loaded["Y_train"], dtype=torch.float32)
        
        # Apply protein normalization if needed (same as in train_ff.py)
        if dataset == "protein" and normalize_protein:
            # Compute tau from training data (same as during training)
            t = X_train.view(-1, 4, 4)[:, :3, 3]  # Extract translation column
            tau = t.std().clamp_min(1e-8)
            
            # Normalize training data with the same tau
            X_train = normalize_se3_translation(X_train, tau)
            Y_train = normalize_se3_translation(Y_train, tau)
            
            print(f"  Applied protein normalization with tau={tau.item():.6g}")
        
        return X_train, Y_train

# Test on all datasets
for dataset in DATASETS:
    print(f"\n{'='*80}")
    print(f"DATASET: {dataset.upper()}")
    print(f"{'='*80}\n")
    
    try:
        # Load training data
        print("  Loading training data...")
        X_train, Y_train = load_train_data(dataset, DATA_DIR)
        print(f"  Training data shape: X={X_train.shape}, Y={Y_train.shape}")
        
        # Check constraint satisfaction on ground truth Y_train
        print("\n  Checking constraint satisfaction on ground truth Y_train...")
        if dataset in ("so3", "cs"):
            # For CS, check last timestep
            if Y_train.dim() == 3:
                Y_check = Y_train[:, -1, :]
            else:
                Y_check = Y_train
            
            ok, max_orth, max_det = check_so3_flat(Y_check)
            sat_rate = ok.float().mean().item()
            print(f"    Samples satisfying constraints: {ok.sum().item()}/{len(ok)} ({sat_rate*100:.2f}%)")
            print(f"    Max orthogonality error: {max_orth:.6e}")
            print(f"    Max determinant error: {max_det:.6e}")
            
            if sat_rate < 0.99:
                print(f"    ⚠ WARNING: Training data does not fully satisfy SO(3) constraints!")
            else:
                print(f"    ✓ Training data satisfies SO(3) constraints")
                
        elif dataset == "protein":
            results = check_se3_flat(Y_train)
            sat_rate = results["ok"].float().mean().item()
            print(f"    Samples satisfying constraints: {results['ok'].sum().item()}/{len(results['ok'])} ({sat_rate*100:.2f}%)")
            print(f"    Max orthogonality error: {results['max_orth_err']:.6e}")
            print(f"    Max determinant error: {results['max_det_err']:.6e}")
            
            if sat_rate < 0.99:
                print(f"    ⚠ WARNING: Training data does not fully satisfy SE(3) constraints!")
            else:
                print(f"    ✓ Training data satisfies SE(3) constraints")
        
        # Load best model and check predictions on training data
        print("\n  Loading best model and checking predictions on training data...")
        for model_type in MODEL_TYPES:
            if model_type == "probabilistic":
                print(f"    {model_type}: Skipping (requires anchors)")
                continue
            
            try:
                best_run = find_best_run(dataset, model_type, OUTPUTS_DIR, verbose=False)
                if best_run is None:
                    print(f"    {model_type}: No model found")
                    continue
                
                run_dir, val_loss = best_run
                model, meta = load_model_and_meta(run_dir, dataset, model_type, device="cpu")
                
                # Evaluate on a sample of training data (to avoid memory issues)
                sample_size = min(100, len(X_train))
                indices = torch.randperm(len(X_train))[:sample_size]
                X_sample = X_train[indices]
                Y_sample = Y_train[indices]
                
                # Get predictions
                with torch.no_grad():
                    pred = model(X_sample)
                    
                    # For sequential data, extract last timestep
                    if dataset == "cs" and pred.dim() == 3:
                        pred_check = pred[:, -1, :]
                        Y_check = Y_sample[:, -1, :] if Y_sample.dim() == 3 else Y_sample
                    else:
                        pred_check = pred
                        Y_check = Y_sample
                
                # Check constraint satisfaction on predictions
                if dataset in ("so3", "cs"):
                    ok, max_orth, max_det = check_so3_flat(pred_check)
                    sat_rate = ok.float().mean().item()
                    print(f"    {model_type:20s}: {ok.sum().item()}/{len(ok)} ({sat_rate*100:.2f}%) satisfy constraints")
                    print(f"      Max errors: orth={max_orth:.6e}, det={max_det:.6e}")
                elif dataset == "protein":
                    results = check_se3_flat(pred_check)
                    sat_rate = results["ok"].float().mean().item()
                    print(f"    {model_type:20s}: {results['ok'].sum().item()}/{len(results['ok'])} ({sat_rate*100:.2f}%) satisfy constraints")
                    print(f"      Max errors: orth={results['max_orth_err']:.6e}, det={results['max_det_err']:.6e}")
                    
            except Exception as e:
                print(f"    {model_type}: Error - {e}")
        
    except Exception as e:
        print(f"  ✗ Error processing {dataset}: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 80)
print("TRAINING DATA CONSTRAINT CHECK COMPLETE")
print("=" * 80)

CHECKING CONSTRAINT SATISFACTION ON TRAINING DATA

DATASET: SPHERE

  Loading training data...
  Training data shape: X=torch.Size([4000, 3]), Y=torch.Size([4000, 3])

  Checking constraint satisfaction on ground truth Y_train...

  Loading best model and checking predictions on training data...
  Building model with: depth=6, dropout=0.0, residual=True, dt=1.0, use_internal=True, nhead=3, d_hid=2048
  Building model with: depth=8, dropout=0.0, residual=True, dt=1.0, use_internal=False, nhead=3, d_hid=2048
  Building model with: depth=8, dropout=0.0, residual=True, dt=1.0, use_internal=True, nhead=3, d_hid=2048
<function sphere at 0x15544641f250> 3 3 8 0.0 True
    probabilistic: Skipping (requires anchors)
  Building model with: depth=8, dropout=0.0, residual=True, dt=1.0, use_internal=False, nhead=3, d_hid=2048
  Found outputsflow directory: /projects/gtml/Constrained Networks/src/Models/outputsflow
  Flow matching model exists: /projects/gtml/Constrained Networks/src/Models/outputsf

  X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)


    projected           : 100/100 (100.00%) satisfy constraints
      Max errors: orth=1.256743e-06, det=1.072884e-06
  Building model with: depth=4, dropout=0.0, residual=True, dt=1.0, use_internal=False, nhead=3, d_hid=2048
<function so3 at 0x15544641d6c0> 9 3 4 0.0 False
    exponential         : 100/100 (100.00%) satisfy constraints
      Max errors: orth=4.466861e-06, det=3.218651e-06
    probabilistic: Skipping (requires anchors)
  Building model with: depth=8, dropout=0.0, residual=True, dt=1.0, use_internal=True, nhead=3, d_hid=2048
  Found outputsflow directory: /projects/gtml/Constrained Networks/src/Models/outputsflow
  Flow matching model exists: /projects/gtml/Constrained Networks/src/Models/outputsflow/so3_dataset/BEST/model.pt
Loaded flow matching model from /projects/gtml/Constrained Networks/src/Models/outputsflow/so3_dataset/BEST/model.pt (input_dim=9)
    flow_matching       : 0/100 (0.00%) satisfy constraints
      Max errors: orth=1.413097e+00, det=1.001283e+00

DA