In [4]:
#!/usr/bin/env python
"""
Optimized Unified EIS Training + Inference + Dynamic RUL (v8 – SoC Regression)
-------------------------------------------------------------------------------
What's new vs v7:
  • SoC is now a regression task (not classification). We train both:
      - Gaussian Process Regressor (with uncertainty)
      - HistGradientBoostingRegressor
    and auto-select the best by validation R² (optionally also a shape-normalized GP).
  • Backward-compat loader still supports legacy bundles with a classifier SoC model.
    Inference handles both cases gracefully.
  • JSON now includes:
      - predicted_SoC_percent (float), SoC_std_estimate, soc_model_chosen
      - SoC_probabilities only if a legacy classifier is used

All other features remain: rich EIS features, DRT descriptors, SoH ensemble, dynamic CPP, OOD checks, plots.
"""

from __future__ import annotations
import re, json, math, random, warnings, joblib
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy import linalg
from scipy.interpolate import interp1d

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.decomposition import PCA
from sklearn.isotonic import IsotonicRegression

import matplotlib.pyplot as plt

# =========================
# 1. CONFIGURATION
# =========================
@dataclass
class Config:
    # Training data directories (update if needed)
    EIS_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")

    # Test files (multi-format allowed)
    EIS_TEST_FILES: List[Path] = None  # assigned after instantiation

    # Frequency interpolation grid
    F_MIN: float = 1e-2
    F_MAX: float = 1e4
    N_FREQ: int = 60

    # Train / split settings
    TEST_FRAC: float = 0.2        # fraction of cells for hold-out test
    GROUP_KFOLDS: int = 0         # (reserved)
    RANDOM_STATE: int = 42

    # PCA toggles (separate for SoC & SoH)
    USE_PCA_SOC: bool = True
    USE_PCA_SOH: bool = False
    PCA_SOC_COMPONENTS: int = 25
    PCA_SOH_COMPONENTS: int = 30

    # Feature group toggles
    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True

    # DRT params
    DRT_POINTS: int = 60
    DRT_TAU_MIN: float = 1e-4
    DRT_TAU_MAX: float = 1e4
    DRT_LAMBDA: float = 1e-2

    # Capacity-based refinement
    REFINE_SOH_WITH_CAPACITY: bool = True

    # SoH modeling
    MAX_GPR_TRAIN_SAMPLES: int = 3500
    INCLUDE_NORMALIZED_SHAPE_MODEL: bool = True
    ENSEMBLE_SOH: bool = True
    NORMALIZE_SHAPE_BY_HF_RE: bool = True

    # NEW: SoC modeling
    SOC_INCLUDE_SHAPE_MODEL: bool = True
    SOC_MAX_GPR_TRAIN_SAMPLES: int = 3500

    # RUL parameters
    DECISION_SOH_PERCENT: float = 50.0
    ILLUSTRATIVE_MIN_SOH: float = 40.0
    CPP_ROLLING_WINDOW: int = 5
    CPP_MIN_POINTS: int = 6
    CPP_FALLBACK: float = 20.0  # fallback cycles-per-percent

    # Inference extras
    TEST_TEMPERATURE_OVERRIDE: Optional[float] = 25.0  # applied if metadata absent
    FORCE_RETRAIN: bool = False  # force retraining even if bundle exists

    # Saving / logging
    SAVE_FEATURE_TABLE: bool = True
    VERBOSE: bool = True
    FEATURE_VERSION: int = 8

    # OOD thresholds
    MAHAL_THRESHOLD: float = 10.0
    GP_ARD_NORM_THRESHOLD: float = 6.0

    # Projection curve shape exponent
    PLOT_EXPONENT: float = 1.25

cfg = Config()
if cfg.EIS_TEST_FILES is None:
    cfg.EIS_TEST_FILES = [
        Path("Mazda-Battery-Cell1.xlsx")
    ]
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# 2. UTILITIES
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path): return str(x)
    if isinstance(x, dict): return {k: to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list, tuple)): return [to_jsonable(i) for i in x]
    return x

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# =========================
# 3. REGEX
# =========================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

# =========================
# 4. PARSERS
# =========================
def parse_eis_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = EIS_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "SOC": float(d["SOC"]),           # keep as float for regression
        "Temp": int(d["Temp"]),
        "RealSOH_file": int(d["RealSOH"])/100.0
    }

def parse_cap_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = CAP_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "Temp": int(d["Temp"]),
        "CycleIndex": int(d["Cycle"])
    }

# =========================
# 5. LOW-LEVEL LOADERS / INTERPOLATION
# =========================
def _find_matrix(mat_dict: dict):
    for v in mat_dict.values():
        if isinstance(v, np.ndarray) and v.ndim == 2 and v.shape[1] >= 3 and v.shape[0] >= 10:
            return v
    return None

