In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

# Set correct path so "processing_mstdb" is importable
PROJECT_PATH = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("Adding to sys.path:", PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

# Now you can import your modules
from processing_mstdb.processor import MSTDBProcessor
from processing_mstdb.trainer import AIModelTrainer

# Load and preprocess the data
data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_processed.csv"))
# data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_janz_processed.csv"))
print("Loading data from:", data_path)

processor = MSTDBProcessor.from_csv(data_path)
processor.df.columns = processor.df.columns.str.strip()

# Specify composition type: 'elements', 'compounds', or 'both'
composition_type = 'both'  # Using 'elements' to match the prediction example below
# For 'compounds', use e.g., composition_type = 'compounds'
# For 'both', use composition_type = 'both'

processor.df['Composition'] = processor.df.apply(
    lambda row: processor.compute_composition(row, composition_type=composition_type), axis=1
)

# Step 2: Train models
trainer = AIModelTrainer(processor.df)
trainer.train_all()

# Step 3: Make prediction
# For composition_type='elements', provide elemental composition
example_composition = {'Na': 0.5, 'Cl': 0.5}
# For composition_type='compounds', provide compound composition, e.g., {'NaCl': 1.0}
# For composition_type='both', provide both, e.g., {'Na': 0.5, 'Cl': 0.5, 'NaCl': 1.0}
predicted_props = trainer.predict(example_composition)

# Step 4: Display results
print("\nPredicted Properties for example composition:")
for k, v in predicted_props.items():
    print(f"  {k}: {v:.4f}")

# Step 5: Show evaluation metrics
print("\nModel Performance Summary (R2 and MSE):")
metrics = trainer.get_metrics()
for target, scores in metrics.items():
    print(f"\nTarget: {target}")
    for model_name, result in scores.items():
        print(f"  {model_name}: R2 = {result['R2']:.4f}, MSE = {result['MSE']:.4f}")

# Step 6: Plot R2 scores
print("\nPlotting R2 scores")
plot_dir = "sklearn_prediction_plots"
os.makedirs(plot_dir, exist_ok=True)

for target, model_scores in metrics.items():
    model_names = list(model_scores.keys())
    r2_scores = [model_scores[m]['R2'] for m in model_names]
    plt.figure(figsize=(10, 5))
    plt.bar(model_names, r2_scores)
    plt.title(f"R² Scores for {target}")
    plt.ylabel("R²")
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f"r2_{target}.png"))
    plt.close()

# Step 7: Plot Actual vs Predicted for best models
print("\nPlot actual vs. predicted for best models coefficients")
from sklearn.metrics import r2_score

for target in trainer.present_target_columns:
    model_path = os.path.join("..", "data", "trained_models", f"{target}.joblib")
    scaler_path = os.path.join("..", "data", "trained_models", f"{target}_scaler.joblib")
    if not os.path.exists(model_path) or not os.path.exists(scaler_path):
        continue
    df_target = processor.df.dropna(subset=[target])
    X = pd.json_normalize(df_target['Composition']).fillna(0)
    X_poly = trainer.poly.transform(X)
    X_scaled = trainer.scalers[target].transform(X_poly)
    y_actual = pd.to_numeric(df_target[target], errors='coerce')
    model = trainer.best_models[target]
    y_pred = model.predict(X_scaled)
    plt.figure(figsize=(6, 6))
    plt.scatter(y_actual, y_pred, alpha=0.7)
    plt.plot([y_actual.min(), y_actual.max()], [y_actual.min(), y_actual.max()], 'r--')
    plt.title(f"Actual vs Predicted: {target}")
    plt.xlabel("Actual")
    plt.ylabel("Predicted")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f"actual_vs_predicted_{target}.png"))
    plt.close()

# Step 8: Compute actual vs predicted thermophysical properties
print("\nPlot actual vs. predicted for best model physical properties")
temperature = 900  # Kelvin
os.makedirs(plot_dir, exist_ok=True)

