In [17]:
#!/usr/bin/env python
"""
Unified EIS Training + Inference + Dynamic RUL (v8.8 – Robust SoC HGB + dual-space OOD + hf/10Hz autoscale)
+ Gradio UI (Jupyter-friendly)
"""

from __future__ import annotations
import re, json, math, random, warnings, joblib, hashlib, uuid, io, sys, os
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy import linalg
from scipy.interpolate import interp1d

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.decomposition import PCA
from sklearn.isotonic import IsotonicRegression
from sklearn.covariance import LedoitWolf
from sklearn.exceptions import ConvergenceWarning

import matplotlib.pyplot as plt
from PIL import Image

# Quiet down some noisy warnings from GPR optimization (optional)
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.gaussian_process")
warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")

# =========================
# Helpers: environment detection
# =========================
def _running_in_notebook() -> bool:
    try:
        from IPython import get_ipython  # noqa
        shell = get_ipython().__class__.__name__
        return shell in ("ZMQInteractiveShell", "Shell")  # Jupyter/Lab/VSCode
    except Exception:
        return False

# =========================
# 1. CONFIGURATION
# =========================
@dataclass
class Config:
    # Training data directories (update if needed)
    EIS_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")

    # Test files
    EIS_TEST_FILES: List[Path] = None  # assigned after instantiation

    # Frequency interpolation grid
    F_MIN: float = 1e-2
    F_MAX: float = 1e4
    N_FREQ: int = 60

    # Uncertainty control
    SOH_STD_MAX_OOD: float = 2.0

    # Train / split settings
    TEST_FRAC: float = 0.2
    GROUP_KFOLDS: int = 0
    RANDOM_STATE: int = 42

    # PCA toggles
    USE_PCA_SOC: bool = False  # <-- tree model (HGB) doesn't need PCA; default off
    USE_PCA_SOH: bool = False
    PCA_SOC_COMPONENTS: int = 25
    PCA_SOH_COMPONENTS: int = 30

    # Feature group toggles
    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True

    # DRT params
    DRT_POINTS: int = 60
    DRT_TAU_MIN: float = 1e-4
    DRT_TAU_MAX: float = 1e4
    DRT_LAMBDA: float = 1e-2

    # Capacity-based refinement
    REFINE_SOH_WITH_CAPACITY: bool = True

    # SoH modeling
    MAX_GPR_TRAIN_SAMPLES: int = 3500
    INCLUDE_NORMALIZED_SHAPE_MODEL: bool = True
    ENSEMBLE_SOH: bool = True
    NORMALIZE_SHAPE_BY_HF_RE: bool = True

    # SoC modeling (no GPR)
    SOC_INCLUDE_SHAPE_MODEL: bool = True
    SOC_MAX_GPR_TRAIN_SAMPLES: int = 3500  # kept for symmetry; unused now

    # RUL parameters
    DECISION_SOH_PERCENT: float = 50.0
    ILLUSTRATIVE_MIN_SOH: float = 40.0
    CPP_ROLLING_WINDOW: int = 5
    CPP_MIN_POINTS: int = 6
    CPP_FALLBACK: float = 20.0

    # Inference extras
    TEST_TEMPERATURE_OVERRIDE: Optional[float] = 25.0
    FORCE_RETRAIN: bool = False  # retrain anyway (overrides signature check) if True

    # Saving / logging
    SAVE_FEATURE_TABLE: bool = True
    VERBOSE: bool = True
    FEATURE_VERSION: int = 88  # bump: robust SoC HGB, dual-space OOD, hf/10Hz autoscale

    # OOD thresholds (SoH)
    MAHAL_THRESHOLD: float = 10.0
    GP_ARD_NORM_THRESHOLD: float = 6.0

    # Projection curve shape exponent
    PLOT_EXPONENT: float = 1.25

    # Thresholds to report/plot
    TARGET_SOH_THRESHOLDS: Tuple[float, ...] = (80.0, 50.0, 40.0)

    # --- SoC OOD / prior settings (no fixed 50 prior) ---
    OOD_SOC_ENABLE: bool = True
    OOD_SOC_Q: float = 0.999          # quantile for distance threshold
    OOD_SOC_SHRINK_SCALE: float = 10.0
    OOD_SOC_W_MIN: float = 0.45       # keep decent weight on raw model
    SOC_OOD_K: int = 30               # neighbors for KNN prior
    SOC_CALIBRATE_ON_OOD: bool = False  # avoid isotonic pull on OOD

    # Std caps to keep SoC std realistic
    SOC_STD_CAP_MULT_IN: float = 1.00   # in-domain cap ~ val RMSE
    SOC_STD_CAP_MULT_OOD: float = 1.20  # OOD cap slightly larger

    # Optional: de-discretize labels
    SOC_LABEL_JITTER: float = 0.0

cfg = Config()
if cfg.EIS_TEST_FILES is None:
    cfg.EIS_TEST_FILES = [Path("Mazda-Battery-Cell9.xlsx")]
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Helper to tweak paths easily in notebooks
def set_paths(eis_dir: str | Path, cap_dir: str | Path, model_dir: str | Path | None = None):
    cfg.EIS_DIR = Path(eis_dir)
    cfg.CAP_DIR = Path(cap_dir)
    if model_dir is not None:
        cfg.MODEL_DIR = Path(model_dir)
        cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# 2. UTILITIES
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path): return str(x)
    if isinstance(x, dict): return {k: to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list, tuple)): return [to_jsonable(i) for i in x]
    return x

def config_signature(cfg: Config) -> str:
    d = asdict(cfg).copy()
    d["EIS_DIR"] = str(d["EIS_DIR"]); d["CAP_DIR"]=str(d["CAP_DIR"]); d["MODEL_DIR"]=str(d["MODEL_DIR"])
    d.pop("EIS_TEST_FILES", None)
    blob = json.dumps(d, sort_keys=True)
    return hashlib.sha256(blob.encode("utf-8")).hexdigest()

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# =========================
# 3. REGEX
# =========================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

# =========================
# 4. PARSERS
# =========================
def parse_eis_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = EIS_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "SOC": float(d["SOC"]),
        "Temp": int(d["Temp"]),
        "RealSOH_file": int(d["RealSOH"])/100.0
    }

def parse_cap_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = CAP_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "Temp": int(d["Temp"]),
        "CycleIndex": int(d["Cycle"])
    }

# =========================
# 5. LOW-LEVEL LOADERS / INTERPOLATION
# =========================
def _find_matrix(mat_dict: dict):
    for v in mat_dict.values():
        if isinstance(v, np.ndarray) and v.ndim == 2 and v.shape[1] >= 3 and v.shape[0] >= 10:
            return v
    return None

