In [1]:
import os, sys
import ipywidgets as widgets
from IPython.display import display, clear_output, FileLink
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.interpolate import LinearNDInterpolator
import matplotlib.tri as mtri

# ─────────────────────────────────────────────────────────────────────────────
#  Make local package importable
# ─────────────────────────────────────────────────────────────────────────────
PROJECT_PATH = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, PROJECT_PATH)

from processing_mstdb.processor        import MSTDBProcessor
from processing_mstdb.trainer          import AIModelTrainer
from processing_mstdb.resnet_trainer   import ResNetMetaTrainer, TARGETS as RESNET_TARGETS, DERIVED_PROPS as RESNET_DERIVED_PROPS
from processing_mstdb.kan_trainer      import KANMetaTrainer, TARGETS as KAN_TARGETS, DERIVED_PROPS as KAN_DERIVED_PROPS
from processing_mstdb.snn_trainer      import SNNMetaTrainer, TARGETS as SNN_TARGETS, DERIVED_PROPS as SNN_DERIVED_PROPS

# ─────────────────────────────────────────────────────────────────────────────
#  Discover CSVs in data/
# ─────────────────────────────────────────────────────────────────────────────
data_dir = os.path.join(PROJECT_PATH, "data")
csv_files = sorted(f for f in os.listdir(data_dir) if f.endswith(".csv"))

# ─────────────────────────────────────────────────────────────────────────────
#  Helper: derive thermo‐props from coefficient dict at temperature T
# ─────────────────────────────────────────────────────────────────────────────
R = 8.314
def derive_properties(coeffs, T):
    out = {}
    # Density: ρ = a - b·T (missing a or b → zero)
    a_rho = coeffs.get("rho_a", 0.0)
    b_rho = coeffs.get("rho_b", 0.0)
    out["rho"] = a_rho - b_rho * T

    # Viscosity (Arrhenius): μA = a·exp(b/(R·T))
    a1 = coeffs.get("mu1_a", 0.0)
    b1 = coeffs.get("mu1_b", 0.0)
    out["muA"] = a1 * math.exp(b1 / (R * T))

    # Thermal conductivity: k = a + b·T
    a_k = coeffs.get("k_a", 0.0)
    b_k = coeffs.get("k_b", 0.0)
    out["k"] = a_k + b_k * T

    # Heat capacity: cₚ = a + b·T + c/T²
    a_cp = coeffs.get("cp_a", 0.0)
    b_cp = coeffs.get("cp_b", 0.0)
    c_cp = coeffs.get("cp_c", 0.0)
    out["cp"] = a_cp + b_cp * T + c_cp / (T**2)

    return out

# ─────────────────────────────────────────────────────────────────────────────
#  Widget definitions
# ─────────────────────────────────────────────────────────────────────────────
file_dropdown      = widgets.Dropdown(options=csv_files, description="Dataset:")
load_button        = widgets.Button(description="Load CSV", button_style="primary")
comp_type          = widgets.ToggleButtons(options=["elements","compounds","both"], description="Comp Type:")
component_selector = widgets.SelectMultiple(options=[], description="Include:", rows=6)
plot_comp_selector = widgets.SelectMultiple(options=[], description="Plot Comps:", rows=6)

embedder_dropdown = widgets.Dropdown(
    options=[
        ("None", "none"),
        ("PCA (Principal Component Analysis)", "pca"),
        ("Feature Hashing", "feature_hashing"),
        ("t-SNE (t-distributed Stochastic Neighbor Embedding)", "tsne"),
        ("Low Variance Filter", "low_variance"),
        ("NMF (Non-negative Matrix Factorization)", "nmf"),
        ("SVD (Singular Value Decomposition)", "svd"),
    ],
    value="none",
    description="Embedder:",
    layout=widgets.Layout(width="350px")
)
n_components_slider = widgets.IntSlider(
    value=10,
    min=1,
    max=100,
    step=1,
    description="Components:",
    continuous_update=False
)