# Clean up all coefficient columns to ensure they are numeric
coeff_cols = ['rho_a', 'rho_b', 'mu1_a', 'mu1_b',
              'mu2_a', 'mu2_b', 'mu2_c',
              'k_a', 'k_b', 'cp_a', 'cp_b', 'cp_c', 'cp_d']

for col in coeff_cols:
    if col in processor.df.columns:
        processor.df[col] = pd.to_numeric(processor.df[col], errors='coerce')

# Replace NaNs with 0.0
processor.df.fillna(0.0, inplace=True)

# Compute actual thermophysical properties
processor.df['Actual Properties'] = processor.df.apply(
    lambda row: processor.compute_actual_properties(row, temperature), axis=1
)

# Predict coefficients and compute predicted thermophysical properties
processor.df['Predicted Coeffs'] = processor.df['Composition'].apply(trainer.predict)
processor.df['Predicted Properties'] = processor.df['Predicted Coeffs'].apply(
    lambda pred: processor.compute_actual_properties_from_predictions(pred, temperature)
)

# Define which properties to compare
properties_to_compare = [
    "Density", "Viscosity A", "Thermal Conductivity", "Heat Capacity of Liquid"
]

# Plot comparisons
for prop in properties_to_compare:
    actual_vals = []
    predicted_vals = []
    skipped = 0
    for _, row in processor.df.iterrows():
        actual = row['Actual Properties'].get(prop)
        predicted = row['Predicted Properties'].get(prop)
        if actual is None or predicted is None or actual < 1e-6:
            skipped += 1
            continue
        actual_vals.append(actual)
        predicted_vals.append(predicted)
    if actual_vals and predicted_vals:
        plt.figure(figsize=(6, 6))
        plt.scatter(actual_vals, predicted_vals, alpha=0.7)
        plt.plot([min(actual_vals), max(actual_vals)], [min(actual_vals), max(actual_vals)], 'r--')
        plt.title(f"Actual vs Predicted {prop} at {temperature} K")
        plt.xlabel("Actual")
        plt.ylabel("Predicted")
        plt.grid(True)
        plt.tight_layout()
        fname = f"actual_vs_predicted_{prop.replace(' ', '_')}.png"
        plt.savefig(os.path.join(plot_dir, fname))
        plt.close()
        print(f"Saved: {fname} ({len(actual_vals)} points, {skipped} skipped)")
    else:
        print(f"Skipped {prop} — no valid data (all {skipped} rows filtered out).")

Adding to sys.path: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb
Loading data from: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb/data/mstdb_processed.csv






























































































































































































































































































































































































































































































































































































































































































































































































































  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


KeyboardInterrupt: 

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