def _interp_channel(freq_raw, y_raw, freq_target):
    freq_raw = np.asarray(freq_raw).astype(float)
    y_raw = np.asarray(y_raw).astype(float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw = freq_raw[::-1]; y_raw = y_raw[::-1]
    uniq, idx = np.unique(freq_raw, return_index=True)
    if len(uniq) != len(freq_raw):
        order = np.argsort(idx)
        freq_raw = uniq[order]; y_raw = y_raw[idx][order]
    f = interp1d(freq_raw, y_raw, bounds_error=False,
                 fill_value=(y_raw[0], y_raw[-1]), kind="linear")
    return f(freq_target)

FREQ_CANDS = ["frequency","freq","f","hz","frequency(hz)","Frequency(Hz)"]
RE_CANDS   = ["zreal","re(z)","re","real","z_re","zreal(ohm)","re (ohm)","re(z) (ohm)","Zreal","Zreal (ohm)","Zreal(ohm)"]
IM_CANDS   = ["-zimag","zimag","im(z)","im","imag","imaginary","z_im","zimg","z_imag"," -Zimag (ohm)"," -Zimag(ohm)","-Zimag","Zimag","Zimag (ohm)"]

def _select_column(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    low = {c.lower(): c for c in df.columns}
    for c in cands:
        if c.lower() in low: return low[c.lower()]
    for c in cands:
        for col in df.columns:
            if c.lower() in col.lower():
                return col
    return None

def load_mat_eis(path: Path):
    mat = loadmat(path)
    arr = _find_matrix(mat)
    if arr is None:
        raise ValueError(f"No valid EIS matrix in {path.name}")
    return arr[:,0].astype(float), arr[:,1].astype(float), arr[:,2].astype(float)

def load_table_eis(path: Path):
    if path.suffix.lower() == ".csv":
        df = pd.read_csv(path)
    else:
        df = pd.read_excel(path)
    if df.empty:
        raise ValueError("Empty table.")
    fcol = _select_column(df, FREQ_CANDS)
    recol = _select_column(df, RE_CANDS)
    imcol = _select_column(df, IM_CANDS)
    if recol is None or imcol is None:
        raise ValueError(f"Missing Re/Im columns in {path.name}")
    re_vals = pd.to_numeric(df[recol], errors="coerce").to_numpy()
    im_vals = pd.to_numeric(df[imcol], errors="coerce").to_numpy()
    if fcol is not None:
        freq_vals = pd.to_numeric(df[fcol], errors="coerce").to_numpy()
    else:
        n = min(len(re_vals), len(im_vals))
        freq_vals = np.geomspace(cfg.F_MAX, cfg.F_MIN, n)
    n = min(len(freq_vals), len(re_vals), len(im_vals))
    freq_vals = freq_vals[:n]; re_vals = re_vals[:n]; im_vals = im_vals[:n]
    if np.nanmean(im_vals) > 0:
        im_vals = -im_vals
    return freq_vals, re_vals.astype(float), im_vals.astype(float)

def load_any_inference(path: Path):
    suf = path.suffix.lower()
    if suf == ".mat": return load_mat_eis(path)
    if suf in (".csv",".xls",".xlsx"): return load_table_eis(path)
    raise ValueError(f"Unsupported test file extension: {suf}")

# =========================
# 6. FEATURE ENGINEERING
# =========================
def compute_F_features(freq, re_i, im_i):
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    F1 = re_i[0]; F2 = re_i[idx_peak]; F3 = re_i[-1]
    sc = np.where(np.sign(im_i[:-1]) != np.sign(im_i[1:]))[0]
    if len(sc):
        k = sc[0]; y0,y1 = im_i[k], im_i[k+1]
        w = -y0/(y1 - y0 + 1e-12)
        F4 = re_i[k] + w*(re_i[k+1]-re_i[k])
    else:
        F4 = np.nan
    F5 = (re_i[idx_peak]-F1) if idx_peak>0 else np.nan
    F6 = np.min(im_i)
    mid_target = 10.0
    idx_mid = int(np.argmin(np.abs(freq-mid_target)))
    F7 = re_i[idx_mid]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES = [
    "Rs","Rct","tau_peak","warburg_sigma","arc_quality",
    "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"
]

def physical_features(freq, re_i, im_i):
    freq = np.asarray(freq); re_i = np.asarray(re_i); im_i = np.asarray(im_i)
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    Rs = float(re_i[0]); Rpeak = float(re_i[idx_peak]); Rlow = float(re_i[-1])
    Rct = max(Rpeak - Rs, 0.0)
    arc_diam = Rlow - Rs
    norm_arc = arc_diam / (Rs + 1e-9)
    f_peak = float(freq[idx_peak])
    tau_peak = 1.0/(2*math.pi*f_peak) if f_peak>0 else np.nan
    K = min(10, len(freq)//3)
    if K >= 4:
        w_section = (2*np.pi*freq[-K:])**(-0.5)
        re_section = re_i[-K:]
        if len(np.unique(w_section)) > 2:
            warburg_sigma = float(np.polyfit(w_section, re_section, 1)[0])
        else:
            warburg_sigma = np.nan
    else:
        warburg_sigma = np.nan
    phase = np.arctan2(-im_i, re_i)
    mid_mask = (freq>=1) & (freq<=100)
    if mid_mask.sum()>2:
        phase_mean_mid = float(phase[mid_mask].mean())
        phase_std_mid  = float(phase[mid_mask].std())
    else:
        phase_mean_mid = np.nan; phase_std_mid = np.nan
    phase_min = float(phase.min())
    lf_mask = (freq<=1.0)
    if lf_mask.sum() >= 4:
        x = np.log10(freq[lf_mask]+1e-12); y = neg_im[lf_mask]
        lf_slope = np.polyfit(x, y, 1)[0]
    else:
        lf_slope = np.nan
    arc_quality = (neg_im.max() - neg_im.min())/(abs(neg_im.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_quality,
            phase_mean_mid,phase_std_mid,phase_min,lf_slope,norm_arc]

BANDS = [(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq, re_i, im_i):
    feats=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m], im_i[m])
            feats += [z.mean(), z.std()]
        else:
            feats += [np.nan, np.nan]
    return feats

def diff_slopes(freq, re_i, im_i, segments=5):
    logf = np.log10(freq)
    edges = np.linspace(logf.min(), logf.max(), segments+1)
    out=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            x=logf[m]
            out += [np.polyfit(x,re_i[m],1)[0], np.polyfit(x,(-im_i)[m],1)[0]]
        else:
            out += [np.nan, np.nan]
    return out

DRT_FEATURE_NAMES = [
    "drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
    "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"
]

def compute_drt(freq,re_i,im_i,tau_min,tau_max,n_tau,lam):
    w = 2*np.pi*freq
    tau = np.geomspace(tau_max, tau_min, n_tau)
    WT = w[:,None]*tau[None,:]
    denom = 1+WT**2
    K_re = 1.0/denom
    K_im = -WT/denom
    R_inf = re_i[0]
    y_re = re_i - R_inf
    y_im = im_i
    Y = np.concatenate([y_re, y_im])
    K = np.vstack([K_re, K_im])
    A = K.T @ K + lam*np.eye(n_tau)
    b = K.T @ Y
    gamma = linalg.solve(A,b,assume_a='pos')
    gamma = np.clip(gamma,0,None)
    return tau, gamma

def drt_features(freq,re_i,im_i):
    try:
        tau,gamma = compute_drt(freq,re_i,im_i,
                                 cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,
                                 cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        log_tau = np.log10(tau)
        g_sum = gamma.sum()+1e-12
        w_norm = gamma/g_sum
        mean_logtau = float((w_norm*log_tau).sum())
        var_logtau  = float((w_norm*(log_tau-mean_logtau)**2).sum())
        p = int(np.argmax(gamma))
        peak_tau = float(tau[p]); peak_gamma=float(gamma[p])
        mid = np.median(log_tau)
        frac_low = float(w_norm[log_tau<=mid].sum())
        frac_high = 1-frac_low
        return [g_sum,mean_logtau,var_logtau,peak_tau,peak_gamma,frac_low,frac_high]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i, im_i, temp, freq, include_names=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts += [re_i, im_i]
        names += [f"Re_{i}" for i in range(len(re_i))] + [f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z = np.hypot(re_i, im_i)
        basics=[re_i[0], re_i[-1], re_i[-1]-re_i[0], z.max(), z.mean(), z.std()]
        parts.append(np.array(basics)); names += ["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        Ff=compute_F_features(freq,re_i,im_i); parts.append(np.array(Ff)); names += [f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        Pf=physical_features(freq,re_i,im_i); parts.append(np.array(Pf)); names += PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        Bf=band_stats(freq,re_i,im_i); parts.append(np.array(Bf))
        for bi in range(len(BANDS)): names += [f"band{bi}_mean", f"band{bi}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        Ds=diff_slopes(freq,re_i,im_i); parts.append(np.array(Ds))
        for i in range(len(Ds)//2): names += [f"slope_re_seg{i}", f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        Df=drt_features(freq,re_i,im_i); parts.append(np.array(Df)); names += DRT_FEATURE_NAMES
    parts.append(np.array([temp])); names += ["Feat_Temp"]
    vec = np.concatenate(parts).astype(float)
    vec = np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    if include_names: return vec, names
    return vec

def build_shape_normalized(re_i, im_i, k: int = 5):
    # Use median of the K highest-frequency Re points as HF scale (robust to noise/missing HF)
    hf = float(np.nanmedian(re_i[:max(1, min(k, len(re_i)))]))
    if not np.isfinite(hf) or abs(hf) < 1e-9:
        hf = 1.0
    return re_i / hf, im_i / hf

# =========================
# 7. CAPACITY & CPP
# =========================
def load_capacity_info(cap_dir: Path) -> pd.DataFrame:
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta = parse_cap_metadata(fp.stem)
        if not meta: continue
        try:
            mat=loadmat(fp); arr=_find_matrix(mat)
            if arr is None: continue
            col=np.argmax(np.abs(arr[-50:, :]).mean(axis=0))
            cap=float(np.nanmax(arr[:,col]))
            meta["MeasuredCapacity_Ah"]=cap
            recs.append(meta)
        except Exception:
            pass
    df=pd.DataFrame(recs)
    if df.empty: return df
    ref=df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"]=df["MeasuredCapacity_Ah"]/ref
    df["SoH_percent"]=df["NormCapacity"]*100.0
    return df

def estimate_cpp_per_cell(capacity_df: pd.DataFrame,
                          window:int, min_points:int)->Dict[str,float]:
    cpp={}
    for cid,grp in capacity_df.groupby("CellID"):
        g=grp.sort_values("CycleIndex")
        if g.shape[0]<min_points: continue
        tail=g.tail(window)
        x=tail["CycleIndex"].values.astype(float)
        y=tail["SoH_percent"].values.astype(float)
        if len(np.unique(x))<2: continue
        slope=np.polyfit(x,y,1)[0]  # SoH% / cycle
        if slope >= -1e-6:  # non-degrading
            continue
        cpp[cid]=1.0/abs(slope)
    return cpp

def build_cpp_map(cap_df: pd.DataFrame):
    if cap_df.empty: return {}, cfg.CPP_FALLBACK
    cpp_map=estimate_cpp_per_cell(
        cap_df[["CellID","CycleIndex","SoH_percent"]],
        cfg.CPP_ROLLING_WINDOW, cfg.CPP_MIN_POINTS
    )
    if not cpp_map:
        return {}, cfg.CPP_FALLBACK
    return cpp_map, float(np.median(list(cpp_map.values())))

def get_cpp(meta: dict, cpp_map: Dict[str,float], global_cpp: float):
    if not meta: return global_cpp
    return cpp_map.get(meta.get("CellID"), global_cpp)

# =========================
# 8. DATASET BUILD (training on .mat only)
# =========================
def load_single_eis_mat(fp: Path):
    meta = parse_eis_metadata(fp.stem)
    if meta is None:
        raise ValueError(f"Bad filename: {fp.name}")
    freq,re_z,im_z = load_mat_eis(fp)
    re_i=_interp_channel(freq, re_z, CANON_FREQ)
    im_i=_interp_channel(freq, im_z, CANON_FREQ)
    vec=build_feature_vector(re_i, im_i, meta["Temp"], CANON_FREQ)
    return vec, meta, re_i, im_i

def build_dataset(eis_dir: Path, cap_df: Optional[pd.DataFrame]):
    files = sorted(eis_dir.rglob("*.mat"))
    if not files:
        raise FileNotFoundError(f"No .mat spectra in {eis_dir}")

    f0,r0,i0 = load_mat_eis(files[0])
    re0=_interp_channel(f0,r0,CANON_FREQ); im0=_interp_channel(f0,i0,CANON_FREQ)
    _, feature_names = build_feature_vector(re0, im0, 25.0, CANON_FREQ, include_names=True)

    feats=[]; rows=[]; shape_feats=[]
    for fp in tqdm(files, desc="Loading training spectra"):
        try:
            v, m, rei, imi = load_single_eis_mat(fp)
            feats.append(v); rows.append(m)
            if (cfg.INCLUDE_NORMALIZED_SHAPE_MODEL or cfg.SOC_INCLUDE_SHAPE_MODEL) and cfg.NORMALIZE_SHAPE_BY_HF_RE:
                rsh, ish = build_shape_normalized(rei, imi)
                shape_vec = build_feature_vector(rsh, ish, m["Temp"], CANON_FREQ)
                shape_feats.append(shape_vec)
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")

    if not rows:
        raise RuntimeError("No valid training spectra after filtering.")

    X = np.vstack(feats)
    X_shape = np.vstack(shape_feats) if shape_feats else None
    meta_df = pd.DataFrame(rows)

    # SoH refinement with capacity
    if cap_df is not None and not cap_df.empty and cfg.REFINE_SOH_WITH_CAPACITY:
        lookup = cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        refined=[]
        for cid, stage, fallback in zip(meta_df.CellID, meta_df.SOH_stage, meta_df.RealSOH_file):
            nc = lookup.get((cid, stage))
            refined.append(100.0*nc if nc is not None else fallback)
        meta_df["SoH_cont"]=refined
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    # Targets
    y_soc = meta_df["SOC"].astype(float).values  # continuous SoC
    y_soh = meta_df["SoH_cont"].values

    soh_var = float(np.var(y_soh))
    if cfg.VERBOSE:
        print(f"[DATA] SoH range: {y_soh.min():.2f} – {y_soh.max():.2f} (var={soh_var:.3f})")
        if soh_var < 1.0:
            print("[WARN] Low SoH variance → model may output near-constant SoH.")

    if cfg.SAVE_FEATURE_TABLE:
        pd.concat(
            [meta_df.reset_index(drop=True),
             pd.DataFrame(X, columns=feature_names)], axis=1
        ).to_parquet(cfg.MODEL_DIR/"training_features.parquet", index=False)

    return meta_df, X, (X_shape, feature_names), y_soc, y_soh

# =========================
# 9. SPLITTING
# =========================
def cell_split_mask(meta_df: pd.DataFrame):
    cells = meta_df.CellID.unique()
    rng = np.random.default_rng(cfg.RANDOM_STATE)
    n_test = max(1, int(len(cells)*cfg.TEST_FRAC))
    test_cells = rng.choice(cells, size=n_test, replace=False)
    return meta_df.CellID.isin(test_cells)

# =========================
# 10. TRAINING
# =========================
def _fit_gpr(X, y, seed, max_samples):
    dim = X.shape[1]
    kernel = RBF(length_scale=np.ones(dim)*3.0,
                 length_scale_bounds=(1e-1,1e4)) + \
             WhiteKernel(noise_level=1e-2,
                         noise_level_bounds=(1e-6,1e-1))
    gpr = GaussianProcessRegressor(
        kernel=kernel, alpha=0.0, normalize_y=True,
        random_state=seed, n_restarts_optimizer=3
    )
    if X.shape[0] > max_samples:
        idx = np.random.default_rng(seed).choice(
            X.shape[0], size=max_samples, replace=False)
        gpr.fit(X[idx], y[idx])
    else:
        gpr.fit(X, y)
    return gpr

def train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh):
    X_shape, feature_names = shape_bundle
    mask_test = cell_split_mask(meta_df)

    # ----- SoC pipeline (REGRESSION) -----
    soc_scaler = StandardScaler()
    X_soc_s = soc_scaler.fit_transform(X_raw)
    soc_pca = None
    X_soc_model = X_soc_s
    if cfg.USE_PCA_SOC:
        soc_pca = PCA(n_components=min(cfg.PCA_SOC_COMPONENTS, X_soc_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soc_model = soc_pca.fit_transform(X_soc_s)

    soc_candidates = {}

    # (1) SoC GPR (raw features)
    soc_gpr = _fit_gpr(X_soc_model, y_soc, cfg.RANDOM_STATE, cfg.SOC_MAX_GPR_TRAIN_SAMPLES)
    pred_soc_gpr = soc_gpr.predict(X_soc_model[mask_test])
    r2_soc_gpr = r2_score(y_soc[mask_test], pred_soc_gpr)
    rmse_soc_gpr = math.sqrt(mean_squared_error(y_soc[mask_test], pred_soc_gpr))
    soc_candidates["soc_gpr_raw"] = (soc_gpr, r2_soc_gpr, rmse_soc_gpr)

    # (2) SoC HGB (raw features)
    soc_hgb = HistGradientBoostingRegressor(
        learning_rate=0.05, max_iter=500,
        l2_regularization=1e-3, random_state=cfg.RANDOM_STATE
    )
    soc_hgb.fit(X_soc_model[~mask_test], y_soc[~mask_test])
    pred_soc_hgb = soc_hgb.predict(X_soc_model[mask_test])
    r2_soc_hgb = r2_score(y_soc[mask_test], pred_soc_hgb)
    rmse_soc_hgb = math.sqrt(mean_squared_error(y_soc[mask_test], pred_soc_hgb))
    soc_candidates["soc_hgb_raw"] = (soc_hgb, r2_soc_hgb, rmse_soc_hgb)

    # (3) Optional shape-normalized SoC GP
    soc_shape_model=None; soc_shape_scaler=None; soc_shape_pca=None; soc_shape_metrics=None
    if cfg.SOC_INCLUDE_SHAPE_MODEL and (X_shape is not None):
        soc_shape_scaler = StandardScaler()
        Xs = soc_shape_scaler.fit_transform(X_shape)
        if cfg.USE_PCA_SOC:
            soc_shape_pca = PCA(n_components=min(cfg.PCA_SOC_COMPONENTS, Xs.shape[1]-1),
                                random_state=cfg.RANDOM_STATE)
            Xs_model = soc_shape_pca.fit_transform(Xs)
        else:
            Xs_model = Xs
        soc_shape_model = _fit_gpr(Xs_model, y_soc, cfg.RANDOM_STATE, cfg.SOC_MAX_GPR_TRAIN_SAMPLES)
        sp = soc_shape_model.predict(Xs_model[mask_test])
        r2_soc_shape = r2_score(y_soc[mask_test], sp)
        rmse_soc_shape = math.sqrt(mean_squared_error(y_soc[mask_test], sp))
        soc_candidates["soc_gpr_shape"] = (soc_shape_model, r2_soc_shape, rmse_soc_shape)
        soc_shape_metrics = {"r2": r2_soc_shape, "rmse": rmse_soc_shape}

    # Select best SoC model
    soc_best_name = max(soc_candidates.keys(), key=lambda k: soc_candidates[k][1])
    soc_best_model, soc_best_r2, soc_best_rmse = soc_candidates[soc_best_name]

    if cfg.VERBOSE:
        print(f"[SoC] GPR_raw:   R2={r2_soc_gpr:.3f} RMSE={rmse_soc_gpr:.2f}")
        print(f"[SoC] HGB_raw:   R2={r2_soc_hgb:.3f} RMSE={rmse_soc_hgb:.2f}")
        if soc_shape_metrics:
            print(f"[SoC] ShapeGP:  R2={soc_shape_metrics['r2']:.3f} RMSE={soc_shape_metrics['rmse']:.2f}")
        print(f"[SoC] Selected = {soc_best_name}")

    # ----- SoH pipeline (unchanged) -----
    soh_scaler = StandardScaler()
    X_soh_s = soh_scaler.fit_transform(X_raw)
    soh_pca=None
    X_soh_model = X_soh_s
    if cfg.USE_PCA_SOH:
        soh_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_soh_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soh_model = soh_pca.fit_transform(X_soh_s)

    soh_candidates = {}

    # (1) Raw Gaussian Process
    gpr = _fit_gpr(X_soh_model, y_soh, cfg.RANDOM_STATE, cfg.MAX_GPR_TRAIN_SAMPLES)
    pred_gpr = gpr.predict(X_soh_model[mask_test])
    r2_gpr = r2_score(y_soh[mask_test], pred_gpr)
    rmse_gpr = math.sqrt(mean_squared_error(y_soh[mask_test], pred_gpr))
    soh_candidates["gpr_raw"] = (gpr, r2_gpr, rmse_gpr)

    # (2) HistGradientBoosting
    hgb = HistGradientBoostingRegressor(
        learning_rate=0.05, max_iter=500,
        l2_regularization=1e-3, random_state=cfg.RANDOM_STATE
    )
    hgb.fit(X_soh_model[~mask_test], y_soh[~mask_test])
    pred_hgb = hgb.predict(X_soh_model[mask_test])
    r2_hgb = r2_score(y_soh[mask_test], pred_hgb)
    rmse_hgb = math.sqrt(mean_squared_error(y_soh[mask_test], pred_hgb))
    soh_candidates["hgb_raw"] = (hgb, r2_hgb, rmse_hgb)

    # (3) Shape-normalized GP
    shape_model = None; shape_scaler=None; shape_pca=None; shape_metrics=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and (X_shape is not None):
        shape_scaler = StandardScaler()
        X_shape_s = shape_scaler.fit_transform(X_shape)
        X_shape_model = X_shape_s
        if cfg.USE_PCA_SOH:
            shape_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_shape_s.shape[1]-1),
                            random_state=cfg.RANDOM_STATE)
            X_shape_model = shape_pca.fit_transform(X_shape_s)
        shape_model = _fit_gpr(X_shape_model, y_soh, cfg.RANDOM_STATE, cfg.MAX_GPR_TRAIN_SAMPLES)
        spred = shape_model.predict(X_shape_model[mask_test])
        r2_shape = r2_score(y_soh[mask_test], spred)
        rmse_shape = math.sqrt(mean_squared_error(y_soh[mask_test], spred))
        soh_candidates["gpr_shape"] = (shape_model, r2_shape, rmse_shape)
        shape_metrics = {"r2": r2_shape, "rmse": rmse_shape}

    best_name = max(["gpr_raw","hgb_raw"], key=lambda k: soh_candidates[k][1])
    best_model, best_r2, best_rmse = soh_candidates[best_name]

    if cfg.VERBOSE:
        print(f"[SoH] GPR_raw:  R2={r2_gpr:.3f} RMSE={rmse_gpr:.2f}")
        print(f"[SoH] HGB_raw:  R2={r2_hgb:.3f} RMSE={rmse_hgb:.2f}")
        if shape_metrics:
            print(f"[SoH] ShapeGP: R2={shape_metrics['r2']:.3f} RMSE={shape_metrics['rmse']:.2f}")
        print(f"[SoH] Selected raw model = {best_name}")

    # Mahalanobis precomputed on scaled raw SoH space
    cov = np.cov(X_soh_s.T)
    try:
        cov_inv = np.linalg.pinv(cov)
    except Exception:
        cov_inv = np.eye(cov.shape[0])
    center = X_soh_s.mean(axis=0)

    bundle = {
        # SoC (regression)
        "soc_scaler": soc_scaler,
        "soc_pca": soc_pca,
        "soc_model": soc_best_model,
        "soc_model_name": soc_best_name,
        "soc_shape_scaler": soc_shape_scaler,
        "soc_shape_pca": soc_shape_pca,
        "soc_shape_model": soc_shape_model,

        # SoH
        "soh_scaler": soh_scaler,
        "soh_pca": soh_pca,
        "soh_model": best_model,
        "soh_model_name": best_name,

        # Optional SoH shape model
        "shape_scaler": shape_scaler,
        "shape_pca": shape_pca,
        "shape_model": shape_model,

        # Meta
        "freq_grid": CANON_FREQ,
        "feature_version": cfg.FEATURE_VERSION,
        "feature_manifest": feature_names,
        "config": to_jsonable(asdict(cfg)),
        "metrics": {
            "soc_r2_selected": soc_best_r2,
            "soc_rmse_selected": soc_best_rmse,
            "soh_r2_selected": best_r2,
            "soh_rmse_selected": best_rmse
        },
        "soc_candidates_metrics": {
            "soc_gpr_raw": {"r2": r2_soc_gpr, "rmse": rmse_soc_gpr},
            "soc_hgb_raw": {"r2": r2_soc_hgb, "rmse": rmse_soc_hgb},
            "soc_gpr_shape": soc_shape_metrics
        },
        "soh_candidates_metrics": {
            "gpr_raw": {"r2": r2_gpr, "rmse": rmse_gpr},
            "hgb_raw": {"r2": r2_hgb, "rmse": rmse_hgb},
            "gpr_shape": shape_metrics
        },
        "train_mahal": {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
    }
    out_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    joblib.dump(bundle, out_path)
    if cfg.VERBOSE:
        print(f"[MODEL] Saved bundle → {out_path}")
        print(json.dumps(bundle["metrics"], indent=2))
    return bundle

# =========================
# 11. LOAD (WITH LEGACY COMPATIBILITY SHIM)
# =========================
def load_bundle():
    """
    Legacy schema keys:
        "scaler", "pca", "soc_model", "soh_model"
    New schema keys (v8):
        SoC regression keys: "soc_scaler","soc_pca","soc_model","soc_model_name",...
    """
    path = cfg.MODEL_DIR / "eis_soc_soh_phys_models.joblib"
    bundle = joblib.load(path)

    # Legacy detection: older single scaler/pca, possibly classifier SoC model
    legacy = ("scaler" in bundle) and ("soc_scaler" not in bundle)
    if legacy:
        scaler = bundle["scaler"]
        pca = bundle.get("pca")
        soc_model = bundle.get("soc_model")   # may be classifier
        soh_model = bundle.get("soh_model")

        bundle["soc_scaler"] = scaler
        bundle["soh_scaler"] = scaler
        bundle["soc_pca"]    = pca
        bundle["soh_pca"]    = pca
        bundle["soh_model"]  = soh_model
        # best guess name
        name = bundle.get("soh_model_name","legacy_model")
        bundle["soh_model_name"] = name
        # mark unknown SoC model type; handled in inference
        bundle["soc_model_name"] = bundle.get("soc_model_name","legacy_soc_model")
        if "metrics" not in bundle:
            bundle["metrics"] = {}
        bundle["metrics"].setdefault("soh_rmse_selected", 5.0)
        bundle["metrics"].setdefault("soc_rmse_selected", 8.0)  # coarse fallback
        if "train_mahal" not in bundle:
            try:
                center = scaler.mean_
                cov_inv = np.eye(len(center))
                bundle["train_mahal"] = {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
            except Exception:
                bundle["train_mahal"] = None
        bundle.setdefault("feature_version", -1)

    # Sanity check
    for key in ["soc_scaler","soc_model","soh_scaler","soh_model","freq_grid"]:
        if key not in bundle:
            raise KeyError(f"Bundle missing required key: {key}")

    return bundle

# =========================
# 12. INFERENCE FEATURIZATION
# =========================
def featurize_any(file_path: Path, bundle):
    freq_grid = bundle["freq_grid"]
    meta = parse_eis_metadata(file_path.stem)
    freq,re_raw,im_raw = load_any_inference(file_path)
    re_i=_interp_channel(freq, re_raw, freq_grid)
    im_i=_interp_channel(freq, im_raw, freq_grid)
    if meta is None and cfg.TEST_TEMPERATURE_OVERRIDE is not None:
        temp = cfg.TEST_TEMPERATURE_OVERRIDE
    else:
        temp = meta["Temp"] if meta else -1
    vec = build_feature_vector(re_i, im_i, temp, freq_grid)
    norm_vec=None
    if (cfg.INCLUDE_NORMALIZED_SHAPE_MODEL or cfg.SOC_INCLUDE_SHAPE_MODEL) and \
       (bundle.get("shape_model") is not None or bundle.get("soc_shape_model") is not None):
        if cfg.NORMALIZE_SHAPE_BY_HF_RE:
            rsh, ish = build_shape_normalized(re_i, im_i)
            norm_vec = build_feature_vector(rsh, ish, temp, freq_grid)
    return vec, norm_vec, meta

# =========================
# 13. OOD UTILITIES
# =========================
def mahalanobis_distance(x, center, cov_inv):
    diff = x - center
    return float(np.sqrt(diff @ cov_inv @ diff.T))

def gp_ard_norm(Xp, model):
    try:
        K = model.kernel_
        from sklearn.gaussian_process.kernels import RBF
        rbf = None
        if hasattr(K,"k1") and isinstance(K.k1,RBF): rbf=K.k1
        elif hasattr(K,"k2") and isinstance(K.k2,RBF): rbf=K.k2
        if rbf is None: return None
        ls = np.atleast_1d(rbf.length_scale)
        z = (Xp / ls).ravel()
        return float(np.linalg.norm(z))
    except Exception:
        return None

# =========================
# 14. PROJECTION PLOT
# =========================
def build_projection(soh_current, cpp, lower, exponent=None, n=160):
    if soh_current <= lower or cpp <= 0:
        return np.array([0.0]), np.array([soh_current])
    total = (soh_current - lower) * cpp
    cycles = np.linspace(0, total, n)
    S0 = soh_current; Smin=lower
    if exponent is None: exponent = cfg.PLOT_EXPONENT
    soh_curve = Smin + (S0 - Smin)*(1 - cycles/total)**exponent
    return cycles, soh_curve

def plot_projection(file_base, soh_current, soh_std, cycles_to_target,
                    cycles_to_lower, cpp, ood_flag, out_path):
    if cycles_to_lower <= 0:
        return
    cycles, curve = build_projection(soh_current, cpp, cfg.ILLUSTRATIVE_MIN_SOH)
    plt.figure(figsize=(6.4,4))
    plt.plot(cycles, curve, lw=2, label="Projected SoH")
    plt.axhline(cfg.DECISION_SOH_PERCENT, color="orange", ls="--", label=f"{cfg.DECISION_SOH_PERCENT:.0f}% target")
    plt.axhline(cfg.ILLUSTRATIVE_MIN_SOH, color="red", ls=":", label=f"{cfg.ILLUSTRATIVE_MIN_SOH:.0f}% lower")
    plt.scatter([0],[soh_current], c="green", s=55, label=f"Current {soh_current:.2f}%")
    plt.text(0, soh_current+0.7, f"±{soh_std:.2f}", color="green", fontsize=8)
    if cycles_to_target > 0:
        plt.axvline(cycles_to_target, color="orange", ls="-.")
        plt.scatter([cycles_to_target],[cfg.DECISION_SOH_PERCENT], c="orange", s=45)
        plt.text(cycles_to_target, cfg.DECISION_SOH_PERCENT+1.0,
                 f"{cycles_to_target:.0f} cyc", ha="center", color="orange", fontsize=8)
    plt.scatter([cycles[-1]],[cfg.ILLUSTRATIVE_MIN_SOH], c="red", s=50)
    plt.text(cycles[-1], cfg.ILLUSTRATIVE_MIN_SOH-2,
             f"{cycles[-1]:.0f} cyc", ha="center", color="red", fontsize=8)
    if ood_flag:
        plt.text(0.98,0.05,"OOD", transform=plt.gca().transAxes,
                 ha="right", va="bottom", color="crimson", fontsize=11,
                 bbox=dict(boxstyle="round", fc="w", ec="crimson"))
    plt.xlabel("Remaining Cycles")
    plt.ylabel("SoH (%)")
    plt.title(f"RUL Projection – {file_base}")
    plt.grid(alpha=0.35)
    plt.legend(fontsize=8)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140)
    plt.close()

# =========================
# 15. INFERENCE (SINGLE FILE)
# =========================
def predict_file(file_path: Path, bundle, cpp_map, global_cpp):
    vec, norm_vec, meta = featurize_any(file_path, bundle)

    # ----- SoC (regression now; legacy classifier supported) -----
    soc_scaler=bundle["soc_scaler"]; soc_pca=bundle.get("soc_pca")
    soc_model=bundle["soc_model"]; soc_model_name=bundle.get("soc_model_name","unknown")
    X_soc = soc_scaler.transform(vec.reshape(1,-1))
    X_soc_in = soc_pca.transform(X_soc) if soc_pca else X_soc

    soc_probs = None
    # Legacy classifier path: if predict_proba exists, use it and also output a continuous estimate by class-prob expectation
    if hasattr(soc_model, "predict_proba"):
        p = soc_model.predict_proba(X_soc_in)[0]
        # Try to read class labels (assumed numeric percent levels)
        classes = getattr(soc_model, "classes_", np.arange(len(p)))
        classes = np.array(classes, dtype=float)
        soc_mean = float(np.dot(p, classes))
        # std as categorical variance proxy
        var = float(np.dot(p, (classes - soc_mean)**2))
        soc_std = math.sqrt(var) if var>0 else float(bundle["metrics"].get("soc_rmse_selected", 8.0))
        soc_probs = {float(c): float(pp) for c,pp in zip(classes,p)}
    else:
        # Regression path
        if isinstance(soc_model, GaussianProcessRegressor):
            sm, ss = soc_model.predict(X_soc_in, return_std=True)
            soc_mean = float(sm[0]); soc_std=float(ss[0])
        else:
            soc_mean = float(soc_model.predict(X_soc_in)[0])
            soc_std  = float(bundle["metrics"].get("soc_rmse_selected", 8.0))

        # If a shape SoC model exists, average (ensemble) for robustness
        soc_shape_model = bundle.get("soc_shape_model")
        if soc_shape_model is not None and norm_vec is not None:
            sscaler = bundle.get("soc_shape_scaler"); spca = bundle.get("soc_shape_pca")
            Xs = sscaler.transform(norm_vec.reshape(1,-1))
            Xs_in = spca.transform(Xs) if spca else Xs
            if isinstance(soc_shape_model, GaussianProcessRegressor):
                sm2, ss2 = soc_shape_model.predict(Xs_in, return_std=True)
                mean2=float(sm2[0]); std2=float(ss2[0])
            else:
                mean2=float(soc_shape_model.predict(Xs_in)[0]); std2=float(bundle["metrics"].get("soc_rmse_selected", 8.0))
            soc_mean = 0.5*(soc_mean + mean2)
            soc_std  = float(np.sqrt(0.5*(soc_std**2 + std2**2)))

    # ----- SoH (raw primary + optional shape ensemble) -----
    soh_scaler=bundle["soh_scaler"]; soh_pca=bundle.get("soh_pca")
    soh_model=bundle["soh_model"]; model_name=bundle.get("soh_model_name","unknown")
    X_soh_s = soh_scaler.transform(vec.reshape(1,-1))
    X_soh_in = soh_pca.transform(X_soh_s) if soh_pca else X_soh_s

    if "gpr" in model_name:
        sm, ss = soh_model.predict(X_soh_in, return_std=True)
        soh_mean_raw = float(sm[0]); soh_std_raw=float(ss[0])
    else:
        soh_mean_raw = float(soh_model.predict(X_soh_in)[0])
        soh_std_raw  = float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    shape_model = bundle.get("shape_model")
    shape_soh_mean=None; shape_soh_std=None
    if shape_model is not None and norm_vec is not None:
        sscaler = bundle.get("shape_scaler")
        spca = bundle.get("shape_pca")
        X_shape_s = sscaler.transform(norm_vec.reshape(1,-1))
        X_shape_in = spca.transform(X_shape_s) if spca else X_shape_s
        if isinstance(shape_model, GaussianProcessRegressor):
            sm2, ss2 = shape_model.predict(X_shape_in, return_std=True)
            shape_soh_mean=float(sm2[0]); shape_soh_std=float(ss2[0])
        else:
            shape_soh_mean=float(shape_model.predict(X_shape_in)[0])
            shape_soh_std=float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    if cfg.ENSEMBLE_SOH and shape_soh_mean is not None:
        soh_mean = 0.5*(soh_mean_raw + shape_soh_mean)
        stds = [soh_std_raw]
        if shape_soh_std is not None: stds.append(shape_soh_std)
        soh_std = float(np.sqrt(np.mean(np.array(stds)**2)))
    else:
        soh_mean, soh_std = soh_mean_raw, soh_std_raw

    # RUL
    cpp = get_cpp(meta, cpp_map, global_cpp)
    if soh_mean > cfg.DECISION_SOH_PERCENT:
        cycles_to_target = (soh_mean - cfg.DECISION_SOH_PERCENT)*cpp
    else:
        cycles_to_target = 0.0
    if soh_mean > cfg.ILLUSTRATIVE_MIN_SOH:
        cycles_to_lower = (soh_mean - cfg.ILLUSTRATIVE_MIN_SOH)*cpp
    else:
        cycles_to_lower = 0.0

    # OOD diagnostics (based on SoH space)
    train_mahal = bundle.get("train_mahal")
    mahal_dist=None
    if train_mahal:
        cov_inv = np.array(train_mahal["cov_inv"])
        center = np.array(train_mahal["center"])
        mahal_dist = mahalanobis_distance(X_soh_s[0], center, cov_inv)
    ard_norm=None
    if "gpr" in model_name:
        ard_norm = gp_ard_norm(X_soh_in, soh_model)
    ood_flag=False
    if (mahal_dist is not None and mahal_dist > cfg.MAHAL_THRESHOLD) or \
       (ard_norm is not None and ard_norm > cfg.GP_ARD_NORM_THRESHOLD):
        ood_flag=True

    result={
        "file": str(file_path),
        "parsed_metadata": meta,
        # SoC (continuous)
        "predicted_SoC_percent": float(soc_mean),
        "SoC_std_estimate": float(soc_std),
        "soc_model_chosen": soc_model_name,
        "SoC_probabilities": soc_probs,  # may be None for regression
        # SoH
        "predicted_SoH_percent": float(soh_mean),
        "SoH_std_estimate": float(soh_std),
        "raw_model_mean": float(soh_mean_raw),
        "raw_model_std": float(soh_std_raw),
        "shape_model_mean": None if shape_model is None else float(shape_soh_mean),
        "shape_model_std": None if shape_model is None else float(shape_soh_std),
        # RUL
        "cycles_per_percent_used": float(cpp),
        "cycles_to_target": float(cycles_to_target),
        "cycles_to_lower": float(cycles_to_lower),
        "decision_threshold_percent": cfg.DECISION_SOH_PERCENT,
        "lower_threshold_percent": cfg.ILLUSTRATIVE_MIN_SOH,
        # Meta
        "feature_version": bundle.get("feature_version"),
        "soh_model_chosen": model_name,
        # OOD
        "OOD_mahal": None if mahal_dist is None else float(mahal_dist),
        "OOD_gp_ard_norm": None if ard_norm is None else float(ard_norm),
        "OOD_flag": bool(ood_flag)
    }
    return result, ood_flag, cycles_to_target, cycles_to_lower

# =========================
# 16. MAIN
# =========================
def main():
    if cfg.VERBOSE:
        print("Configuration:\n", json.dumps(to_jsonable(asdict(cfg)), indent=2))

    # Directory assertions
    assert cfg.EIS_DIR.exists(), f"EIS_DIR missing: {cfg.EIS_DIR}"
    if cfg.REFINE_SOH_WITH_CAPACITY:
        assert cfg.CAP_DIR.exists(), f"CAP_DIR missing: {cfg.CAP_DIR}"

    # Capacity + dynamic CPP
    cap_df = load_capacity_info(cfg.CAP_DIR)
    if cap_df.empty:
        if cfg.VERBOSE:
            print("[INFO] No / empty capacity data -> fallback Cpp.")
        cpp_map, global_cpp = {}, cfg.CPP_FALLBACK
    else:
        cpp_map, global_cpp = build_cpp_map(cap_df)
        if cfg.VERBOSE:
            print(f"[CPP] dynamic cells={len(cpp_map)} global_cpp_median={global_cpp:.2f}")

    bundle_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    if bundle_path.exists() and not cfg.FORCE_RETRAIN:
        if cfg.VERBOSE:
            print(f"[LOAD] Using existing model bundle: {bundle_path}")
        bundle = load_bundle()
    else:
        if cfg.VERBOSE:
            print("[TRAIN] Building dataset & training models...")
        meta_df, X_raw, shape_bundle, y_soc, y_soh = build_dataset(cfg.EIS_DIR, cap_df)
        if cfg.VERBOSE:
            print(f"[TRAIN] Samples={X_raw.shape[0]} Features={X_raw.shape[1]} Cells={meta_df.CellID.nunique()}")
        bundle = train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh)

    # Inference
    for test_fp in cfg.EIS_TEST_FILES:
        print(f"\n===== TEST: {test_fp.name} =====")
        if not test_fp.exists():
            print(f"[WARN] Test file not found: {test_fp}")
            continue
        try:
            result, ood_flag, cyc_target, cyc_lower = predict_file(test_fp, bundle, cpp_map, global_cpp)
        except Exception as e:
            print(f"[ERROR] Prediction failed for {test_fp.name}: {e}")
            continue

        out_plot = cfg.MODEL_DIR / f"{test_fp.stem}_projection.png"
        plot_projection(
            test_fp.stem,
            result["predicted_SoH_percent"],
            result["SoH_std_estimate"],
            result["cycles_to_target"],
            result["cycles_to_lower"],
            result["cycles_per_percent_used"],
            result["OOD_flag"],
            out_plot
        )

        out_json = cfg.MODEL_DIR / f"{test_fp.stem}_prediction.json"
        with out_json.open("w", encoding="utf-8") as f:
            json.dump(result, f, indent=2)
        print(json.dumps(result, indent=2))
        print(f"[PLOT] Saved: {out_plot}")
        print(f"[JSON] Saved: {out_json}")

    print("\nDone.")

if __name__ == "__main__":
    main()


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tgondal0\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tgondal0\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "EIS_TEST_FILES": [
    "Mazda-Battery-Cell1.xlsx"
  ],
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "TEST_FRAC": 0.2,
  "GROUP_KFOLDS": 0,
  "RANDOM_STATE": 42,
  "USE_PCA_SOC": true,
  "USE_PCA_SOH": false,
  "PCA_SOC_COMPONENTS": 25,
  "PCA_SOH_COMPONENTS": 30,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_TAU_MAX": 10000.0,
  "DRT_LAMBDA": 0.01,
  "REFINE_SOH_WITH_CAPACITY": true,
  "MAX_GPR_TRAIN_

Loading training spectra: 100%|██████████| 360/360 [00:01<00:00, 227.79it/s]


[DATA] SoH range: 80.46 – 100.00 (var=52.028)
[TRAIN] Samples=360 Features=173 Cells=24




[SoC] GPR_raw:   R2=0.977 RMSE=4.91
[SoC] HGB_raw:   R2=0.819 RMSE=13.90
[SoC] ShapeGP:  R2=0.979 RMSE=4.78
[SoC] Selected = soc_gpr_shape




[SoH] GPR_raw:  R2=0.995 RMSE=0.37
[SoH] HGB_raw:  R2=0.571 RMSE=3.61
[SoH] ShapeGP: R2=0.994 RMSE=0.43
[SoH] Selected raw model = gpr_raw
[MODEL] Saved bundle → models_eis_phase2_phys\eis_soc_soh_phys_models.joblib
{
  "soc_r2_selected": 0.978540857869718,
  "soc_rmse_selected": 4.78282819165404,
  "soh_r2_selected": 0.9953987333752046,
  "soh_rmse_selected": 0.37392841054767223
}

===== TEST: Mazda-Battery-Cell1.xlsx =====
{
  "file": "Mazda-Battery-Cell1.xlsx",
  "parsed_metadata": null,
  "predicted_SoC_percent": 48.0,
  "SoC_std_estimate": 33.12242526854917,
  "soc_model_chosen": "soc_gpr_shape",
  "SoC_probabilities": null,
  "predicted_SoH_percent": 90.3475,
  "SoH_std_estimate": 7.241065855707704,
  "raw_model_mean": 90.3475,
  "raw_model_std": 7.241980250868564,
  "shape_model_mean": 90.3475,
  "shape_model_std": 7.240151345063277,
  "cycles_per_percent_used": 20.0,
  "cycles_to_target": 806.9499999999999,
  "cycles_to_lower": 1006.9499999999999,
  "decision_threshold_percent"