filter_button      = widgets.Button(description="Filter", button_style="warning")
export_button      = widgets.Button(description="Export Filtered", button_style="info")

model_dropdown = widgets.Dropdown(
    options=[
        ("Scikit-Learn Polynomial Regressor",      "sklearn"),
        ("ResNet + Meta-Learning + Physics",        "resnet"),
        ("Kernel-Approximation Network (KAN) + Meta","kan"),
        ("Spiking Neural Network (SNN) + Meta",     "snn"),
    ],
    description="Model:",
    layout=widgets.Layout(width="350px")
)
train_button       = widgets.Button(description="Train Model", button_style="success")
prediction_inputs  = widgets.VBox()
temp_slider        = widgets.FloatSlider(value=900, min=300, max=1500, step=10,
                                         description="Temp (K):")
predict_button     = widgets.Button(description="Predict", button_style="info")
output             = widgets.Output()

temp_range = widgets.FloatRangeSlider(
    value=(300, 1500), min=300, max=1500, step=10,
    description='Temp Range (K):', continuous_update=False
)
plot_button = widgets.Button(description='Plot Range', button_style='primary')
comp_plot_button = widgets.Button(description='Plot vs Composition', button_style='info')


# ─────────────────────────────────────────────────────────────────────────────
#  Debuggable helper: populate Include‐box
# ─────────────────────────────────────────────────────────────────────────────
def update_component_options(*_):
    print("🔄 update_component_options fired")
    try:
        elems = processor.predefined_elements
        comps = processor.predefined_compounds
        print(f"   processor has {len(elems)} elements and {len(comps)} compounds")
    except NameError:
        print("   ⚠️ processor not defined yet")
        return

    if comp_type.value == "elements":
        opts = sorted(elems)
    elif comp_type.value == "compounds":
        opts = sorted(comps)
    else:
        opts = sorted(elems | comps)

    component_selector.options = opts
    preview = opts[:5] + (["…"] if len(opts)>5 else [])
    print(f"   → set component_selector.options to {preview}")

    plot_comp_selector.options = opts

# ─────────────────────────────────────────────────────────────────────────────
#  Debuggable callback: Load the selected CSV
# ─────────────────────────────────────────────────────────────────────────────
def on_load_clicked(_):
    print("▶️ on_load_clicked fired")
    with output:
        clear_output()
        fname = os.path.join(data_dir, file_dropdown.value)
        p = MSTDBProcessor.from_csv(fname)
        globals()["processor"] = p
        # reset any previous filter
        globals().pop("filtered_proc", None)
        print(f"   ✅ Loaded `{file_dropdown.value}` ({len(p.df)} rows)")

        # Compute Composition
        print(f"🔄 Computing Composition column (type={comp_type.value})")
        processor.df["Composition"] = processor.df.apply(
            lambda row: processor.compute_composition(row, composition_type=comp_type.value),
            axis=1
        )

        # repopulate Include‐box
        update_component_options()

# ─────────────────────────────────────────────────────────────────────────────
#  Callback: comp_type change
# ─────────────────────────────────────────────────────────────────────────────
def on_comp_type_changed(change):
    print(f"▶️ comp_type changed to {comp_type.value}")
    if "processor" in globals():
        print(f"🔄 Recomputing Composition column (type={comp_type.value})")
        processor.df["Composition"] = processor.df.apply(
            lambda row: processor.compute_composition(row, composition_type=comp_type.value),
            axis=1
        )
    update_component_options()

# ─────────────────────────────────────────────────────────────────────────────
#  NEW: Filter callback
# ─────────────────────────────────────────────────────────────────────────────
def on_filter_clicked(_):
    print("▶️ on_filter_clicked fired")
    with output:
        clear_output()
        if "processor" not in globals():
            print("⚠️  Load a dataset first.")
            return
        if not component_selector.value:
            print("⚠️  Select at least one element/compound to filter.")
            return

        flt = {"include": {comp_type.value: list(component_selector.value)}}
        fp = processor.filter_by_components(flt)
        globals()["filtered_proc"] = fp
        plot_comp_selector.options = list(component_selector.value)
        print(f"   ✅ Filter applied: {len(fp.df)} rows remain")