def _interp_channel(freq_raw, y_raw, freq_target):
    freq_raw = np.asarray(freq_raw).astype(float)
    y_raw = np.asarray(y_raw).astype(float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw = freq_raw[::-1]; y_raw = y_raw[::-1]
    uniq, idx = np.unique(freq_raw, return_index=True)
    if len(uniq) != len(freq_raw):
        order = np.argsort(idx)
        freq_raw = uniq[order]; y_raw = y_raw[idx][order]
    f = interp1d(freq_raw, y_raw, bounds_error=False,
                 fill_value=(y_raw[0], y_raw[-1]), kind="linear")
    return f(freq_target)

FREQ_CANDS = ["frequency","freq","f","hz","frequency(hz)","Frequency(Hz)"]
RE_CANDS   = ["zreal","re(z)","re","real","z_re","zreal(ohm)","re (ohm)","re(z) (ohm)","Zreal","Zreal (ohm)","Zreal(ohm)"]
IM_CANDS   = ["-zimag","zimag","im(z)","im","imag","imaginary","z_im","zimg","z_imag"," -Zimag (ohm)"," -Zimag(ohm)","-Zimag","Zimag","Zimag (ohm)"]

def _select_column(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    low = {c.lower(): c for c in df.columns}
    for c in cands:
        if c.lower() in low: return low[c.lower()]
    for c in cands:
        for col in df.columns:
            if c.lower() in col.lower():
                return col
    return None

def load_mat_eis(path: Path):
    mat = loadmat(path)
    arr = _find_matrix(mat)
    if arr is None:
        raise ValueError(f"No valid EIS matrix in {path.name}")
    return arr[:,0].astype(float), arr[:,1].astype(float), arr[:,2].astype(float)

def load_table_eis(path: Path):
    # returns (freq, re, im, used_freq_from_file: bool)
    if path.suffix.lower() == ".csv":
        df = pd.read_csv(path)
    else:
        df = pd.read_excel(path)
    if df.empty:
        raise ValueError("Empty table.")
    fcol = _select_column(df, FREQ_CANDS)
    recol = _select_column(df, RE_CANDS)
    imcol = _select_column(df, IM_CANDS)
    if recol is None or imcol is None:
        raise ValueError(f"Missing Re/Im columns in {path.name}")
    re_vals = pd.to_numeric(df[recol], errors="coerce").to_numpy()
    im_vals = pd.to_numeric(df[imcol], errors="coerce").to_numpy()
    used_freq = True
    if fcol is not None:
        freq_vals = pd.to_numeric(df[fcol], errors="coerce").to_numpy()
    else:
        used_freq = False
        n = min(len(re_vals), len(im_vals))
        freq_vals = np.geomspace(cfg.F_MAX, cfg.F_MIN, n)
    n = min(len(freq_vals), len(re_vals), len(im_vals))
    freq_vals = freq_vals[:n]; re_vals = re_vals[:n]; im_vals = im_vals[:n]
    if np.nanmean(im_vals) > 0:
        im_vals = -im_vals
    return freq_vals.astype(float), re_vals.astype(float), im_vals.astype(float), used_freq

def load_any_inference(path: Path):
    suf = path.suffix.lower()
    if suf == ".mat":
        f,r,i = load_mat_eis(path); used=True
    elif suf in (".csv",".xls",".xlsx"):
        f,r,i,used = load_table_eis(path)
    else:
        raise ValueError(f"Unsupported test file extension: {suf}")
    return f,r,i,used

# =========================
# 6. FEATURE ENGINEERING
# =========================
def compute_F_features(freq, re_i, im_i):
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    F1 = re_i[0]; F2 = re_i[idx_peak]; F3 = re_i[-1]
    sc = np.where(np.sign(im_i[:-1]) != np.sign(im_i[1:]))[0]
    if len(sc):
        k = sc[0]; y0,y1 = im_i[k], im_i[k+1]
        w = -y0/(y1 - y0 + 1e-12)
        F4 = re_i[k] + w*(re_i[k+1]-re_i[k])
    else:
        F4 = np.nan
    F5 = (re_i[idx_peak]-F1) if idx_peak>0 else np.nan
    F6 = np.min(im_i)
    mid_target = 10.0
    idx_mid = int(np.argmin(np.abs(freq-mid_target)))
    F7 = re_i[idx_mid]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES = [
    "Rs","Rct","tau_peak","warburg_sigma","arc_quality",
    "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"
]

def physical_features(freq, re_i, im_i):
    freq = np.asarray(freq); re_i = np.asarray(re_i); im_i = np.asarray(im_i)
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    Rs = float(re_i[0]); Rpeak = float(re_i[idx_peak]); Rlow = float(re_i[-1])
    Rct = max(Rpeak - Rs, 0.0)
    arc_diam = Rlow - Rs
    norm_arc = arc_diam / (Rs + 1e-9)
    f_peak = float(freq[idx_peak])
    tau_peak = 1.0/(2*math.pi*f_peak) if f_peak>0 else np.nan
    K = min(10, len(freq)//3)
    if K >= 4:
        w_section = (2*np.pi*freq[-K:])**(-0.5)
        re_section = re_i[-K:]
        if len(np.unique(w_section)) > 2:
            warburg_sigma = float(np.polyfit(w_section, re_section, 1)[0])
        else:
            warburg_sigma = np.nan
    else:
        warburg_sigma = np.nan
    phase = np.arctan2(-im_i, re_i)
    mid_mask = (freq>=1) & (freq<=100)
    if mid_mask.sum()>2:
        phase_mean_mid = float(phase[mid_mask].mean())
        phase_std_mid  = float(phase[mid_mask].std())
    else:
        phase_mean_mid = np.nan; phase_std_mid = np.nan
    phase_min = float(phase.min())
    lf_mask = (freq<=1.0)
    if lf_mask.sum() >= 4:
        x = np.log10(freq[lf_mask]+1e-12); y = neg_im[lf_mask]
        lf_slope = np.polyfit(x, y, 1)[0]
    else:
        lf_slope = np.nan
    arc_quality = (neg_im.max() - neg_im.min())/(abs(neg_im.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_quality,
            phase_mean_mid,phase_std_mid,phase_min,lf_slope,norm_arc]

BANDS = [(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq, re_i, im_i):
    feats=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m], im_i[m])
            feats += [z.mean(), z.std()]
        else:
            feats += [np.nan, np.nan]
    return feats

def diff_slopes(freq, re_i, im_i, segments=5):
    logf = np.log10(freq)
    edges = np.linspace(logf.min(), logf.max(), segments+1)
    out=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            x=logf[m]
            out += [np.polyfit(x,re_i[m],1)[0], np.polyfit(x,(-im_i)[m],1)[0]]
        else:
            out += [np.nan, np.nan]
    return out

DRT_FEATURE_NAMES = [
    "drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
    "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"
]

def compute_drt(freq,re_i,im_i,tau_min,tau_max,n_tau,lam):
    w = 2*np.pi*freq
    tau = np.geomspace(tau_max, tau_min, n_tau)
    WT = w[:,None]*tau[None,:]
    denom = 1+WT**2
    K_re = 1.0/denom
    K_im = -WT/denom
    R_inf = re_i[0]
    y_re = re_i - R_inf
    y_im = im_i
    Y = np.concatenate([y_re, y_im])
    K = np.vstack([K_re, K_im])
    A = K.T @ K + lam*np.eye(n_tau)
    b = K.T @ Y
    gamma = linalg.solve(A,b,assume_a='pos')
    gamma = np.clip(gamma,0,None)
    return tau, gamma

def drt_features(freq,re_i,im_i):
    try:
        tau,gamma = compute_drt(freq,re_i,im_i,
                                 cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,
                                 cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        log_tau = np.log10(tau)
        g_sum = gamma.sum()+1e-12
        w_norm = gamma/g_sum
        mean_logtau = float((w_norm*log_tau).sum())
        var_logtau  = float((w_norm*(log_tau-mean_logtau)**2).sum())
        p = int(np.argmax(gamma))
        peak_tau = float(tau[p]); peak_gamma=float(gamma[p])
        mid = np.median(log_tau)
        frac_low = float(w_norm[log_tau<=mid].sum())
        frac_high = 1-frac_low
        return [g_sum,mean_logtau,var_logtau,peak_tau,peak_gamma,frac_low,frac_high]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i, im_i, temp, freq, include_names=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts += [re_i, im_i]
        names += [f"Re_{i}" for i in range(len(re_i))] + [f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z = np.hypot(re_i, im_i)
        basics=[re_i[0], re_i[-1], re_i[-1]-re_i[0], z.max(), z.mean(), z.std()]
        parts.append(np.array(basics)); names += ["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        Ff=compute_F_features(freq,re_i,im_i); parts.append(np.array(Ff)); names += [f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        Pf=physical_features(freq,re_i,im_i); parts.append(np.array(Pf)); names += PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        Bf=band_stats(freq,re_i,im_i); parts.append(np.array(Bf))
        for bi in range(len(BANDS)): names += [f"band{bi}_mean", f"band{bi}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        Ds=diff_slopes(freq,re_i,im_i); parts.append(np.array(Ds))
        for i in range(len(Ds)//2): names += [f"slope_re_seg{i}", f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        Df=drt_features(freq,re_i,im_i); parts.append(np.array(Df)); names += DRT_FEATURE_NAMES
    parts.append(np.array([temp])); names += ["Feat_Temp"]
    vec = np.concatenate(parts).astype(float)
    vec = np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    if include_names: return vec, names
    return vec

def build_shape_normalized(re_i, im_i, k: int = 5):
    hf = float(np.nanmedian(re_i[:max(1, min(k, len(re_i)))]))
    if not np.isfinite(hf) or abs(hf) < 1e-9:
        hf = 1.0
    return re_i / hf, im_i / hf

# =========================
# 7. CAPACITY & CPP
# =========================
def load_capacity_info(cap_dir: Path) -> pd.DataFrame:
    """
    Robust loader:
    - If a plain 2-D numeric matrix exists, use it (original behavior).
    - Else, try MATLAB struct with cumulative 'AhAccu' or 'WhAccu' arrays.
    """
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta = parse_cap_metadata(fp.stem)
        if not meta:
            continue
        try:
            mat = loadmat(fp, squeeze_me=True, struct_as_record=False)

            # Old path: plain 2-D numeric matrix
            arr = _find_matrix(mat)
            cap = None
            if arr is not None:
                col = np.argmax(np.abs(arr[-50:, :]).mean(axis=0))
                cap = float(np.nanmax(arr[:, col]))
            else:
                # New path: MATLAB struct with cumulative capacity
                d = mat.get("data", None)
                if d is not None:
                    def _cell_to_1d(x):
                        a = np.array(x, dtype=object).squeeze()
                        out=[]
                        for e in a.flat:
                            if isinstance(e, np.ndarray):
                                out.append(float(np.nanmax(e.astype(float))) if e.size else np.nan)
                            else:
                                try: out.append(float(e))
                                except Exception: out.append(np.nan)
                        z = np.array(out, dtype=float)
                        if z.ndim == 0: z = z[None]
                        return z

                    if hasattr(d, "AhAccu"):
                        v = _cell_to_1d(getattr(d, "AhAccu"))
                        if v.size: cap = float(np.nanmax(v))
                    if cap is None and hasattr(d, "WhAccu"):
                        v = _cell_to_1d(getattr(d, "WhAccu"))
                        if v.size: cap = float(np.nanmax(v) / 3.7)  # approx Ah from Wh

            if cap is None or not np.isfinite(cap):
                continue

            meta["MeasuredCapacity_Ah"] = cap
            recs.append(meta)
        except Exception:
            pass

    df = pd.DataFrame(recs)
    if df.empty:
        return df
    ref = df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"] = df["MeasuredCapacity_Ah"] / ref
    df["SoH_percent"] = df["NormCapacity"] * 100.0
    return df

def estimate_cpp_per_cell(capacity_df: pd.DataFrame,
                          window:int, min_points:int)->Dict[str,float]:
    cpp={}
    for cid,grp in capacity_df.groupby("CellID"):
        g=grp.sort_values("CycleIndex")
        if g.shape[0]<min_points: continue
        tail=g.tail(window)
        x=tail["CycleIndex"].values.astype(float)
        y=tail["SoH_percent"].values.astype(float)
        if len(np.unique(x))<2: continue
        slope=np.polyfit(x,y,1)[0]
        if slope >= -1e-6:
            continue
        cpp[cid]=1.0/abs(slope)
    return cpp

def build_cpp_map(cap_df: pd.DataFrame):
    if cap_df.empty: return {}, cfg.CPP_FALLBACK
    cpp_map=estimate_cpp_per_cell(
        cap_df[["CellID","CycleIndex","SoH_percent"]],
        cfg.CPP_ROLLING_WINDOW, cfg.CPP_MIN_POINTS
    )
    if not cpp_map:
        return {}, cfg.CPP_FALLBACK
    return cpp_map, float(np.median(list(cpp_map.values())))

def get_cpp(meta: dict, cpp_map: Dict[str,float], global_cpp: float):
    if not meta: return global_cpp
    return cpp_map.get(meta.get("CellID"), global_cpp)

# =========================
# 8. DATASET BUILD
# =========================
def load_single_eis_mat(fp: Path):
    meta = parse_eis_metadata(fp.stem)
    if meta is None:
        raise ValueError(f"Bad filename: {fp.name}")
    freq,re_z,im_z = load_mat_eis(fp)
    re_i=_interp_channel(freq, re_z, CANON_FREQ)
    im_i=_interp_channel(freq, im_z, CANON_FREQ)
    vec=build_feature_vector(re_i, im_i, meta["Temp"], CANON_FREQ)
    return vec, meta, re_i, im_i

def build_dataset(eis_dir: Path, cap_df: Optional[pd.DataFrame]):
    files = sorted(eis_dir.rglob("*.mat"))
    if not files:
        raise FileNotFoundError(f"No .mat spectra in {eis_dir}")

    f0,r0,i0 = load_mat_eis(files[0])
    re0=_interp_channel(f0,r0,CANON_FREQ); im0=_interp_channel(f0,i0,CANON_FREQ)
    _, feature_names = build_feature_vector(re0, im0, 25.0, CANON_FREQ, include_names=True)

    feats=[]; rows=[]; shape_feats=[]
    for fp in tqdm(files, desc="Loading training spectra"):
        try:
            v, m, rei, imi = load_single_eis_mat(fp)
            feats.append(v); rows.append(m)
            if (cfg.INCLUDE_NORMALIZED_SHAPE_MODEL or cfg.SOC_INCLUDE_SHAPE_MODEL) and cfg.NORMALIZE_SHAPE_BY_HF_RE:
                rsh, ish = build_shape_normalized(rei, imi)
                shape_vec = build_feature_vector(rsh, ish, m["Temp"], CANON_FREQ)
                shape_feats.append(shape_vec)
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")

    if not rows:
        raise RuntimeError("No valid training spectra after filtering.")

    X = np.vstack(feats)
    X_shape = np.vstack(shape_feats) if shape_feats else None
    meta_df = pd.DataFrame(rows)

    # SoH refinement
    if cap_df is not None and not cap_df.empty and cfg.REFINE_SOH_WITH_CAPACITY:
        lookup = cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        refined=[]
        for cid, stage, fallback in zip(meta_df.CellID, meta_df.SOH_stage, meta_df.RealSOH_file):
            nc = lookup.get((cid, stage))
            refined.append(100.0*nc if nc is not None else fallback)
        meta_df["SoH_cont"]=refined
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    # Targets
    y_soc = meta_df["SOC"].astype(float).values
    y_soh = meta_df["SoH_cont"].values

    soh_var = float(np.var(y_soh))
    if cfg.VERBOSE:
        print(f"[DATA] SoH range: {y_soh.min():.2f} – {y_soh.max():.2f} (var={soh_var:.3f})")
        if soh_var < 1.0:
            print("[WARN] Low SoH variance → model may output near-constant SoH.")

    if cfg.SAVE_FEATURE_TABLE:
        pd.concat(
            [meta_df.reset_index(drop=True),
             pd.DataFrame(X, columns=feature_names)], axis=1
        ).to_parquet(cfg.MODEL_DIR/"training_features.parquet", index=False)

    # Save anchors for autoscale
    try:
        idx_hf  = feature_names.index("hf_re")
    except ValueError:
        idx_hf = None
    try:
        idx_f7 = feature_names.index("F7")
    except ValueError:
        idx_f7 = None

    train_hf_median = float(np.median(X[:, idx_hf])) if idx_hf is not None else None
    train_f7_median = float(np.median(X[:, idx_f7])) if idx_f7 is not None else None

    return meta_df, X, (X_shape, feature_names, train_hf_median, train_f7_median), y_soc, y_soh

# =========================
# 9. SPLITTING
# =========================
def cell_split_mask(meta_df: pd.DataFrame):
    cells = meta_df.CellID.unique()
    rng = np.random.default_rng(cfg.RANDOM_STATE)
    n_test = max(1, int(len(cells)*cfg.TEST_FRAC))
    test_cells = rng.choice(cells, size=n_test, replace=False)
    return meta_df.CellID.isin(test_cells)

# =========================
# 10. TRAINING
# =========================
def _fit_hgb(X, y, seed):
    # Solid defaults; robust on small/medium tabular sets
    return HistGradientBoostingRegressor(
        learning_rate=0.06,
        max_iter=650,
        max_depth=None,
        min_samples_leaf=12,
        l2_regularization=1e-3,
        random_state=seed
    ).fit(X, y)

def train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh):
    X_shape, feature_names, train_hf_median, train_f7_median = shape_bundle
    mask_test = cell_split_mask(meta_df)

    # ----- SoC (NO GPR) -----
    soc_scaler = StandardScaler()
    X_soc_s = soc_scaler.fit_transform(X_raw)
    soc_pca = None
    X_soc_model = X_soc_s
    if cfg.USE_PCA_SOC:
        soc_pca = PCA(n_components=min(cfg.PCA_SOC_COMPONENTS, X_soc_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soc_model = soc_pca.fit_transform(X_soc_s)

    # Optional: slight jitter to de-discretize SoC labels
    y_soc_train_all = y_soc.copy()
    if cfg.SOC_LABEL_JITTER and cfg.SOC_LABEL_JITTER > 0:
        rng = np.random.default_rng(cfg.RANDOM_STATE)
        y_soc_train_all = np.clip(y_soc_train_all + rng.normal(0.0, cfg.SOC_LABEL_JITTER, size=y_soc_train_all.shape), 0.0, 100.0)

    soc_candidates = {}

    # HGB on raw embedding
    soc_hgb_raw = _fit_hgb(X_soc_model[~mask_test], y_soc_train_all[~mask_test], cfg.RANDOM_STATE)
    pred_soc_hgb = soc_hgb_raw.predict(X_soc_model[mask_test])
    r2_soc_hgb = r2_score(y_soc[mask_test], pred_soc_hgb)
    rmse_soc_hgb = math.sqrt(mean_squared_error(y_soc[mask_test], pred_soc_hgb))
    soc_candidates["soc_hgb_raw"] = (soc_hgb_raw, r2_soc_hgb, rmse_soc_hgb)

    # HGB on shape-normalized embedding (if available)
    soc_shape_model = soc_shape_scaler = soc_shape_pca = None
    soc_shape_metrics = None
    if cfg.SOC_INCLUDE_SHAPE_MODEL and (X_shape is not None):
        soc_shape_scaler = StandardScaler()
        Xs = soc_shape_scaler.fit_transform(X_shape)
        if cfg.USE_PCA_SOC:
            soc_shape_pca = PCA(n_components=min(cfg.PCA_SOC_COMPONENTS, Xs.shape[1]-1),
                                random_state=cfg.RANDOM_STATE)
            Xs_model = soc_shape_pca.fit_transform(Xs)
        else:
            Xs_model = Xs
        soc_hgb_shape = _fit_hgb(Xs_model[~mask_test], y_soc_train_all[~mask_test], cfg.RANDOM_STATE)
        sp = soc_hgb_shape.predict(Xs_model[mask_test])
        r2_soc_shape = r2_score(y_soc[mask_test], sp)
        rmse_soc_shape = math.sqrt(mean_squared_error(y_soc[mask_test], sp))
        soc_candidates["soc_hgb_shape"] = (soc_hgb_shape, r2_soc_shape, rmse_soc_shape)
        soc_shape_model = soc_hgb_shape
        soc_shape_metrics = {"r2": r2_soc_shape, "rmse": rmse_soc_shape}
    else:
        Xs_model = None

    # Choose best SoC model by R^2
    soc_best_name = max(soc_candidates.keys(), key=lambda k: soc_candidates[k][1])
    soc_best_model, soc_best_r2, soc_best_rmse = soc_candidates[soc_best_name]

    if cfg.VERBOSE:
        print(f"[SoC] HGB_raw:   R2={r2_soc_hgb:.3f} RMSE={rmse_soc_hgb:.2f}")
        if soc_shape_metrics:
            print(f"[SoC] HGB_shape: R2={soc_shape_metrics['r2']:.3f} RMSE={soc_shape_metrics['rmse']:.2f}")
        print(f"[SoC] Selected = {soc_best_name}")

    # ----- SoH (unchanged core: GPR tends to be best here) -----
    from sklearn.gaussian_process import GaussianProcessRegressor
    from sklearn.gaussian_process.kernels import RBF, WhiteKernel

    soh_scaler = StandardScaler()
    X_soh_s = soh_scaler.fit_transform(X_raw)
    soh_pca=None
    X_soh_model = X_soh_s
    if cfg.USE_PCA_SOH:
        soh_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_soh_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soh_model = soh_pca.fit_transform(X_soh_s)

    soh_candidates = {}
    dim = X_soh_model.shape[1]
    kernel = RBF(length_scale=np.ones(dim)*3.0,
                 length_scale_bounds=(1e-1,1e6)) + \
             WhiteKernel(noise_level=1e-2,
                         noise_level_bounds=(1e-6,1e-1))
    gpr = GaussianProcessRegressor(
        kernel=kernel, alpha=0.0, normalize_y=True,
        random_state=cfg.RANDOM_STATE, n_restarts_optimizer=3
    )
    if X_soh_model.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
        idx = np.random.default_rng(cfg.RANDOM_STATE).choice(
            X_soh_model.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
        gpr.fit(X_soh_model[idx], y_soh[idx])
    else:
        gpr.fit(X_soh_model, y_soh)
    pred_gpr = gpr.predict(X_soh_model[mask_test])
    r2_gpr = r2_score(y_soh[mask_test], pred_gpr)
    rmse_gpr = math.sqrt(mean_squared_error(y_soh[mask_test], pred_gpr))
    soh_candidates["gpr_raw"] = (gpr, r2_gpr, rmse_gpr)

    hgb = HistGradientBoostingRegressor(
        learning_rate=0.05, max_iter=500,
        l2_regularization=1e-3, random_state=cfg.RANDOM_STATE
    )
    hgb.fit(X_soh_model[~mask_test], y_soh[~mask_test])
    pred_hgb = hgb.predict(X_soh_model[mask_test])
    r2_hgb = r2_score(y_soh[mask_test], pred_hgb)
    rmse_hgb = math.sqrt(mean_squared_error(y_soh[mask_test], pred_hgb))
    soh_candidates["hgb_raw"] = (hgb, r2_hgb, rmse_hgb)

    shape_model=shape_scaler=shape_pca=None
    shape_metrics=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and (X_shape is not None):
        shape_scaler = StandardScaler()
        X_shape_s = shape_scaler.fit_transform(X_shape)
        X_shape_model = X_shape_s
        if cfg.USE_PCA_SOH:
            shape_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_shape_s.shape[1]-1),
                            random_state=cfg.RANDOM_STATE)
            X_shape_model = shape_pca.fit_transform(X_shape_s)

        # Build a brand-new GPR matched to the shape-space dimension (fixes kernel__k1 param leak)
        dim_shape = X_shape_model.shape[1]
        kernel_shape = (
            RBF(length_scale=np.ones(dim_shape) * 3.0,
                length_scale_bounds=(1e-1, 1e6))
            + WhiteKernel(noise_level=1e-2,
                          noise_level_bounds=(1e-6, 1e-1))
        )
        shape_model = GaussianProcessRegressor(
            kernel=kernel_shape,
            alpha=0.0,
            normalize_y=True,
            random_state=cfg.RANDOM_STATE,
            n_restarts_optimizer=3
        )

        # Fit new GPR for shape space
        if X_shape_model.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
            idx = np.random.default_rng(cfg.RANDOM_STATE).choice(
                X_shape_model.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
            shape_model.fit(X_shape_model[idx], y_soh[idx])
        else:
            shape_model.fit(X_shape_model, y_soh)

        spred = shape_model.predict(X_shape_model[mask_test])
        r2_shape = r2_score(y_soh[mask_test], spred)
        rmse_shape = math.sqrt(mean_squared_error(y_soh[mask_test], spred))
        soh_candidates["gpr_shape"] = (shape_model, r2_shape, rmse_shape)
        shape_metrics = {"r2": r2_shape, "rmse": rmse_shape}

    best_name = max(["gpr_raw","hgb_raw"], key=lambda k: soh_candidates[k][1])
    best_model, best_r2, best_rmse = soh_candidates[best_name]

    if cfg.VERBOSE:
        print(f"[SoH] GPR_raw:  R2={r2_gpr:.3f} RMSE={rmse_gpr:.2f}")
        print(f"[SoH] HGB_raw: R2={r2_hgb:.3f} RMSE={rmse_hgb:.2f}")
        if shape_metrics:
            print(f"[SoH] ShapeGP: R2={shape_metrics['r2']:.3f} RMSE={shape_metrics['rmse']:.2f}")
        print(f"[SoH] Selected raw model = {best_name}")

    # --- Robust covariance (Ledoit–Wolf) for OOD ---
    def _lw_cov_inv(X):
        try:
            lw = LedoitWolf().fit(X)
            cov = lw.covariance_
            return np.linalg.pinv(cov), X.mean(axis=0)
        except Exception:
            return np.eye(X.shape[1]), X.mean(axis=0)

    # SoH-space OOD (kept for SoH uncertainty capping)
    cov_inv_soh, center_soh = _lw_cov_inv(X_soh_s)

    # SoC-space OOD stats (raw embed)
    soc_cov_inv_raw, soc_center_raw = _lw_cov_inv(X_soc_model[~mask_test])
    d_raw = []
    X_tr_raw = X_soc_model[~mask_test]
    for i in range(X_tr_raw.shape[0]):
        diff = X_tr_raw[i] - soc_center_raw
        d_raw.append(float(np.sqrt(diff @ soc_cov_inv_raw @ diff.T)))
    thr_raw = float(np.quantile(d_raw, cfg.OOD_SOC_Q)) if len(d_raw) else None

    # SoC-shape-space OOD stats (use the model's own embedding)
    soc_cov_inv_shape = None; soc_center_shape = None; thr_shape = None
    X_tr_shape = None
    if cfg.SOC_INCLUDE_SHAPE_MODEL and (soc_shape_model is not None):
        X_tr_shape = Xs_model[~mask_test]  # same space as the trained shape model
        if X_tr_shape.size:
            soc_cov_inv_shape, soc_center_shape = _lw_cov_inv(X_tr_shape)
            d_shape=[]
            for i in range(X_tr_shape.shape[0]):
                diff = X_tr_shape[i] - soc_center_shape
                d_shape.append(float(np.sqrt(diff @ soc_cov_inv_shape @ diff.T)))
            thr_shape = float(np.quantile(d_shape, cfg.OOD_SOC_Q)) if len(d_shape) else None

    # SoC calibrator (fit on holdout predictions of the selected model)
    if soc_best_name == "soc_hgb_raw":
        val_pred = soc_hgb_raw.predict(X_soc_model[mask_test])
    else:
        val_pred = soc_candidates["soc_hgb_shape"][0].predict(Xs_model[mask_test])

    soc_calibrator = IsotonicRegression(y_min=0.0, y_max=100.0, out_of_bounds="clip")
    soc_calibrator.fit(val_pred, y_soc[mask_test])

    bundle = {
        # SoC
        "soc_scaler": soc_scaler,
        "soc_pca": soc_pca,
        "soc_model": soc_best_model,
        "soc_model_name": soc_best_name,

        # shape-based SoC path
        "soc_shape_scaler": None if soc_shape_model is None else soc_shape_scaler,
        "soc_shape_pca": None if soc_shape_model is None else soc_shape_pca,
        "soc_shape_model": soc_shape_model,

        "soc_calibrator": soc_calibrator,

        # Embeddings for KNN prior (both)
        "soc_train_embed_raw": X_tr_raw,
        "soc_train_y_raw": y_soc[~mask_test],
        "soc_train_embed_shape": X_tr_shape,
        "soc_train_y_shape": y_soc[~mask_test] if X_tr_shape is not None else None,

        # SoH
        "soh_scaler": soh_scaler,
        "soh_pca": soh_pca,
        "soh_model": best_model,
        "soh_model_name": best_name,

        # Optional SoH shape model
        "shape_scaler": shape_scaler,
        "shape_pca": shape_pca,
        "shape_model": shape_model,

        # Meta
        "freq_grid": CANON_FREQ,
        "feature_version": cfg.FEATURE_VERSION,
        "feature_manifest": feature_names,
        "config_signature": config_signature(cfg),
        "config": to_jsonable(asdict(cfg)),
        "metrics": {
            "soc_r2_selected": soc_best_r2,
            "soc_rmse_selected": soc_best_rmse,
            "soh_r2_selected": best_r2,
            "soh_rmse_selected": best_rmse
        },
        "soc_candidates_metrics": {
            "soc_hgb_raw": {"r2": r2_soc_hgb, "rmse": rmse_soc_hgb},
            "soc_hgb_shape": soc_shape_metrics
        },
        "soh_candidates_metrics": {
            "gpr_raw": {"r2": r2_gpr, "rmse": rmse_gpr},
            "hgb_raw": {"r2": r2_hgb, "rmse": rmse_hgb},
            "gpr_shape": shape_metrics
        },
        # SoH-space OOD stats
        "train_mahal": {"center": center_soh.tolist(), "cov_inv": cov_inv_soh.tolist()},
        # SoC OOD stats (dual)
        "soc_train_mahal_raw": {
            "center": soc_center_raw.tolist(),
            "cov_inv": soc_cov_inv_raw.tolist(),
            "threshold": thr_raw
        },
        "soc_train_mahal_shape": None if soc_center_shape is None else {
            "center": soc_center_shape.tolist(),
            "cov_inv": soc_cov_inv_shape.tolist(),
            "threshold": thr_shape
        },
        # autoscale anchors
        "train_hf_re_median": train_hf_median,
        "train_f7_median": train_f7_median,
    }
    out_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    joblib.dump(bundle, out_path)
    if cfg.VERBOSE:
        print(f"[MODEL] Saved bundle → {out_path}")
        print(json.dumps(bundle["metrics"], indent=2))
    return bundle

# =========================
# 11. LOAD (WITH LEGACY + SIGNATURE CHECK)
# =========================
def load_bundle():
    path = cfg.MODEL_DIR / "eis_soc_soh_phys_models.joblib"
    bundle = joblib.load(path)

    legacy = ("scaler" in bundle) and ("soc_scaler" not in bundle)
    if legacy:
        scaler = bundle["scaler"]
        pca = bundle.get("pca")
        soc_model = bundle.get("soc_model")
        soh_model = bundle.get("soh_model")
        bundle["soc_scaler"] = scaler
        bundle["soh_scaler"] = scaler
        bundle["soc_pca"]    = pca
        bundle["soh_pca"]    = pca
        bundle["soh_model"]  = soh_model
        bundle["soh_model_name"] = bundle.get("soh_model_name","legacy_model")
        bundle["soc_model_name"] = bundle.get("soc_model_name","legacy_soc_model")
        bundle.setdefault("metrics", {"soh_rmse_selected":5.0, "soc_rmse_selected":8.0})
        if "train_mahal" not in bundle:
            try:
                center = scaler.mean_
                cov_inv = np.eye(len(center))
                bundle["train_mahal"] = {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
            except Exception:
                bundle["train_mahal"] = None
        bundle.setdefault("feature_version", -1)

    for key in ["soc_scaler","soc_model","soh_scaler","soh_model","freq_grid"]:
        if key not in bundle:
            raise KeyError(f"Bundle missing required key: {key}")
    return bundle

# =========================
# 12. INFERENCE FEATURIZATION
# =========================
def featurize_any(file_path: Path, bundle):
    freq_grid = bundle["freq_grid"]
    meta = parse_eis_metadata(file_path.stem)
    freq,re_raw,im_raw, used_freq = load_any_inference(file_path)
    if not used_freq:
        warnings.warn(f"[{file_path.name}] No frequency column found. Using geometric grid fallback.")
    re_i=_interp_channel(freq, re_raw, freq_grid)
    im_i=_interp_channel(freq, im_raw, freq_grid)

    # ---- AUTO UNIT ALIGNMENT via hf_re and F7 (10 Hz) ----
    autoscale_factor = 1.0
    hf_test = float(np.nanmedian(re_i[:max(1, min(5, len(re_i)))]))
    # F7 ~ 10 Hz (match how we built it)
    idx_mid = int(np.argmin(np.abs(freq_grid-10.0)))
    f7_test = float(re_i[idx_mid])

    hf_train = bundle.get("train_hf_re_median", None)
    f7_train = bundle.get("train_f7_median", None)

    try:
        hf_ok = np.isfinite(hf_test) and abs(hf_test) > 1e-12 and hf_train is not None and np.isfinite(hf_train)
        f7_ok = np.isfinite(f7_test) and abs(f7_test) > 1e-12 and f7_train is not None and np.isfinite(f7_train)
        scale_candidates = []
        if hf_ok:
            scale_candidates.append(float(hf_train) / hf_test)
        if f7_ok:
            scale_candidates.append(float(f7_train) / f7_test)

        scale = None
        if len(scale_candidates) == 2:
            # If both anchors agree within 2×, use geometric mean; else prefer the one further from 1
            s1, s2 = scale_candidates
            if (min(s1, s2) > 0) and (max(s1, s2) / max(1e-12, min(s1, s2)) <= 2.0):
                scale = float(np.sqrt(s1 * s2))
            else:
                # pick the one implying stronger correction
                scale = s1 if abs(s1-1) > abs(s2-1) else s2
        elif len(scale_candidates) == 1:
            scale = scale_candidates[0]

        # Apply if clearly off (outside [0.5, 2.0])
        if scale is not None and (scale < 0.5 or scale > 2.0):
            re_i *= scale
            im_i *= scale
            autoscale_factor = float(scale)
            if cfg.VERBOSE:
                print(f"[AUTO-SCALE] {file_path.name}: hf_test={hf_test:.4g}, f7_test={f7_test:.4g} "
                      f"→ scaled by {autoscale_factor:.4g} to align.")
    except Exception:
        pass

    if meta is None and cfg.TEST_TEMPERATURE_OVERRIDE is not None:
        temp = cfg.TEST_TEMPERATURE_OVERRIDE
    else:
        temp = meta["Temp"] if meta else -1
    vec = build_feature_vector(re_i, im_i, temp, freq_grid)
    norm_vec=None
    if (cfg.INCLUDE_NORMALIZED_SHAPE_MODEL or cfg.SOC_INCLUDE_SHAPE_MODEL) and \
       (bundle.get("shape_model") is not None or bundle.get("soc_shape_model") is not None):
        if cfg.NORMALIZE_SHAPE_BY_HF_RE:
            rsh, ish = build_shape_normalized(re_i, im_i)
            norm_vec = build_feature_vector(rsh, ish, temp, freq_grid)
    checksum = hashlib.sha1(np.ascontiguousarray(vec).tobytes()).hexdigest()

    extras = {
        "used_freq_from_file": bool(used_freq),
        "hf_train": None if hf_train is None else float(hf_train),
        "hf_test_before_scale": float(hf_test) if np.isfinite(hf_test) else None,
        "f7_train": None if f7_train is None else float(f7_train),
        "f7_test_before_scale": float(f7_test) if np.isfinite(f7_test) else None,
        "autoscale_factor": float(autoscale_factor),
    }
    return vec, norm_vec, meta, checksum, extras

# =========================
# 13. OOD & KNN PRIOR UTILITIES
# =========================
def mahalanobis_distance(x, center, cov_inv):
    diff = x - center
    return float(np.sqrt(diff @ cov_inv @ diff.T))

def _knn_prior_soc(x_embed: np.ndarray, train_embed: np.ndarray, y_train: np.ndarray, k: int):
    """
    Weighted KNN prior in an embedding space.
    Returns (prior_mean, prior_std, neighbor_count).
    """
    if train_embed is None or y_train is None or len(y_train)==0:
        return None, None, 0
    k = max(3, min(int(k), len(y_train)))
    d = np.linalg.norm(train_embed - x_embed[None, :], axis=1)
    idx = np.argpartition(d, k-1)[:k]
    di = d[idx]
    yi = y_train[idx].astype(float)
    w = 1.0 / (di**2 + 1e-6)
    w = w / (np.sum(w) + 1e-12)
    prior_mean = float(np.sum(w * yi))
    prior_var = float(np.sum(w * (yi - prior_mean)**2))
    prior_std = math.sqrt(max(prior_var, 1e-9))
    return prior_mean, prior_std, k

# =========================
# 14. PROJECTION PLOT
# =========================
def build_projection(soh_current, cpp, lower, exponent=None, n=160):
    if soh_current <= lower or cpp <= 0:
        return np.array([0.0]), np.array([soh_current])
    total = (soh_current - lower) * cpp
    cycles = np.linspace(0, total, n)
    S0 = soh_current; Smin=lower
    if exponent is None: exponent = cfg.PLOT_EXPONENT
    soh_curve = Smin + (S0 - Smin)*(1 - cycles/total)**exponent
    return cycles, soh_curve

def plot_projection(file_base, soh_current, soh_std, cycles_to_map, cpp, ood_flag, out_path, thresholds):
    if not thresholds: thresholds = (50.0, 40.0)
    min_thr = min(thresholds)
    if soh_current <= min_thr:
        return
    cycles, curve = build_projection(soh_current, cpp, min_thr)
    plt.figure(figsize=(6.4,4))
    plt.plot(cycles, curve, lw=2, label="Projected SoH")

    for thr in thresholds:
        style = "--" if thr >= 50 else ":"
        color = "orange" if thr >= 50 else "red"
        plt.axhline(thr, color=color, ls=style, label=f"{int(thr)}%")
        x = cycles_to_map.get(thr, 0.0)
        if x > 0:
            plt.axvline(x, color=color, ls="-." if thr>=50 else ":")
            plt.scatter([x],[thr], s=45)
            txty = thr + (1.0 if thr>=50 else -2.0)
            plt.text(x, txty, f"{x:.0f} cyc", ha="center", fontsize=8, color=color)

    plt.scatter([0],[soh_current], c="green", s=55, label=f"Current {soh_current:.2f}%")
    plt.text(0, soh_current+0.7, f"±{soh_std:.2f}", color="green", fontsize=8)

    if ood_flag:
        plt.text(0.98,0.05,"OOD", transform=plt.gca().transAxes,
                 ha="right", va="bottom", color="crimson", fontsize=11,
                 bbox=dict(boxstyle="round", fc="w", ec="crimson"))

    plt.xlabel("Remaining Cycles")
    plt.ylabel("SoH (%)")
    plt.title(f"RUL Projection – {file_base}")
    plt.grid(alpha=0.35)
    plt.legend(fontsize=8)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140)
    plt.close()

# =========================
# 15. INFERENCE (SINGLE FILE)
# =========================
def predict_file(file_path: Path, bundle, cpp_map, global_cpp):
    vec, norm_vec, meta, checksum, extras = featurize_any(file_path, bundle)

    # ----- SoC -----
    soc_scaler=bundle["soc_scaler"]; soc_pca=bundle.get("soc_pca")
    soc_model=bundle["soc_model"]; soc_model_name=bundle.get("soc_model_name","unknown")
    X_soc = soc_scaler.transform(vec.reshape(1,-1))
    X_soc_in_raw = soc_pca.transform(X_soc) if soc_pca else X_soc

    # shape SoC path (if present)
    soc_shape_model = bundle.get("soc_shape_model")
    X_soc_in_shape = None
    if soc_shape_model is not None and norm_vec is not None:
        sscaler = bundle.get("soc_shape_scaler"); spca = bundle.get("soc_shape_pca")
        Xs = sscaler.transform(norm_vec.reshape(1,-1))
        X_soc_in_shape = spca.transform(Xs) if spca else Xs

    # --- SoC-space OOD check (dual spaces; pick the smaller distance) ---
    def _dist_block(X1, stats_key):
        info = bundle.get(stats_key)
        if not info: return None, None, None
        c = np.array(info.get("center"))
        inv = np.array(info.get("cov_inv"))
        thr = info.get("threshold", None)
        if X1 is None or c is None or inv is None:
            return None, None, None
        try:
            d = float(np.sqrt((X1[0]-c) @ inv @ (X1[0]-c).T))
        except Exception:
            d = None
        return d, float(thr) if thr is not None else None, info

    d_raw, thr_raw, _ = _dist_block(X_soc_in_raw, "soc_train_mahal_raw")
    d_shape, thr_shape, _ = _dist_block(X_soc_in_shape, "soc_train_mahal_shape")

    # choose embedding with smaller normalized exceedance
    def _exceed(d, thr):
        if d is None or thr is None or not np.isfinite(thr): return np.inf
        return d - thr
    choose_shape = (_exceed(d_shape, thr_shape) < _exceed(d_raw, thr_raw))
    d_used = d_shape if choose_shape else d_raw
    thr_used = thr_shape if choose_shape else thr_raw
    X_used = X_soc_in_shape if choose_shape else X_soc_in_raw
    train_embed = bundle["soc_train_embed_shape"] if choose_shape else bundle["soc_train_embed_raw"]
    train_y = bundle["soc_train_y_shape"] if choose_shape else bundle["soc_train_y_raw"]

    soc_mahal = d_used
    soc_oob = bool(d_used is not None and thr_used is not None and d_used > thr_used)

    # raw model prediction (HGB)
    raw_soc_mean = float(soc_model.predict(X_soc_in_raw)[0])
    raw_soc_std  = float(bundle["metrics"].get("soc_rmse_selected", 8.0))  # model has no native std

    # optional shape SoC prediction (HGB)
    soc_shape_mean = None; soc_shape_std = None
    if soc_shape_model is not None and X_soc_in_shape is not None:
        soc_shape_mean = float(soc_shape_model.predict(X_soc_in_shape)[0])
        soc_shape_std  = float(bundle["metrics"].get("soc_rmse_selected", 8.0))
        # stabilize pre-OOD by averaging raw + shape
        raw_soc_mean = 0.5*(raw_soc_mean + soc_shape_mean)
        raw_soc_std  = float(np.sqrt(0.5*(raw_soc_std**2 + soc_shape_std**2)))

    # ----- SoH -----
    soh_scaler=bundle["soh_scaler"]; soh_pca=bundle.get("soh_pca")
    soh_model=bundle["soh_model"]; model_name=bundle.get("soh_model_name","unknown")
    X_soh_s = soh_scaler.transform(vec.reshape(1,-1))
    X_soh_in = soh_pca.transform(X_soh_s) if soh_pca else X_soh_s

    if "gpr" in model_name:
        from sklearn.gaussian_process import GaussianProcessRegressor
        if isinstance(soh_model, GaussianProcessRegressor):
            sm, ss = soh_model.predict(X_soh_in, return_std=True)
            soh_mean_raw = float(sm[0]); soh_std_raw=float(ss[0])
        else:
            soh_mean_raw = float(soh_model.predict(X_soh_in)[0])
            soh_std_raw  = float(bundle["metrics"].get("soh_rmse_selected", 5.0))
    else:
        soh_mean_raw = float(soh_model.predict(X_soh_in)[0])
        soh_std_raw  = float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    shape_model = bundle.get("shape_model")
    shape_soh_mean=None; shape_soh_std=None
    if shape_model is not None and norm_vec is not None:
        sscaler = bundle.get("shape_scaler")
        spca = bundle.get("shape_pca")
        X_shape_s = sscaler.transform(norm_vec.reshape(1,-1))
        X_shape_in = spca.transform(X_shape_s) if spca else X_shape_s
        try:
            sm2, ss2 = shape_model.predict(X_shape_in, return_std=True)
            shape_soh_mean=float(sm2[0]); shape_soh_std=float(ss2[0])
        except Exception:
            shape_soh_mean=float(shape_model.predict(X_shape_in)[0])
            shape_soh_std=float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    if cfg.ENSEMBLE_SOH and shape_soh_mean is not None:
        soh_mean = 0.5*(soh_mean_raw + shape_soh_mean)
        stds = [soh_std_raw]
        if shape_soh_std is not None: stds.append(shape_soh_std)
        soh_std = float(np.sqrt(np.mean(np.array(stds)**2)))
    else:
        soh_mean, soh_std = soh_mean_raw, soh_std_raw

    # ----- OOD diagnostics (SoH space ONLY for SoH uncertainty) -----
    train_mahal = bundle.get("train_mahal")
    mahal_dist=None
    if train_mahal:
        cov_inv = np.array(train_mahal["cov_inv"])
        center = np.array(train_mahal["center"])
        mahal_dist = mahalanobis_distance(X_soh_s[0], center, cov_inv)
    ard_norm=None  # kept for backward-compatibility; not essential here
    ood_flag=False
    if (mahal_dist is not None and mahal_dist > cfg.MAHAL_THRESHOLD):
        ood_flag=True

    # ----- SoC post-hoc calibration -----
    soc_mean = raw_soc_mean
    soc_std  = raw_soc_std
    soc_cal = bundle.get("soc_calibrator")
    if soc_cal is not None and (not soc_oob or cfg.SOC_CALIBRATE_ON_OOD):
        try:
            soc_mean = float(soc_cal.predict([soc_mean])[0])
        except Exception:
            pass
    soc_mean = float(np.clip(soc_mean, 0.0, 100.0))

    # ----- If SoC OOD: blend raw + (optional) shape + KNN prior; compute capped std -----
    val_rmse_soc = float(bundle["metrics"].get("soc_rmse_selected", 8.0))
    if cfg.OOD_SOC_ENABLE and soc_oob:
        thr = thr_used if (thr_used is not None and np.isfinite(thr_used)) else np.inf
        delta = max(0.0, (soc_mahal or 0.0) - thr)
        s = max(1e-6, cfg.OOD_SOC_SHRINK_SCALE)

        # Gentle decay of raw weight; never below W_MIN
        w_raw = 1.0 / (1.0 + (delta / s))
        w_raw = max(cfg.OOD_SOC_W_MIN, float(w_raw))

        # Shape weight gets a portion of remaining if available
        has_shape_soc = (soc_shape_model is not None and soc_shape_mean is not None)
        w_shape = (1.0 - w_raw) * (0.60 if has_shape_soc else 0.0)

        # KNN prior from chosen embedding space
        prior_mean, prior_std, k_used = _knn_prior_soc(
            X_used[0], train_embed, train_y, k=cfg.SOC_OOD_K
        )
        if prior_mean is None:
            prior_mean = soc_mean
            prior_std = val_rmse_soc

        w_prior = max(0.0, 1.0 - w_raw - w_shape)

        # Combine means
        soc_mean = (w_raw * soc_mean) + \
                   (w_shape * (soc_shape_mean if has_shape_soc else soc_mean)) + \
                   (w_prior * prior_mean)
        soc_mean = float(np.clip(soc_mean, 0.0, 100.0))

        # Combine stds (RMS of weighted components) then cap
        comp_raw_std   = soc_std
        comp_shape_std = soc_shape_std if has_shape_soc and soc_shape_std is not None else soc_std
        comp_prior_std = max(prior_std, val_rmse_soc)

        var = (w_raw*comp_raw_std)**2 + (w_shape*comp_shape_std)**2 + (w_prior*comp_prior_std)**2
        soc_std = float(math.sqrt(max(var, 1e-12)))

        # Cap std for OOD (aggressively reduce huge values)
        cap_ood = val_rmse_soc * cfg.SOC_STD_CAP_MULT_OOD
        soc_std = float(min(soc_std, cap_ood))

        if cfg.VERBOSE:
            which = "shape" if choose_shape else "raw"
            print(f"[SoC-OOD] space={which} d={soc_mahal:.2f} thr={thr:.2f} "
                  f"w_raw={w_raw:.3f} w_shape={w_shape:.3f} w_prior={w_prior:.3f} "
                  f"prior_mean={prior_mean:.2f} prior_std={prior_std:.2f} k={k_used}")

    else:
        # In-domain: cap std tightly to validation RMSE
        cap_in = val_rmse_soc * cfg.SOC_STD_CAP_MULT_IN
        soc_std = float(min(soc_std, cap_in))

    # ----- SoH uncertainty capping (SoH OOD only) -----
    soh_val_rmse = float(bundle["metrics"].get("soh_rmse_selected", 5.0))
    if ood_flag:
        soh_std = min(soh_std, cfg.SOH_STD_MAX_OOD)
    else:
        soh_std = min(soh_std, soh_val_rmse)

    # ----- RUL (multi-threshold) -----
    cpp = get_cpp(meta, cpp_map, global_cpp)
    cycles_to = {}
    for thr_val in cfg.TARGET_SOH_THRESHOLDS:
        cycles_to[thr_val] = float((soh_mean - thr_val) * cpp) if soh_mean > thr_val else 0.0

    if cfg.VERBOSE:
        print(f"[SoC] {Path(file_path).name}: mean={soc_mean:.2f} std={soc_std:.2f}  "
              f"SOC_mahal_raw={d_raw} thr_raw={thr_raw}  SOC_mahal_shape={d_shape} thr_shape={thr_shape}  "
              f"used={'shape' if choose_shape else 'raw'} OOD={soc_oob}")
        print(f"[SoH] {Path(file_path).name}: mean={soh_mean:.2f} std={soh_std:.2f}  "
              f"OOD_flag(SoH)={ood_flag}  Mahalanobis={mahal_dist}")

    result={
        "file": str(file_path),
        "feature_checksum": checksum,
        "parsed_metadata": meta,
        # SoC
        "predicted_SoC_percent": float(soc_mean),
        "SoC_std_estimate": float(soc_std),
        "soc_model_chosen": soc_model_name,
        "SoC_probabilities": None,
        "SOC_mahal": soc_mahal,
        "SOC_ood": bool(soc_oob),
        # helpful debug
        "raw_soc_precal": float(raw_soc_mean),
        "raw_soc_precal_std": float(bundle["metrics"].get("soc_rmse_selected", 8.0)),
        "shape_soc_mean": None if soc_shape_mean is None else float(soc_shape_mean),
        "shape_soc_std": None if soc_shape_std is None else float(soc_shape_std),
        # SoH
        "predicted_SoH_percent": float(soh_mean),
        "SoH_std_estimate": float(soh_std),
        "raw_model_mean": float(soh_mean_raw),
        "raw_model_std": float(soh_std_raw),
        "shape_model_mean": None if shape_model is None else float(shape_soh_mean),
        "shape_model_std": None if shape_model is None else float(shape_soh_std),
        "soh_model_chosen": model_name,
        # RUL
        "cycles_per_percent_used": float(cpp),
        "cycles_to_thresholds": {str(int(k)): v for k,v in cycles_to.items()},
        "decision_threshold_percent": cfg.DECISION_SOH_PERCENT,
        "lower_threshold_percent": cfg.ILLUSTRATIVE_MIN_SOH,
        # OOD (SoH diagnostics)
        "OOD_mahal": None if mahal_dist is None else float(mahal_dist),
        "OOD_gp_ard_norm": None,
        "OOD_flag": bool(ood_flag),
        # Loader / autoscale diagnostics
        "used_freq_from_file": extras.get("used_freq_from_file"),
        "hf_train": extras.get("hf_train"),
        "hf_test_before_scale": extras.get("hf_test_before_scale"),
        "f7_train": extras.get("f7_train"),
        "f7_test_before_scale": extras.get("f7_test_before_scale"),
        "autoscale_factor": extras.get("autoscale_factor"),
    }
    return result, ood_flag, cycles_to

# =========================
# 16. MAIN (batch mode)
# =========================
def main():
    if cfg.VERBOSE:
        print("Configuration:\n", json.dumps(to_jsonable(asdict(cfg)), indent=2))

    assert cfg.EIS_DIR.exists(), f"EIS_DIR missing: {cfg.EIS_DIR}"
    if cfg.REFINE_SOH_WITH_CAPACITY:
        assert cfg.CAP_DIR.exists(), f"CAP_DIR missing: {cfg.CAP_DIR}"

    # Capacity + dynamic CPP
    cap_df = load_capacity_info(cfg.CAP_DIR)
    if cap_df.empty:
        if cfg.VERBOSE: print("[INFO] No / empty capacity data -> fallback Cpp.")
        cpp_map, global_cpp = {}, cfg.CPP_FALLBACK
    else:
        cpp_map, global_cpp = build_cpp_map(cap_df)
        if cfg.VERBOSE:
            print(f"[CPP] dynamic cells={len(cpp_map)} global_cpp_median={global_cpp:.2f}")

    # Train or load with signature check
    bundle_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    need_retrain = True
    if bundle_path.exists():
        try:
            bundle = load_bundle()
            same_sig = (bundle.get("config_signature") == config_signature(cfg)) and \
                       (bundle.get("feature_version") == cfg.FEATURE_VERSION)
            need_retrain = not same_sig
            if cfg.VERBOSE:
                print(f"[LOAD] Found bundle. Signature match: {same_sig}")
        except Exception as e:
            print(f"[LOAD] Could not load existing bundle cleanly: {e}")
            need_retrain = True

    if cfg.FORCE_RETRAIN:
        need_retrain = True

    if need_retrain:
        if cfg.VERBOSE: print("[TRAIN] Building dataset & training models...")
        meta_df, X_raw, shape_bundle, y_soc, y_soh = build_dataset(cfg.EIS_DIR, cap_df)
        if cfg.VERBOSE:
            print(f"[TRAIN] Samples={X_raw.shape[0]} Features={X_raw.shape[1]} Cells={meta_df.CellID.nunique()}")
        bundle = train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh)
    else:
        bundle = load_bundle()

    # Inference (batch over cfg.EIS_TEST_FILES)
    for test_fp in cfg.EIS_TEST_FILES:
        print(f"\n===== TEST: {test_fp.name} =====")
        if not Path(test_fp).exists():
            print(f"[WARN] Test file not found: {test_fp}")
            continue
        try:
            result, ood_flag, cycles_to_map = predict_file(Path(test_fp), bundle, cpp_map, global_cpp)
        except Exception as e:
            print(f"[ERROR] Prediction failed for {Path(test_fp).name}: {e}")
            continue

        out_plot = cfg.MODEL_DIR / f"{Path(test_fp).stem}_projection.png"
        plot_projection(
            Path(test_fp).stem,
            result["predicted_SoH_percent"],
            result["SoH_std_estimate"],
            cycles_to_map,
            result["cycles_per_percent_used"],
            result["OOD_flag"],
            out_plot,
            thresholds=cfg.TARGET_SOH_THRESHOLDS
        )

        out_json = cfg.MODEL_DIR / f"{Path(test_fp).stem}_prediction.json"
        with out_json.open("w", encoding="utf-8") as f:
            json.dump(result, f, indent=2)
        print(json.dumps(result, indent=2))
        print(f"[PLOT] Saved: {out_plot}")
        print(f"[JSON] Saved: {out_json}")

    print("\nDone.")

# =========================
# 17. GRADIO UI (Jupyter-friendly)
# =========================
def _prepare_models_for_ui(force_retrain: bool=False):
    bundle_path = cfg.MODEL_DIR / "eis_soc_soh_phys_models.joblib"

    # Capacity + dynamic CPP (mirrors main)
    cap_df = load_capacity_info(cfg.CAP_DIR) if cfg.REFINE_SOH_WITH_CAPACITY else pd.DataFrame()
    if cap_df.empty:
        cpp_map, global_cpp = {}, cfg.CPP_FALLBACK
    else:
        cpp_map, global_cpp = build_cpp_map(cap_df)

    need_retrain = force_retrain or (not bundle_path.exists())
    if not need_retrain and bundle_path.exists():
        try:
            bundle = load_bundle()
            same_sig = (bundle.get("config_signature") == config_signature(cfg)) and \
                       (bundle.get("feature_version") == cfg.FEATURE_VERSION)
            need_retrain = not same_sig
        except Exception:
            need_retrain = True

    if need_retrain:
        if not cfg.EIS_DIR.exists():
            raise FileNotFoundError(f"EIS_DIR missing: {cfg.EIS_DIR}. Update cfg.EIS_DIR before training.")
        if cfg.REFINE_SOH_WITH_CAPACITY and not cfg.CAP_DIR.exists():
            raise FileNotFoundError(f"CAP_DIR missing: {cfg.CAP_DIR}. Update cfg.CAP_DIR or disable REFINE_SOH_WITH_CAPACITY.")
        meta_df, X_raw, shape_bundle, y_soc, y_soh = build_dataset(cfg.EIS_DIR, cap_df)
        bundle = train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh)
    else:
        bundle = load_bundle()

    return bundle, cpp_map, global_cpp

def _ui_predict(file_obj, override_temp, force_retrain):
    try:
        # Optionally override test temperature from UI
        orig_temp_override = cfg.TEST_TEMPERATURE_OVERRIDE
        if override_temp is None or str(override_temp).strip() == "":
            cfg.TEST_TEMPERATURE_OVERRIDE = orig_temp_override
        else:
            try:
                cfg.TEST_TEMPERATURE_OVERRIDE = float(override_temp)
            except Exception:
                cfg.TEST_TEMPERATURE_OVERRIDE = orig_temp_override

        # Prepare model assets
        bundle, cpp_map, global_cpp = _prepare_models_for_ui(force_retrain=bool(force_retrain))

        # Normalize Gradio File
        test_fp: Optional[Path] = None
        if file_obj is None:
            raise ValueError("Please upload a file.")
        if isinstance(file_obj, (str, Path)):
            test_fp = Path(file_obj)
        elif isinstance(file_obj, dict) and "name" in file_obj:
            test_fp = Path(file_obj["name"])
        elif hasattr(file_obj, "name"):
            name = Path(getattr(file_obj, "name", "upload")).name
            suffix = Path(name).suffix or ""
            tmp_name = cfg.MODEL_DIR / f"ui_{uuid.uuid4().hex}{suffix}"
            try:
                file_obj.seek(0)
            except Exception:
                pass
            data = file_obj.read()
            if isinstance(data, str):
                data = data.encode("utf-8")
            with open(tmp_name, "wb") as f:
                f.write(data)
            test_fp = tmp_name
        else:
            tmp_name = cfg.MODEL_DIR / f"ui_{uuid.uuid4().hex}"
            data = file_obj.read()
            if isinstance(data, str):
                data = data.encode("utf-8")
            with open(tmp_name, "wb") as f:
                f.write(data)
            test_fp = tmp_name

        # Predict
        result, ood_flag, cycles_to_map = predict_file(test_fp, bundle, cpp_map, global_cpp)

        # Build plot (in-memory)
        out_plot = cfg.MODEL_DIR / f"{Path(test_fp).stem}_projection_ui.png"
        plot_projection(
            Path(test_fp).stem,
            result["predicted_SoH_percent"],
            result["SoH_std_estimate"],
            cycles_to_map,
            result["cycles_per_percent_used"],
            result["OOD_flag"],
            out_plot,
            thresholds=cfg.TARGET_SOH_THRESHOLDS
        )
        with open(out_plot, "rb") as f:
            img_bytes = f.read()
        plot_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")

        # Compact, user-facing JSON
        pretty = {
            "file": result["file"],
            "feature_checksum": result.get("feature_checksum"),
            "SoC_percent": round(result["predicted_SoC_percent"], 2),
            "SoC_std": round(result["SoC_std_estimate"], 2),
            "SoH_percent": round(result["predicted_SoH_percent"], 2),
            "SoH_std": round(result["SoH_std_estimate"], 2),
            "OOD": bool(result["OOD_flag"]),
            "cycles_per_percent": result["cycles_per_percent_used"],
            "cycles_to_thresholds": {k: round(v, 1) for k, v in result["cycles_to_thresholds"].items()},
            "soc_model": result.get("soc_model_chosen"),
            "soh_model": result.get("soh_model_chosen"),
            "decision_threshold_percent": result["decision_threshold_percent"],
            "lower_threshold_percent": result["lower_threshold_percent"]
        }
        pretty_json = json.dumps(pretty, indent=2)

        soc_value = float(result["predicted_SoC_percent"])

        # restore override
        cfg.TEST_TEMPERATURE_OVERRIDE = orig_temp_override
        return plot_img, pretty_json, soc_value
    except Exception as e:
        err = {"error": str(e)}
        return None, json.dumps(err, indent=2), None

def launch_gradio(server_name: str = "127.0.0.1",
                  server_port: int = 7860,
                  share: bool = False,
                  inbrowser: bool = False):
    try:
        import gradio as gr
    except Exception as e:
        raise ImportError("Gradio is not installed. Please: pip install gradio") from e

    if _running_in_notebook():
        try:
            import nest_asyncio
            nest_asyncio.apply()
        except Exception as e:
            print("[WARN] nest_asyncio not available. If you see 'event loop' errors, pip install nest_asyncio.", e)

    with gr.Blocks(title="EIS SoC/SoH + RUL (v8.8)") as demo:
        gr.Markdown("## Unified EIS: SoC / SoH inference & RUL projection\nUpload a single EIS file (.csv / .xlsx / .xls / .mat) to get the projection plot and SoC value.")
        with gr.Row():
            file_in = gr.File(label="Upload EIS file", file_count="single")
        with gr.Row():
            override_temp = gr.Textbox(label="Test temperature override (°C, optional)", placeholder=str(cfg.TEST_TEMPERATURE_OVERRIDE))
            force_retrain = gr.Checkbox(label="Force retrain before inference", value=False)
        predict_btn = gr.Button("Predict")
        with gr.Row():
            img_out = gr.Image(label="RUL Projection Plot", type="pil")
            json_out = gr.Code(label="Results (JSON)")
        soc_out = gr.Number(label="Predicted SoC (%)", precision=2)

        with gr.Row():
            rt_btn = gr.Button("Retrain bundle only")
            rt_status = gr.Markdown()

        def _do_predict(file_obj, temp, fr):
            return _ui_predict(file_obj, temp, fr)

        def _do_retrain():
            try:
                _prepare_models_for_ui(force_retrain=True)
                return "✅ Retrained successfully."
            except Exception as e:
                return f"❌ Retrain failed: {e}"

        predict_btn.click(_do_predict, inputs=[file_in, override_temp, force_retrain], outputs=[img_out, json_out, soc_out])
        rt_btn.click(_do_retrain, inputs=None, outputs=rt_status)

    launch_kwargs = dict(server_name=server_name, server_port=server_port, share=share)
    if _running_in_notebook():
        launch_kwargs.update(dict(inline=True, inbrowser=False, prevent_thread_lock=True, debug=False))
        demo.queue(concurrency_count=2, max_size=10)
        for attempt in range(6):
            try:
                return demo.launch(**launch_kwargs)
            except OSError as e:
                if "Address already in use" in str(e).lower():
                    launch_kwargs["server_port"] = int(launch_kwargs.get("server_port", 7860)) + 1
                    continue
                raise
            except TypeError:
                for k in ("inline", "prevent_thread_lock"):
                    launch_kwargs.pop(k, None)
                return demo.launch(**launch_kwargs)
    else:
        try:
            return demo.launch(server_name=server_name, server_port=server_port, share=share, inbrowser=inbrowser)
        except OSError as e:
            if "Address already in use" in str(e):
                return demo.launch(server_name=server_name, server_port=server_port+1, share=share, inbrowser=inbrowser)
            raise

# =========================
# 18. ENTRYPOINT
# =========================
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ui", action="store_true", help="Launch the Gradio UI")
    parser.add_argument("--share", action="store_true", help="Create a public share link")
    parser.add_argument("--host", default="127.0.0.1", help="Server host (default: 127.0.0.1)")
    parser.add_argument("--port", type=int, default=7860, help="Server port (default: 7860)")
    parser.add_argument("--inbrowser", action="store_true", help="Open UI in browser automatically")
    args, _ = parser.parse_known_args()

    if args.ui:
        launch_gradio(server_name=args.host, server_port=args.port, share=args.share, inbrowser=args.inbrowser)
    else:
        main()


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tgondal0\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tgondal0\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "EIS_TEST_FILES": [
    "Mazda-Battery-Cell9.xlsx"
  ],
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "SOH_STD_MAX_OOD": 2.0,
  "TEST_FRAC": 0.2,
  "GROUP_KFOLDS": 0,
  "RANDOM_STATE": 42,
  "USE_PCA_SOC": false,
  "USE_PCA_SOH": false,
  "PCA_SOC_COMPONENTS": 25,
  "PCA_SOH_COMPONENTS": 30,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_TAU_MAX": 10000.0,
  "DRT_LAMBDA": 0.01,
  "REFINE_SOH_WITH_CAPACIT