# Ensure local module path is visible
PROJECT_PATH = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("Adding to sys.path:", PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

from processing_mstdb.processor import MSTDBProcessor
from processing_mstdb.resnet_trainer import ResNetMetaTrainer, TARGETS, DERIVED_PROPS

# Step 1: Load and preprocess the data
# data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_processed.csv"))
data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_janz_processed.csv"))
print("Loading data from:", data_path)

processor = MSTDBProcessor.from_csv(data_path)
processor.df.columns = processor.df.columns.str.strip()

composition_type = 'elements'

processor.df['Composition'] = processor.df.apply(
    lambda row: processor.compute_composition(row, composition_type=composition_type), axis=1
)

# Step 2: Train ResNet+Meta+Physics
trainer = ResNetMetaTrainer(processor.df, TARGETS, DERIVED_PROPS)
trainer.train_base()
trainer.train_meta()
trainer.train_joint()

# Step 3: Predict for an example
example_composition = {'Na': 0.5, 'Cl': 0.5}
predicted_coeffs = trainer.predict(example_composition)
print("\nPredicted coefficients for 50-50 NaCl:")
for k, v in predicted_coeffs.items():
    print(f"{k}: {v:.4f}")

# Step 4: Compute derived properties at 900K
derived_props = trainer.derived(predicted_coeffs, 900)
print("\nDerived properties at 900K:")
for k, v in derived_props.items():
    print(f"{k}: {v:.4f}")

# Step 5: Plotting
print("\nPlotting results...")
plot_dir = "resnet_prediction_plots"
os.makedirs(plot_dir, exist_ok=True)

# Helper to predict batches
def predict_all(X_input):
    import torch
    trainer.base_nets.eval()
    trainer.meta.eval()
    with torch.no_grad():
        xb = torch.tensor(X_input, dtype=torch.float32, device=trainer.device)
        base_out = torch.stack([trainer.base_nets[p](xb) for p in trainer.present_targets], 1)
        pred = (base_out + trainer.meta(base_out)).cpu().numpy()
        return pred * trainer.σ + trainer.μ

# Step 6: Actual vs Predicted - coefficients
for split_name, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):
    y_true = trainer.y_raw[idx_set]
    y_pred = predict_all(trainer.X[idx_set])

    for j, target in enumerate(trainer.present_targets):
        mask = y_true[:, j] > 1e-10
        if np.any(mask):
          plt.figure(figsize=(6, 6))
          plt.scatter(y_true[mask, j], y_pred[mask, j], alpha=0.7)
          plt.plot([y_true[mask, j].min(), y_true[mask, j].max()],
                  [y_true[mask, j].min(), y_true[mask, j].max()], 'r--')
          plt.title(f"{target} ({split_name} set)")
          plt.xlabel("Actual")
          plt.ylabel("Predicted")
          plt.grid(True)
          plt.tight_layout()
          fname = f"actual_vs_predicted_coeff_{target}_{split_name}.png"
          plt.savefig(os.path.join(plot_dir, fname))
          plt.close()
          print(f"Saved: {fname}")

# Step 7: Actual vs Predicted - thermophysical properties
def compute_actual_properties_from_coeffs(row, temperature):
    coeff = {col: row.get(col, 0.0) for col in trainer.present_targets}
    return trainer.derived(coeff, temperature)

print("\nPlotting actual vs. predicted thermophysical properties...")

temperature = 900  # Kelvin
properties_to_compare = ["rho", "muA", "muB", "k", "cp"]

for split_name, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):
    actual_vals_dict = {prop: [] for prop in properties_to_compare}
    predicted_vals_dict = {prop: [] for prop in properties_to_compare}

    for idx in idx_set:
        row = trainer.df.iloc[idx]
        actual_coeffs = {col: row.get(col, 0.0) for col in trainer.present_targets}
        actual_props = trainer.derived(actual_coeffs, temperature)

        pred_coeffs = dict(zip(trainer.present_targets, predict_all(trainer.X[[idx]])[0]))
        pred_props = trainer.derived(pred_coeffs, temperature)

        for prop in properties_to_compare:
            a = actual_props.get(prop, None)
            p = pred_props.get(prop, None)
            if a is not None and p is not None and a > 1e-6:
                actual_vals_dict[prop].append(a)
                predicted_vals_dict[prop].append(p)

    for prop in properties_to_compare:
        if actual_vals_dict[prop] and predicted_vals_dict[prop]:
            plt.figure(figsize=(6, 6))
            plt.scatter(actual_vals_dict[prop], predicted_vals_dict[prop], alpha=0.7)
            plt.plot([min(actual_vals_dict[prop]), max(actual_vals_dict[prop])],
                     [min(actual_vals_dict[prop]), max(actual_vals_dict[prop])], 'r--')
            plt.title(f"{prop} at {temperature} K ({split_name} set)")
            plt.xlabel("Actual")
            plt.ylabel("Predicted")
            plt.grid(True)
            plt.tight_layout()
            fname = f"actual_vs_predicted_property_{prop}_{split_name}.png"
            plt.savefig(os.path.join(plot_dir, fname))
            plt.close()
            print(f"Saved: {fname}")
        else:
            print(f"Skipped {prop} ({split_name}) — no valid data.")

print("\nAll plots saved in", plot_dir)