# ─────────────────────────────────────────────────────────────────────────────
#  NEW: Export callback
# ─────────────────────────────────────────────────────────────────────────────
def on_export_clicked(_):
    print("▶️ on_export_clicked fired")
    with output:
        clear_output()
        if "filtered_proc" not in globals():
            print("⚠️  Nothing to export. Run **Filter** first.")
            return
        out_path = "filtered_data.csv"
        filtered_proc.df.to_csv(out_path, index=False)
        print(f"   ✅ Exported filtered data to `{out_path}`")
        display(FileLink(out_path))

# ─────────────────────────────────────────────────────────────────────────────
#  Callback: Train the chosen model on the filtered data
# ─────────────────────────────────────────────────────────────────────────────
def on_train_clicked(_):
    print("▶️ on_train_clicked fired")
    with output:
        clear_output()

        # 1️⃣ Make sure we have a base processor
        if "processor" not in globals():
            print("⚠️  Load a dataset first.")
            return

        # 2️⃣ Grab the user’s include list
        includes = list(component_selector.value)
        if not includes:
            print("⚠️  Select at least one element/compound.")
            return

        # 3️⃣ Apply the filter right now (overwrite filtered_proc)
        print(f"🔄 Applying filter: include {comp_type.value} = {includes}")
        flt = {"include": {comp_type.value: includes}}
        fp = processor.filter_by_components(flt)
        globals()["filtered_proc"] = fp
        print(f"   ✅ Filtered down to {len(fp.df)} rows")

        # 4️⃣ Now train on filtered_proc.df
        emethod = embedder_dropdown.value
        ncomp   = n_components_slider.value
        cm = model_dropdown.value
        globals()["current_model"] = cm
        ds_df = filtered_proc.df

        print(f"🛠 Training {cm} with embedder={emethod}, n_components={ncomp} on {len(ds_df)} samples…")
        if cm == "sklearn":
            print(emethod)
            globals()["trainer"] = AIModelTrainer(ds_df, embedding_method=emethod, embedding_params={"n_components": ncomp})
            trainer.train_all()
        elif cm == "resnet":
            globals()["trainer"] = ResNetMetaTrainer(
                ds_df,
                RESNET_TARGETS,
                RESNET_DERIVED_PROPS,
                embedding_method=emethod,
                n_components=ncomp
            )
            trainer.train_base(); trainer.train_meta(); trainer.train_joint()
        elif cm == "kan":
            globals()["trainer"] = KANMetaTrainer(
                ds_df,
                RESNET_TARGETS,
                RESNET_DERIVED_PROPS,
                embedding_method=emethod,
                n_components=ncomp
            )
            trainer.train_base(); trainer.train_meta(); trainer.train_joint()
        else:  # snn
            globals()["trainer"] = SNNMetaTrainer(
                ds_df,
                RESNET_TARGETS,
                RESNET_DERIVED_PROPS,
                embedding_method=emethod,
                n_components=ncomp
            )
            trainer.train_base(); trainer.train_meta(); trainer.train_joint()

        print(f"   ✅ {cm} training complete.")

        # 5️⃣ Rebuild the prediction widgets
        prediction_inputs.children = [
            widgets.BoundedFloatText(value=0.0, min=0.0, max=1.0, step=0.01, description=c)
            for c in includes
        ]

# ─────────────────────────────────────────────────────────────────────────────
#  Callback: Predict either direct properties (sklearn) or coeff→properties
# ─────────────────────────────────────────────────────────────────────────────
def on_predict_clicked(_):
    print("▶️ on_predict_clicked fired")
    with output:
        clear_output()
        if "trainer" not in globals():
            print("⚠️  Train a model first.")
            return

        # Build a one‐row DataFrame from the inputs
        comp_dict = {w.description: w.value for w in prediction_inputs.children}
        print(f"🔍 Composition dict: {comp_dict}")
        # Align columns to the trainer's composition_df
        import pandas as _pd
        input_df = _pd.DataFrame([comp_dict])
        # fill any missing columns with zero
        cols = trainer.composition_df.columns
        input_df = input_df.reindex(columns=cols, fill_value=0.0)
        print(f"   → Input columns aligned to {list(cols)}")

        # Apply polynomial expansion
        X_poly = trainer.poly.transform(input_df)
        print(f"   → X_poly shape: {X_poly.shape}")

        # Now loop over each target
        preds = {}
        for target, model in trainer.best_models.items():
            scaler = trainer.scalers[target]
            X_scaled = scaler.transform(X_poly)
            p = model.predict(X_scaled)[0]
            # match original logic for densities
            if "a" in target:
                p = max(p, 1e-10)
            preds[target] = p
            print(f"   • Raw prediction for {target}: {p:.4e}")

        globals()['last_preds'] = preds
        print("🔖 Coefficients saved for plotting")

        if current_model == "sklearn":
            # For sklearn we directly predicted properties
            print("\n✅ Predicted Properties:")
            for k, v in preds.items():
                print(f"   {k}: {v:.4f}")

        else:
            # For the other networks, preds are fit‐coefficients
            print("\n✅ Predicted Fit Coefficients:")
            for k, v in preds.items():
                print(f"   {k}: {v:.4f}")

            # derive actual thermo props at the chosen temperature
            T = temp_slider.value
            derived = trainer.derived(preds, T)
            print(f"\n   → Derived properties at {T} K:")
            for k, v in derived.items():
                print(f"     {k}: {v:.4f}")


# ─────────────────────────────────────────────────────────────────────────────
#  Callback: Plot predicted properties
# ─────────────────────────────────────────────────────────────────────────────
def on_plot_clicked(_):
    with output:
        clear_output()
        if "last_preds" not in globals():
            print("⚠️  Please run Predict first.")
            return

        coeffs = globals()["last_preds"]
        Tmin, Tmax = temp_range.value
        Ts = np.linspace(Tmin, Tmax, 50)

        # map property keys to nice labels
        prop_labels = {
            "rho": "Density ρ (kg/m³)",
            "muA": "Viscosity μ (Arrhenius)",
            "k":   "Thermal Conductivity k (W/m·K)",
            "cp":  "Heat Capacity cₚ (J/kg·K)"
        }

        # collect curves
        curves = {}
        for T in Ts:
            props = derive_properties(coeffs, T)
            for name, val in props.items():
                curves.setdefault(name, {"x": [], "y": []})
                curves[name]["x"].append(T)
                curves[name]["y"].append(val)

        import matplotlib.pyplot as plt

        for name, data in curves.items():
            if len(data["x"]) < 2:
                continue

            # build the equation string
            if name == "rho":
                eq = f"ρ = {coeffs['rho_a']:.2e} – {coeffs['rho_b']:.2e}·T"
            elif name == "muA":
                eq = f"μ = {coeffs['mu1_a']:.2e}·exp({coeffs['mu1_b']:.2e}/(R·T))"
            elif name == "k":
                eq = f"k = {coeffs['k_a']:.2e} + {coeffs['k_b']:.2e}·T"
            else:  # cp
                eq = (f"cₚ = {coeffs['cp_a']:.2e} + {coeffs['cp_b']:.2e}·T + "
                      f"{coeffs['cp_c']:.2e}/T²")

            # publication-quality plot
            plt.figure(figsize=(6,4))
            plt.plot(data["x"], data["y"], linewidth=2, marker='o', markersize=4)
            plt.grid(True, linestyle="--", alpha=0.6)
            plt.xlabel("Temperature (K)", fontsize=12)
            plt.ylabel(prop_labels.get(name, name), fontsize=12)
            if name in ("muA"):
                plt.yscale("log")
            plt.title(f"{prop_labels.get(name, name)} vs Temperature", fontsize=14)
            plt.legend([eq], fontsize=10, loc="best")
            plt.tight_layout()
            plt.show()