Adding to sys.path: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb
Loading data from: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb/data/mstdb_janz_processed.csv
 • Training base net for Melt(K)
 ⇢ Early stopping for Melt(K)
 • Training base net for Boil(K)
 ⇢ Early stopping for Boil(K)
 • Training base net for rho_a
 ⇢ Early stopping for rho_a
 • Training base net for rho_b
 ⇢ Early stopping for rho_b
 • Training base net for mu1_a
 ⇢ Early stopping for mu1_a
 • Training base net for mu1_b
 ⇢ Early stopping for mu1_b
 • Training base net for mu2_a
 ⇢ Early stopping for mu2_a
 • Training base net for mu2_b
 ⇢ Early stopping for mu2_b
 • Training base net for mu2_c
 ⇢ Early stopping for mu2_c
 • Training base net for k_a
 ⇢ Early stopping for k_a
 • Training base net for k_b
 ⇢ Early stopping for k_b
 • Training base net for cp_a
 ⇢ Early stopping for cp_a
 • Training base net for cp_b
 ⇢ Early stopping for cp_b
 • Training base net for cp_

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

# Ensure local module path is visible
PROJECT_PATH = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("Adding to sys.path:", PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

from processing_mstdb.processor import MSTDBProcessor
from processing_mstdb.kan_trainer import KANMetaTrainer, TARGETS, DERIVED_PROPS

# Step 1: Load and preprocess the data
data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_processed.csv"))
# data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_janz_processed.csv"))
print("Loading data from:", data_path)

processor = MSTDBProcessor.from_csv(data_path)
processor.df.columns = processor.df.columns.str.strip()

composition_type = 'elements'

processor.df['Composition'] = processor.df.apply(
    lambda row: processor.compute_composition(row, composition_type=composition_type), axis=1
)

# Step 2: Train KAN+Meta+Physics
trainer = KANMetaTrainer(processor.df, TARGETS, DERIVED_PROPS)
trainer.train_base()
trainer.train_meta()
trainer.train_joint()

# Step 3: Predict for an example
example_composition = {'Na': 0.5, 'Cl': 0.5}
predicted_coeffs = trainer.predict(example_composition)
print("\nPredicted coefficients for 50-50 NaCl:")
for k, v in predicted_coeffs.items():
    print(f"{k}: {v:.4f}")

# Step 4: Compute derived properties at 900K
derived_props = trainer.derived(predicted_coeffs, 900)
print("\nDerived properties at 900K:")
for k, v in derived_props.items():
    print(f"{k}: {v:.4f}")

# Step 5: Plotting
print("\nPlotting results...")
plot_dir = "kan_prediction_plots"
os.makedirs(plot_dir, exist_ok=True)

# Helper to predict batches
def predict_all(X_input):
    import torch
    trainer.base_nets.eval()
    trainer.meta.eval()
    with torch.no_grad():
        xb = torch.tensor(X_input, dtype=torch.float32, device=trainer.device)
        base_out = torch.stack([trainer.base_nets[p](xb) for p in trainer.present_targets], 1)
        pred = (base_out + trainer.meta(base_out)).cpu().numpy()
        return pred * trainer.σ + trainer.μ

# Step 6: Actual vs Predicted - coefficients
for split_name, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):
    y_true = trainer.y_raw[idx_set]
    y_pred = predict_all(trainer.X[idx_set])

    for j, target in enumerate(trainer.present_targets):
        mask = y_true[:, j] > 1e-10
        if np.any(mask):
            plt.figure(figsize=(6, 6))
            plt.scatter(y_true[mask, j], y_pred[mask, j], alpha=0.7)
            plt.plot([y_true[mask, j].min(), y_true[mask, j].max()],
                    [y_true[mask, j].min(), y_true[mask, j].max()], 'r--')
            plt.title(f"{target} ({split_name} set)")
            plt.xlabel("Actual")
            plt.ylabel("Predicted")
            plt.grid(True)
            plt.tight_layout()
            fname = f"actual_vs_predicted_coeff_{target}_{split_name}.png"
            plt.savefig(os.path.join(plot_dir, fname))
            plt.close()
            print(f"Saved: {fname}")

# Step 7: Actual vs Predicted - thermophysical properties
def compute_actual_properties_from_coeffs(row, temperature):
    coeff = {col: row.get(col, 0.0) for col in trainer.present_targets}
    return trainer.derived(coeff, temperature)

print("\nPlotting actual vs. predicted thermophysical properties...")

temperature = 900  # Kelvin
properties_to_compare = ["rho", "muA", "muB", "k", "cp"]

for split_name, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):
    actual_vals_dict = {prop: [] for prop in properties_to_compare}
    predicted_vals_dict = {prop: [] for prop in properties_to_compare}

    for idx in idx_set:
        row = trainer.df.iloc[idx]
        actual_coeffs = {col: row.get(col, 0.0) for col in trainer.present_targets}
        actual_props = trainer.derived(actual_coeffs, temperature)

        pred_coeffs = dict(zip(trainer.present_targets, predict_all(trainer.X[[idx]])[0]))
        pred_props = trainer.derived(pred_coeffs, temperature)

        for prop in properties_to_compare:
            a = actual_props.get(prop, None)
            p = pred_props.get(prop, None)
            if a is not None and p is not None and a > 1e-6:
                actual_vals_dict[prop].append(a)
                predicted_vals_dict[prop].append(p)

    for prop in properties_to_compare:
        if actual_vals_dict[prop] and predicted_vals_dict[prop]:
            plt.figure(figsize=(6, 6))
            plt.scatter(actual_vals_dict[prop], predicted_vals_dict[prop], alpha=0.7)
            plt.plot([min(actual_vals_dict[prop]), max(actual_vals_dict[prop])],
                     [min(actual_vals_dict[prop]), max(actual_vals_dict[prop])], 'r--')
            plt.title(f"{prop} at {temperature} K ({split_name} set)")
            plt.xlabel("Actual")
            plt.ylabel("Predicted")
            plt.grid(True)
            plt.tight_layout()
            fname = f"actual_vs_predicted_property_{prop}_{split_name}.png"
            plt.savefig(os.path.join(plot_dir, fname))
            plt.close()
            print(f"Saved: {fname}")
        else:
            print(f"Skipped {prop} ({split_name}) — no valid data.")

print("\nAll plots saved in", plot_dir)


Adding to sys.path: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb
Loading data from: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb/data/mstdb_processed.csv

Stage-1: Training base KANs...
 • Training base net for Melt(K)
 • Training base net for Boil(K)
 ⇢ Early stopping for Boil(K)
 • Training base net for rho_a
 ⇢ Early stopping for rho_a
 • Training base net for rho_b
 ⇢ Early stopping for rho_b
 • Training base net for mu1_a
 • Training base net for mu1_b
 • Training base net for mu2_a
 ⇢ Early stopping for mu2_a
 • Training base net for mu2_b
 ⇢ Early stopping for mu2_b
 • Training base net for mu2_c
 ⇢ Early stopping for mu2_c
 • Training base net for k_a
 ⇢ Early stopping for k_a
 • Training base net for k_b
 ⇢ Early stopping for k_b
 • Training base net for cp_a
 ⇢ Early stopping for cp_a
 • Training base net for cp_b
 ⇢ Early stopping for cp_b
 • Training base net for cp_c

Stage-2: Training meta net with physics regularization..

In [None]:
"""
train_and_predict_snn.py  —  end-to-end smoke-test for SNNMetaTrainer
--------------------------------------------------------------------
✓ trains base SNNs
✓ trains meta network with physics regularisation
✓ predicts coefficients & derived props for 50-50 NaCl
✓ makes parity plots for train / test splits (coeffs + thermo-props)
"""
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