# ─────────────────────────────────────────────────────────────────────────────
#  New: Plot vs Composition callback
# ─────────────────────────────────────────────────────────────────────────────
def on_comp_plot_clicked(_):
    print("▶️ on_comp_plot_clicked fired")
    with output:
        clear_output()
        if "trainer" not in globals():
            print("⚠️  Train a model first.")
            return

        comps = list(plot_comp_selector.value)
        if not comps:
            print("⚠️  Select at least one component.")
            return

        T = temp_slider.value
        grid = np.linspace(0,1,50)

        import pandas as _pd

        def predict_coeffs(comp_dict):
            # exactly same feature‐prep as on_predict_clicked
            df = _pd.DataFrame([comp_dict])
            df = df.reindex(columns=trainer.composition_df.columns, fill_value=0.0)
            X_poly = trainer.poly.transform(df)
            preds = {}
            for target, model in trainer.best_models.items():
                scaler = trainer.scalers[target]
                Xs = scaler.transform(X_poly)
                p = model.predict(Xs)[0]
                if "a" in target:
                    p = max(p, 1e-10)
                preds[target] = p
            return preds

        # single‐component sweep
        if len(comps) == 1:
            comp = comps[0]
            data = {}
            for f in grid:
                coeffs = predict_coeffs({comp: f})
                props  = derive_properties(coeffs, T)
                for name, val in props.items():
                    data.setdefault(name, []).append(val)

            for name, vals in data.items():
                plt.figure(figsize=(6,4))
                plt.plot(grid, vals, '-o', linewidth=2, markersize=4)
                if name in ("muA"):
                    plt.yscale("log")
                plt.xlabel(f"{comp} Fraction", fontsize=12)
                plt.ylabel(name, fontsize=12)
                plt.title(f"{name} vs {comp} at {T} K", fontsize=14)
                plt.grid(True, linestyle='--', alpha=0.6)
                plt.tight_layout()
                plt.show()

        # binary sweep
        elif len(comps) == 2:
            c1, c2 = comps
            data = {}
            for f in grid:
                coeffs = predict_coeffs({c1: f, c2: 1-f})
                props  = derive_properties(coeffs, T)
                for name, val in props.items():
                    data.setdefault(name, []).append(val)

            for name, vals in data.items():
                plt.figure(figsize=(6,4))
                plt.plot(grid, vals, '-o', linewidth=2, markersize=4)
                if name in ("muA"):
                    plt.yscale("log")
                plt.xlabel(f"{c1} Fraction", fontsize=12)
                plt.ylabel(name, fontsize=12)
                plt.title(f"{name} over {c1}/{c2} at {T} K", fontsize=14)
                plt.grid(True, linestyle='--', alpha=0.6)
                plt.tight_layout()
                plt.show()

        # ternary sweep (first three comps)
        else:
            from scipy.interpolate import LinearNDInterpolator
            import matplotlib.tri as mtri

            labels   = comps[:3]
            n_coarse = 20

            for name in ["rho","muA","k","cp"]:
                # 1) build coarse (xs, ys, cs) for this property
                xs, ys, cs = [], [], []
                for i in range(n_coarse+1):
                    for j in range(n_coarse+1-i):
                        kf = n_coarse - i - j
                        f1, f2, f3 = i/n_coarse, j/n_coarse, kf/n_coarse
                        coeffs = predict_coeffs({
                            labels[0]: f1,
                            labels[1]: f2,
                            labels[2]: f3
                        })
                        props = derive_properties(coeffs, T)
                        if name in props:
                            xs.append(0.5*(2*f2 + f3))
                            ys.append((np.sqrt(3)/2)*f3)
                            cs.append(props[name])

                if not xs:
                    continue

                # 2) fit interpolator
                pts   = np.vstack((xs, ys)).T
                interp = LinearNDInterpolator(pts, cs)

                # 3) sample fine grid
                n_fine = 100
                x_fine, y_fine = [], []
                for i in range(n_fine+1):
                    for j in range(n_fine+1-i):
                        kf = n_fine - i - j
                        f1, f2, f3 = i/n_fine, j/n_fine, kf/n_fine
                        x = 0.5*(2*f2 + f3)
                        y = (np.sqrt(3)/2)*f3
                        x_fine.append(x)
                        y_fine.append(y)
                x_fine = np.array(x_fine)
                y_fine = np.array(y_fine)

                # 4) evaluate and mask NaNs
                z_fine = interp(x_fine, y_fine)
                mask   = ~np.isnan(z_fine)
                x_f, y_f, z_f = x_fine[mask], y_fine[mask], z_fine[mask]

                # 5) triangulate and contour
                tri = mtri.Triangulation(x_f, y_f)
                plt.figure(figsize=(6,6))

                if name in ("muA"):
                    from matplotlib.colors import LogNorm
                    cf = plt.tricontourf(
                        tri, z_f,
                        levels=20,
                        cmap='viridis',
                        norm=LogNorm(),
                        alpha=0.9
                    )
                else:
                    cf = plt.tricontourf(tri, z_f, levels=20, cmap='viridis', alpha=0.9)

                plt.colorbar(cf, label=name, shrink=0.8)

                # 6) border & vertex labels
                verts_x = [0, 1, 0.5, 0]
                verts_y = [0, 0, np.sqrt(3)/2, 0]
                plt.plot(verts_x, verts_y, 'k-', lw=1)
                plt.text(0, -0.05,  labels[0], ha='center', va='top', fontsize=12)
                plt.text(1, -0.05,  labels[1], ha='center', va='top', fontsize=12)
                plt.text(0.5, np.sqrt(3)/2+0.03, labels[2], ha='center', va='bottom', fontsize=12)

                plt.title(f"{name} ternary at {T} K", fontsize=14)
                plt.axis('off')
                plt.tight_layout()
                plt.show()