# ────────────────────────────────────────────────────────────────
#  Make local package importable
# ────────────────────────────────────────────────────────────────
PROJECT_PATH = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("Adding to sys.path:", PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

from processing_mstdb.processor import MSTDBProcessor
from processing_mstdb.snn_trainer  import SNNMetaTrainer, TARGETS, DERIVED_PROPS

# ────────────────────────────────────────────────────────────────
#  1. Load & preprocess database
# ────────────────────────────────────────────────────────────────
# data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_processed.csv"))
data_path = os.path.abspath(os.path.join(PROJECT_PATH, "data/mstdb_janz_processed.csv"))
print("Loading data from:", data_path)

processor = MSTDBProcessor.from_csv(data_path)
processor.df.columns = processor.df.columns.str.strip() # remove trailing white spaces

composition_type = 'elements'

processor.df['Composition'] = processor.df.apply(
    lambda row: processor.compute_composition(row, composition_type=composition_type), axis=1
)

# ────────────────────────────────────────────────────────────────
#  2. Train SNN + meta + physics regularisation
# ────────────────────────────────────────────────────────────────
trainer = SNNMetaTrainer(processor.df, TARGETS, DERIVED_PROPS)
trainer.train_base()
trainer.train_meta()
trainer.train_joint()

# switch all nets to evaluation mode once
for net in trainer.base_nets.values():
    net.eval()
trainer.meta.eval()

# ────────────────────────────────────────────────────────────────
#  3. Quick demo prediction
# ────────────────────────────────────────────────────────────────
example = {"Na": 0.5, "Cl": 0.5}
coeff = trainer.predict(example)

print("\nPredicted coefficients for 50-50 NaCl:")
for k, v in coeff.items():
    print(f"{k:7s}: {v:11.4f}")

print("\nDerived properties at 900 K:")
derived = trainer.derived(coeff, 900)
for k, v in derived.items():
    print(f"{k:4s}: {v:11.4f}")

# ────────────────────────────────────────────────────────────────
#  4. Helper: batch prediction (physical units)
# ────────────────────────────────────────────────────────────────
def predict_all(X_input: np.ndarray) -> np.ndarray:
    xb = torch.tensor(X_input, dtype=torch.float32, device=trainer.device)
    with torch.no_grad():
        # concatenate → shape (B,14)  (stack would give (B,14,1))
        base_std = torch.cat([trainer.base_nets[p](xb)
                              for p in trainer.present_targets], dim=1)
        pred_std = base_std + trainer.meta(base_std)
    return pred_std.cpu().numpy() * trainer.σ + trainer.μ   # back to physical units

# ────────────────────────────────────────────────────────────────
#  5. Parity plots — coefficients
# ────────────────────────────────────────────────────────────────
print("\nPlotting results …")
plot_dir = "snn_prediction_plots"
os.makedirs(plot_dir, exist_ok=True)

for split, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):

    y_true = trainer.y_raw[idx_set]
    y_pred = predict_all(trainer.X[idx_set])

    for j, tgt in enumerate(trainer.present_targets):
        mask = y_true[:, j] > 1e-10
        if np.any(mask):
            plt.figure(figsize=(6, 6))
            plt.scatter(y_true[mask, j], y_pred[mask, j], alpha=0.7)
            lims = [y_true[mask, j].min(), y_true[mask, j].max()]
            plt.plot(lims, lims, "r--")
            plt.title(f"{tgt} ({split} set)")
            plt.xlabel("Actual"); plt.ylabel("Predicted"); plt.grid(True)
            plt.tight_layout()
            fname = f"actual_vs_predicted_coeff_{tgt}_{split}.png"
            plt.savefig(os.path.join(plot_dir, fname))
            plt.close()
            print(f"Saved: {fname}")

# ────────────────────────────────────────────────────────────────
#  6. Parity plots — derived thermo-physical properties
# ────────────────────────────────────────────────────────────────
print("\nPlotting actual vs. predicted thermo-physical properties …")

T_plot  = 900                                     # K
props   = ["rho", "muA", "muB", "k", "cp"]

for split, idx_set in zip(["train", "test"], [trainer.tr_idx, trainer.te_idx]):

    actual, pred = {p: [] for p in props}, {p: [] for p in props}

    for idx in idx_set:
        # ground-truth coefficients → properties
        row_coeff = {c: trainer.df.iloc[idx][c] for c in trainer.present_targets}
        a_props   = trainer.derived(row_coeff, T_plot)

        # model coefficients → properties
        m_coeff   = dict(zip(trainer.present_targets,
                             predict_all(trainer.X[[idx]])[0]))
        p_props   = trainer.derived(m_coeff, T_plot)

        for p in props:
            a = a_props.get(p);  pr = p_props.get(p)
            if a is not None and pr is not None and a > 1e-6:
                actual[p].append(a); pred[p].append(pr)

    for p in props:
        if actual[p]:
            plt.figure(figsize=(6, 6))
            plt.scatter(actual[p], pred[p], alpha=0.7)
            lims = [min(actual[p]), max(actual[p])]
            plt.plot(lims, lims, "r--")
            plt.title(f"{p} at {T_plot} K ({split} set)")
            plt.xlabel("Actual"); plt.ylabel("Predicted"); plt.grid(True)
            plt.tight_layout()
            fname = f"actual_vs_predicted_property_{p}_{split}.png"
            plt.savefig(os.path.join(plot_dir, fname))
            plt.close()
            print(f"Saved: {fname}")
        else:
            print(f"Skipped {p} ({split}) — not enough data")

print("\nAll plots saved in", plot_dir)


Adding to sys.path: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb
Loading data from: /Users/mauriciotano/projects/jax-cfd-tests/notebooks/processing_mstdb/data/mstdb_processed.csv

Joint Training: Optimizing base and meta SNN networks together...
Epoch   0 | Train: 95.2864 | Val: 2.4491
Epoch   1 | Train: 70.0288 | Val: 1.9255
Epoch   2 | Train: 61.6117 | Val: 1.6369
Epoch   3 | Train: 51.5643 | Val: 1.4401
Epoch   4 | Train: 48.6882 | Val: 1.2414
Epoch   5 | Train: 1.8614 | Val: 1.0559
Epoch   6 | Train: 1.8059 | Val: 0.9244
Epoch   7 | Train: 1.1984 | Val: 0.9088
Epoch   8 | Train: 1.1521 | Val: 0.9152
Epoch   9 | Train: 1.8519 | Val: 0.9066
Epoch  10 | Train: 0.8276 | Val: 0.9794
Epoch  11 | Train: 1.2641 | Val: 1.0037
Epoch  12 | Train: 2.9461 | Val: 1.0065
Epoch  13 | Train: 0.9738 | Val: 0.9724
Epoch  14 | Train: 0.8338 | Val: 0.9641
Epoch  15 | Train: 0.8498 | Val: 0.9338
Epoch  16 | Train: 0.7873 | Val: 0.9613
Epoch  17 | Train: 1.0963 | Val: 1.0498
Epoc



Saved: actual_vs_predicted_coeff_Melt(K)_train.png
Saved: actual_vs_predicted_coeff_Boil(K)_train.png
Saved: actual_vs_predicted_coeff_rho_a_train.png
Saved: actual_vs_predicted_coeff_rho_b_train.png
Saved: actual_vs_predicted_coeff_mu1_a_train.png
Saved: actual_vs_predicted_coeff_mu1_b_train.png
Saved: actual_vs_predicted_coeff_mu2_a_train.png
Saved: actual_vs_predicted_coeff_mu2_b_train.png
Saved: actual_vs_predicted_coeff_mu2_c_train.png
Saved: actual_vs_predicted_coeff_k_a_train.png
Saved: actual_vs_predicted_coeff_cp_a_train.png
Saved: actual_vs_predicted_coeff_cp_b_train.png
Saved: actual_vs_predicted_coeff_Melt(K)_test.png
Saved: actual_vs_predicted_coeff_Boil(K)_test.png
Saved: actual_vs_predicted_coeff_rho_a_test.png
Saved: actual_vs_predicted_coeff_rho_b_test.png
Saved: actual_vs_predicted_coeff_mu1_a_test.png
Saved: actual_vs_predicted_coeff_mu1_b_test.png
Saved: actual_vs_predicted_coeff_mu2_a_test.png
Saved: actual_vs_predicted_coeff_mu2_b_test.png
Saved: actual_vs_predict