# ─────────────────────────────────────────────────────────────────────────────
#  Wire callbacks and display the UI
# ─────────────────────────────────────────────────────────────────────────────
load_button.on_click(on_load_clicked)
comp_type.observe(on_comp_type_changed, names="value")
filter_button.on_click(on_filter_clicked)
export_button.on_click(on_export_clicked)
train_button.on_click(on_train_clicked)
predict_button.on_click(on_predict_clicked)
plot_button.on_click(on_plot_clicked)
comp_plot_button.on_click(on_comp_plot_clicked)

ui = widgets.VBox([
    widgets.HBox([file_dropdown, load_button]),
    widgets.HBox([comp_type, component_selector]),
    widgets.HBox([filter_button, export_button]),
    widgets.HBox([embedder_dropdown, n_components_slider]),
    widgets.HBox([model_dropdown, train_button]),
    widgets.Label("Prediction Inputs:"),
    prediction_inputs,
    temp_slider,
    predict_button,
    temp_range,
    plot_button,
    widgets.Label("Composition Evolution:"),
    widgets.Label("Choose comps for composition‐plot:"),
    plot_comp_selector,
    comp_plot_button,
    output
])

display(ui)




VBox(children=(HBox(children=(Dropdown(description='Dataset:', options=('mstdb_janz_processed.csv', 'mstdb_pro…

▶️ on_load_clicked fired
▶️ comp_type changed to compounds
🔄 Recomputing Composition column (type=compounds)
🔄 update_component_options fired
   processor has 19 elements and 34 compounds
   → set component_selector.options to ['AlCl3', 'BeCl2', 'BeF2', 'CaCl2', 'CaF2', '…']
▶️ on_filter_clicked fired
▶️ on_train_clicked fired
▶️ on_train_clicked fired
