In [14]:
#!/usr/bin/env python
"""
EIS → SoC (classification) + SoH (regression) unified training & inference (Phase 2 core)
----------------------------------------------------------------------------------------
Features:
  - Raw interpolated Re/Im (canonical frequency grid)
  - Basic magnitude stats
  - F-features (peak & heuristic points)
  - Physical features (Rs, Rct, tau_peak, Warburg proxy, etc.)
  - Band statistics over log-frequency ranges
  - Differential log-frequency slopes
  - DRT-based regularized gamma summary features
  - Temperature feature (stored as Temp_feat to avoid collision)

Enhancements vs earlier snippet:
  * Robust recursive to_jsonable to print config (fixes Path JSON error)
  * Optional variance inflation when SoH variance is extremely low
  * Clear training / validation split & logging
  * Graceful handling of missing capacity refinement
"""

from __future__ import annotations
import re, json, math, joblib, random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Tuple, List, Optional, Any

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy.interpolate import interp1d
from scipy import linalg

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    mean_squared_error, r2_score
)
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupKFold

# =========================
# 1. CONFIG
# =========================
@dataclass
class Config:
    EIS_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")

    # Frequency grid
    F_MIN: float = 1e-2
    F_MAX: float = 1e4
    N_FREQ: int = 60

    # Split / CV
    TEST_FRAC: float = 0.2
    GROUP_KFOLDS: int = 0
    RANDOM_STATE: int = 42

    # PCA
    USE_PCA: bool = True
    PCA_COMPONENTS: int = 25

    # GP limit
    MAX_GPR_TRAIN_SAMPLES: int = 2500

    # Feature toggles
    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True

    # DRT
    DRT_POINTS: int = 60
    DRT_TAU_MIN: float = 1e-4
    DRT_TAU_MAX: float = 1e4
    DRT_LAMBDA: float = 1e-2

    # Capacity refinement
    REFINE_SOH_WITH_CAPACITY: bool = True

    # Saving & meta
    SAVE_FEATURE_TABLE: bool = True
    FEATURE_VERSION: int = 4
    VERBOSE: bool = True
    FORCE_RETRAIN: bool = True          # set False to reuse existing bundle
    LOW_SOH_VAR_EPS: float = 0.0        # set >0 (e.g. 0.5) to inflate very low variance labels slightly

cfg = Config()
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# 2. UTILITIES
# =========================
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path):
        return str(x)
    if isinstance(x, dict):
        return {k: to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list, tuple)):
        return [to_jsonable(i) for i in x]
    return x

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# =========================
# 3. REGEX
# =========================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

def parse_eis_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = EIS_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "SOC": int(d["SOC"]),
        "Temp": int(d["Temp"]),
        "RealSOH_file": int(d["RealSOH"]) / 100.0
    }

def parse_cap_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = CAP_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "Temp": int(d["Temp"]),
        "CycleIndex": int(d["Cycle"])
    }

# =========================
# 4. LOADERS
# =========================
def _find_matrix(mat_dict: dict):
    for v in mat_dict.values():
        if isinstance(v, np.ndarray) and v.ndim==2 and v.shape[1]>=3 and v.shape[0]>=10:
            return v
    return None

def _interp_channel(freq_raw, y_raw, freq_target):
    freq_raw = np.asarray(freq_raw).astype(float)
    y_raw = np.asarray(y_raw).astype(float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw = freq_raw[::-1]; y_raw = y_raw[::-1]
    uniq, idx = np.unique(freq_raw, return_index=True)
    if len(uniq)!=len(freq_raw):
        order = np.argsort(idx)
        freq_raw = uniq[order]; y_raw = y_raw[idx][order]
    f = interp1d(freq_raw, y_raw, bounds_error=False,
                 fill_value=(y_raw[0], y_raw[-1]), kind="linear")
    return f(freq_target)

# =========================
# 5. FEATURE COMPONENTS
# =========================
def compute_F_features(freq, re_i, im_i):
    neg_im=-im_i
    idx_peak = int(np.argmax(neg_im))
    F1=re_i[0]; F2=re_i[idx_peak]; F3=re_i[-1]
    sc=np.where(np.sign(im_i[:-1]) != np.sign(im_i[1:]))[0]
    if len(sc):
        k=sc[0]; y0,y1=im_i[k],im_i[k+1]; w=-y0/(y1-y0+1e-12)
        F4 = re_i[k] + w*(re_i[k+1]-re_i[k])
    else:
        F4 = np.nan
    F5 = (re_i[idx_peak]-F1) if idx_peak>0 else np.nan
    F6 = np.min(im_i)
    mid_target=10.0
    idx_mid = int(np.argmin(np.abs(freq-mid_target)))
    F7 = re_i[idx_mid]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES = [
    "Rs","Rct","tau_peak","warburg_sigma","arc_quality",
    "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"
]

def physical_features(freq, re_i, im_i):
    freq=np.asarray(freq); re_i=np.asarray(re_i); im_i=np.asarray(im_i)
    neg_im=-im_i
    idx_peak=int(np.argmax(neg_im))
    Rs=float(re_i[0]); Rpeak=float(re_i[idx_peak]); Rlow=float(re_i[-1])
    Rct=max(Rpeak - Rs,0.0)
    arc_diam=Rlow - Rs
    norm_arc = arc_diam/(Rs+1e-9)
    f_peak=float(freq[idx_peak])
    tau_peak = 1/(2*math.pi*f_peak) if f_peak>0 else np.nan
    K=min(10,len(freq)//3)
    if K>=4:
        w_section=(2*np.pi*freq[-K:])**(-0.5)
        re_section=re_i[-K:]
        warburg_sigma=float(np.polyfit(w_section, re_section,1)[0]) if len(np.unique(w_section))>2 else np.nan
    else:
        warburg_sigma=np.nan
    phase=np.arctan2(-im_i, re_i)
    mid=(freq>=1)&(freq<=100)
    if mid.sum()>2:
        phase_mean_mid=float(phase[mid].mean())
        phase_std_mid=float(phase[mid].std())
    else:
        phase_mean_mid=np.nan; phase_std_mid=np.nan
    phase_min=float(phase.min())
    lf_mask=(freq<=1.0)
    if lf_mask.sum()>=4:
        x=np.log10(freq[lf_mask]+1e-12); y=neg_im[lf_mask]
        lf_slope=np.polyfit(x,y,1)[0]
    else:
        lf_slope=np.nan
    arc_quality=(neg_im.max()-neg_im.min())/(abs(neg_im.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_quality,
            phase_mean_mid,phase_std_mid,phase_min,lf_slope,norm_arc]

BANDS=[(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq, re_i, im_i):
    feats=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m], im_i[m])
            feats+=[z.mean(), z.std()]
        else:
            feats+=[np.nan,np.nan]
    return feats

def diff_slopes(freq, re_i, im_i, segments=5):
    logf=np.log10(freq)
    edges=np.linspace(logf.min(), logf.max(), segments+1)
    out=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            x=logf[m]
            out += [np.polyfit(x,re_i[m],1)[0], np.polyfit(x,(-im_i)[m],1)[0]]
        else:
            out += [np.nan, np.nan]
    return out

DRT_FEATURE_NAMES=[
    "drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
    "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"
]

def compute_drt(freq,re_i,im_i,tau_min,tau_max,n_tau,lam):
    w=2*np.pi*freq
    tau=np.geomspace(tau_max,tau_min,n_tau)
    WT=w[:,None]*tau[None,:]
    denom=1+WT**2
    K_re=1.0/denom
    K_im=-WT/denom
    R_inf=re_i[0]
    y_re=re_i - R_inf
    y_im=im_i
    Y=np.concatenate([y_re,y_im])
    K=np.vstack([K_re,K_im])
    A=K.T@K + lam*np.eye(n_tau)
    b=K.T@Y
    gamma=linalg.solve(A,b,assume_a='pos')
    gamma=np.clip(gamma,0,None)
    return tau,gamma

def drt_features(freq,re_i,im_i):
    if not cfg.INCLUDE_DRT:
        return []
    try:
        tau,gamma=compute_drt(freq,re_i,im_i,
                               cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,
                               cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        log_tau=np.log10(tau)
        g_sum=gamma.sum()+1e-12
        w_norm=gamma/g_sum
        mean_logtau=float((w_norm*log_tau).sum())
        var_logtau=float((w_norm*(log_tau-mean_logtau)**2).sum())
        p=int(np.argmax(gamma))
        peak_tau=float(tau[p]); peak_gamma=float(gamma[p])
        mid=np.median(log_tau)
        frac_low=float(w_norm[log_tau<=mid].sum())
        frac_high=1-frac_low
        return [g_sum,mean_logtau,var_logtau,peak_tau,peak_gamma,frac_low,frac_high]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i, im_i, temp, freq, include_manifest=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts += [re_i, im_i]
        names += [f"Re_{i}" for i in range(len(re_i))] + [f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z=np.hypot(re_i, im_i)
        basics=[re_i[0], re_i[-1], re_i[-1]-re_i[0], z.max(), z.mean(), z.std()]
        parts.append(np.array(basics)); names += ["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        Ff=compute_F_features(freq,re_i,im_i)
        parts.append(np.array(Ff)); names += [f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        Pf=physical_features(freq,re_i,im_i)
        parts.append(np.array(Pf)); names += PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        Bf=band_stats(freq,re_i,im_i)
        parts.append(np.array(Bf))
        for bi in range(len(BANDS)):
            names += [f"band{bi}_mean", f"band{bi}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        Ds=diff_slopes(freq,re_i,im_i)
        parts.append(np.array(Ds))
        for i in range(len(Ds)//2):
            names += [f"slope_re_seg{i}", f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        Df=drt_features(freq,re_i,im_i)
        parts.append(np.array(Df)); names += DRT_FEATURE_NAMES
    parts.append(np.array([temp])); names += ["Temp_feat"]
    vec=np.concatenate(parts).astype(float)
    vec=np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    if include_manifest:
        return vec, names
    return vec

# =========================
# 6. CAPACITY
# =========================
def load_capacity_info(cap_dir: Path) -> pd.DataFrame:
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta=parse_cap_metadata(fp.stem)
        if not meta: continue
        try:
            mat=loadmat(fp); arr=_find_matrix(mat)
            if arr is None: continue
            col=np.argmax(np.abs(arr[-50:, :]).mean(axis=0))
            cap=float(np.nanmax(arr[:,col]))
            meta["MeasuredCapacity_Ah"]=cap
            recs.append(meta)
        except Exception:
            pass
    df=pd.DataFrame(recs)
    if df.empty: return df
    ref=df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"]=df["MeasuredCapacity_Ah"]/ref
    return df

# =========================
# 7. DATASET
# =========================
def load_single_eis(fp: Path):
    meta=parse_eis_metadata(fp.stem)
    if meta is None:
        raise ValueError(f"Bad filename: {fp.name}")
    mat=loadmat(fp); arr=_find_matrix(mat)
    if arr is None: raise ValueError(f"No EIS matrix in {fp.name}")
    freq_raw=arr[:,0].astype(float); re_raw=arr[:,1].astype(float); im_raw=arr[:,2].astype(float)
    re_i=_interp_channel(freq_raw, re_raw, CANON_FREQ)
    im_i=_interp_channel(freq_raw, im_raw, CANON_FREQ)
    vec=build_feature_vector(re_i, im_i, meta["Temp"], CANON_FREQ)
    return vec, meta

def build_dataset(eis_dir: Path, cap_df: Optional[pd.DataFrame]):
    files=sorted(eis_dir.rglob("*.mat"))
    if not files: raise FileNotFoundError(f"No .mat files in {eis_dir}")
    # Feature names
    arr0=_find_matrix(loadmat(files[0]))
    f0=arr0[:,0].astype(float); r0=arr0[:,1].astype(float); i0=arr0[:,2].astype(float)
    re0=_interp_channel(f0,r0,CANON_FREQ); im0=_interp_channel(f0,i0,CANON_FREQ)
    _, feature_names = build_feature_vector(re0, im0, 25.0, CANON_FREQ, include_manifest=True)

    feats=[]; rows=[]
    for fp in tqdm(files, desc="Loading spectra"):
        try:
            v,m=load_single_eis(fp)
            feats.append(v); rows.append(m)
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")
    if not rows: raise RuntimeError("No valid spectra")
    X=np.vstack(feats)
    meta_df=pd.DataFrame(rows)

    # SoH refinement
    if cap_df is not None and not cap_df.empty and cfg.REFINE_SOH_WITH_CAPACITY:
        lookup=cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        refined=[]
        for cid,stage,fallback in zip(meta_df.CellID, meta_df.SOH_stage, meta_df.RealSOH_file):
            nc=lookup.get((cid,stage))
            refined.append(100.0*nc if nc is not None else fallback)
        meta_df["SoH_cont"]=refined
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    y_soc=meta_df["SOC"].values
    y_soh=meta_df["SoH_cont"].values.astype(float)

    soh_var=float(np.var(y_soh))
    if cfg.VERBOSE:
        print(f"[DATA] SoH stats: min={y_soh.min():.2f} max={y_soh.max():.2f} mean={y_soh.mean():.2f} var={soh_var:.4f}")
    if soh_var < 1.0 and cfg.LOW_SOH_VAR_EPS > 0:
        # mild variance inflation (center + noise) — optional
        rng=np.random.default_rng(cfg.RANDOM_STATE+7)
        noise=rng.normal(0, cfg.LOW_SOH_VAR_EPS, size=y_soh.shape)
        y_soh = y_soh + noise
        if cfg.VERBOSE:
            print(f"[INFO] Applied variance inflation (epsilon={cfg.LOW_SOH_VAR_EPS}). New var={np.var(y_soh):.4f}")

    # Save feature table
    if cfg.SAVE_FEATURE_TABLE:
        feature_df=pd.DataFrame(X, columns=feature_names)
        dup = set(meta_df.columns).intersection(feature_df.columns)
        if dup:
            feature_df=feature_df.rename(columns={c: f"{c}_feat" for c in dup})
        pd.concat([meta_df.reset_index(drop=True), feature_df], axis=1)\
          .to_parquet(cfg.MODEL_DIR/"training_features.parquet", index=False)

    return meta_df, X, y_soc, y_soh, feature_names

# =========================
# 8. SPLIT
# =========================
def simple_cell_split(meta_df: pd.DataFrame):
    cells=meta_df.CellID.unique()
    rng=np.random.default_rng(cfg.RANDOM_STATE)
    n_test=max(1,int(len(cells)*cfg.TEST_FRAC))
    test_cells=rng.choice(cells,size=n_test,replace=False)
    return meta_df.CellID.isin(test_cells)

# =========================
# 9. TRAIN
# =========================
def train_models(X,y_soc,y_soh,meta_df,feature_names):
    scaler=StandardScaler()
    Xs=scaler.fit_transform(X)
    pca=None; Xm=Xs
    if cfg.USE_PCA:
        pca=PCA(n_components=min(cfg.PCA_COMPONENTS, Xs.shape[1]-1),
                random_state=cfg.RANDOM_STATE)
        Xm=pca.fit_transform(Xs)

    def fit_once(Xt, y_soc_t, y_soh_t):
        soc=RandomForestClassifier(
            n_estimators=600, min_samples_leaf=2, class_weight='balanced',
            n_jobs=-1, random_state=cfg.RANDOM_STATE
        )
        soc.fit(Xt, y_soc_t)
        dim=Xt.shape[1]
        kernel=RBF(length_scale=np.ones(dim),
                   length_scale_bounds=(1e-2,1e3)) + \
               WhiteKernel(noise_level=1e-3,
                           noise_level_bounds=(1e-6,1e-1))
        gpr=GaussianProcessRegressor(
            kernel=kernel, alpha=0.0, normalize_y=True,
            random_state=cfg.RANDOM_STATE, n_restarts_optimizer=3
        )
        if Xt.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
            idx=np.random.default_rng(cfg.RANDOM_STATE).choice(
                Xt.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
            gpr.fit(Xt[idx], y_soh_t[idx])
        else:
            gpr.fit(Xt, y_soh_t)
        return soc,gpr

    if cfg.GROUP_KFOLDS and cfg.GROUP_KFOLDS>1:
        gkf=GroupKFold(n_splits=cfg.GROUP_KFOLDS)
        groups=meta_df.CellID.values
        r2s=[]; fold_models=[]
        for i,(tr,te) in enumerate(gkf.split(Xm,y_soc,groups)):
            soc,gpr=fit_once(Xm[tr], y_soc[tr], y_soh[tr])
            pred=gpr.predict(Xm[te]); r2=r2_score(y_soh[te], pred)
            r2s.append(r2); fold_models.append((soc,gpr))
            if cfg.VERBOSE: print(f"[Fold {i}] SoH R2={r2:.3f}")
        best=int(np.argmax(r2s))
        soc_model,soh_model=fold_models[best]
        metrics={"cv_soh_r2_mean":float(np.mean(r2s)),"chosen_fold":best}
    else:
        mask_test=simple_cell_split(meta_df)
        Xt,Xv = Xm[~mask_test], Xm[mask_test]
        yst,ysv = y_soc[~mask_test], y_soc[mask_test]
        yht,yhv = y_soh[~mask_test], y_soh[mask_test]
        soc_model,soh_model=fit_once(Xt, yst, yht)
        soc_pred=soc_model.predict(Xv)
        acc=accuracy_score(ysv,soc_pred); f1m=f1_score(ysv,soc_pred,average='macro')
        soh_pred=soh_model.predict(Xv)
        rmse=math.sqrt(mean_squared_error(yhv,soh_pred)); r2=r2_score(yhv,soh_pred)
        if cfg.VERBOSE:
            print("[SoC] holdout classification report:")
            print(classification_report(ysv,soc_pred,digits=4))
            print(f"[SoH] holdout RMSE={rmse:.3f} R2={r2:.3f}")
        metrics={"soc_accuracy":acc,"soc_macro_f1":f1m,"soh_rmse":rmse,"soh_r2":r2}

    bundle={
        "scaler":scaler,"pca":pca,
        "soc_model":soc_model,"soh_model":soh_model,
        "freq_grid":CANON_FREQ,"feature_version":cfg.FEATURE_VERSION,
        "feature_manifest":feature_names,
        "config":to_jsonable(asdict(cfg)),
        "metrics":metrics
    }
    out_path=cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    joblib.dump(bundle,out_path)
    if cfg.VERBOSE:
        print(f"Saved model bundle → {out_path}")
        print("Metrics:", json.dumps(metrics, indent=2))
    return bundle

# =========================
# 10. INFERENCE
# =========================
def load_bundle()->dict:
    return joblib.load(cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib")

def featurize_new_eis(file_path:Path, bundle=None):
    if bundle is None:
        bundle=load_bundle()
    freq_grid=bundle["freq_grid"]
    meta=parse_eis_metadata(file_path.stem)
    mat=loadmat(file_path); arr=_find_matrix(mat)
    if arr is None: raise ValueError("No valid EIS matrix in file.")
    fr=arr[:,0].astype(float); re_raw=arr[:,1].astype(float); im_raw=arr[:,2].astype(float)
    re_i=_interp_channel(fr,re_raw,freq_grid)
    im_i=_interp_channel(fr,im_raw,freq_grid)
    temp=meta["Temp"] if meta else -1
    vec=build_feature_vector(re_i, im_i, temp, freq_grid)
    return vec, meta

def predict_from_file(file_path:Path)->Dict[str,Any]:
    bundle=load_bundle()
    scaler=bundle["scaler"]; pca=bundle["pca"]
    soc_model=bundle["soc_model"]; soh_model=bundle["soh_model"]
    feat_vec, meta=featurize_new_eis(file_path, bundle)
    Xs=scaler.transform(feat_vec.reshape(1,-1))
    Xp=pca.transform(Xs) if pca is not None else Xs
    soc_probs=soc_model.predict_proba(Xp)[0]
    soc_classes=soc_model.classes_
    soc_pred=int(soc_classes[np.argmax(soc_probs)])
    soh_mean_arr, soh_std_arr=soh_model.predict(Xp, return_std=True)
    return {
        "file": file_path.name,
        "metadata": meta,
        "predicted_SoC": soc_pred,
        "SoC_probabilities": {int(c): float(p) for c,p in zip(soc_classes, soc_probs)},
        "predicted_SoH_percent": float(soh_mean_arr[0]),
        "SoH_std_estimate": float(soh_std_arr[0]),
        "feature_version": bundle.get("feature_version")
    }

# =========================
# 11. MAIN
# =========================
def main():
    if cfg.VERBOSE:
        print("Configuration:\n", json.dumps(to_jsonable(asdict(cfg)), indent=2))
    assert cfg.EIS_DIR.exists(), f"EIS_DIR missing: {cfg.EIS_DIR}"
    if cfg.REFINE_SOH_WITH_CAPACITY:
        assert cfg.CAP_DIR.exists(), f"CAP_DIR missing: {cfg.CAP_DIR}"

    bundle_path=cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    if bundle_path.exists() and not cfg.FORCE_RETRAIN:
        if cfg.VERBOSE: print(f"[LOAD] Reusing existing model: {bundle_path}")
    else:
        if cfg.VERBOSE: print("[TRAIN] Building dataset & fitting models...")
        cap_df=load_capacity_info(cfg.CAP_DIR)
        if cfg.REFINE_SOH_WITH_CAPACITY and cap_df.empty and cfg.VERBOSE:
            print("[INFO] Capacity file set empty → using filename RealSOH only.")
        meta_df, X, y_soc, y_soh, feature_names = build_dataset(cfg.EIS_DIR, cap_df)
        if cfg.VERBOSE:
            print(f"Training set: {X.shape[0]} spectra | feat_dim={X.shape[1]} | cells={meta_df.CellID.nunique()}")
        train_models(X,y_soc,y_soh,meta_df,feature_names)

    # Demonstration prediction on first file
    sample = sorted(cfg.EIS_DIR.rglob("*.mat"))[0]
    demo = predict_from_file(sample)
    print("\nExample prediction:\n", json.dumps(demo, indent=2))

if __name__ == "__main__":
    main()


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "TEST_FRAC": 0.2,
  "GROUP_KFOLDS": 0,
  "RANDOM_STATE": 42,
  "USE_PCA": true,
  "PCA_COMPONENTS": 25,
  "MAX_GPR_TRAIN_SAMPLES": 2500,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_TAU_MAX": 10000.0,
  "DRT_LAMBDA": 0.01,
  "REFINE_SOH_WITH_CAPACITY": true,
  "SAVE_FEATURE_TABLE": true,
  "FEATURE_VERSION": 4,
  "VERBOSE": true,
  "FORCE_RETRAIN": true,
  "LOW_SOH

Loading spectra: 100%|██████████| 360/360 [00:01<00:00, 215.31it/s]


[DATA] SoH stats: min=80.46 max=100.00 mean=90.35 var=52.0277
Training set: 360 spectra | feat_dim=173 | cells=24




[SoC] holdout classification report:
              precision    recall  f1-score   support

           5     1.0000    0.9167    0.9565        12
          20     1.0000    1.0000    1.0000        12
          50     0.9231    1.0000    0.9600        12
          70     1.0000    1.0000    1.0000        12
          95     1.0000    1.0000    1.0000        12

    accuracy                         0.9833        60
   macro avg     0.9846    0.9833    0.9833        60
weighted avg     0.9846    0.9833    0.9833        60

[SoH] holdout RMSE=1.066 R2=0.963


FileNotFoundError: [Errno 2] No such file or directory: 'models_eis_phase2_phys\\eis_soc_soh_phys_models.joblib'

In [13]:
#!/usr/bin/env python
"""
Optimized Unified EIS Training + Inference + Dynamic RUL (v7 with Legacy Bundle Compatibility)
----------------------------------------------------------------------------------------------
Features:
  * Extensive impedance feature set: raw Re/Im, basics, F-features, physical (Rs, Rct, etc.),
    DRT-derived descriptors, band statistics, log-frequency differential slopes, temperature.
  * Optional shape-normalized GP (normalizes spectrum by high-frequency Re) + ensemble with raw model.
  * Multiple SoH regressors (Gaussian Process, HistGradientBoosting); automatic selection by validation R².
  * Dynamic cycles-per-percent (Cpp) estimation from capacity .mat files (rolling linear fit); fallback heuristic.
  * Multi-format inference (.mat, .csv, .xls, .xlsx) with column auto-detection & freq interpolation.
  * Out-of-distribution (OOD) diagnostics: Mahalanobis distance in scaled feature space + GP ARD norm heuristic.
  * Backward compatibility loader to use older bundles (with keys: "scaler","pca","soc_model","soh_model").
  * Per-test JSON output + projection plot annotated with OOD flag.
  * Configurable PCA for SoC and SoH pipelines (decoupled).

Outputs in MODEL_DIR:
  - eis_soc_soh_phys_models.joblib
  - training_features.parquet (if SAVE_FEATURE_TABLE)
  - <TestFile>_prediction.json
  - <TestFile>_projection.png
"""

from __future__ import annotations
import re, json, math, random, warnings, joblib
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy import linalg
from scipy.interpolate import interp1d

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    mean_squared_error, r2_score
)
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupKFold

import matplotlib.pyplot as plt

# =========================
# 1. CONFIGURATION
# =========================
@dataclass
class Config:
    # Training data directories (update if needed)
    EIS_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")

    # Test files (multi-format allowed)
    EIS_TEST_FILES: List[Path] = None  # assigned after instantiation

    # Frequency interpolation grid
    F_MIN: float = 1e-2
    F_MAX: float = 1e4
    N_FREQ: int = 60

    # Train / split settings
    TEST_FRAC: float = 0.2        # fraction of cells for hold-out test (if GROUP_KFOLDS==0)
    GROUP_KFOLDS: int = 0         # >1 to enable grouped CV
    RANDOM_STATE: int = 42

    # PCA toggles (separate for SoC & SoH)
    USE_PCA_SOC: bool = True
    USE_PCA_SOH: bool = False
    PCA_SOC_COMPONENTS: int = 25
    PCA_SOH_COMPONENTS: int = 30

    # Feature group toggles
    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True

    # DRT params
    DRT_POINTS: int = 60
    DRT_TAU_MIN: float = 1e-4
    DRT_TAU_MAX: float = 1e4
    DRT_LAMBDA: float = 1e-2

    # Capacity-based refinement
    REFINE_SOH_WITH_CAPACITY: bool = True

    # SoH modeling
    MAX_GPR_TRAIN_SAMPLES: int = 3500
    INCLUDE_NORMALIZED_SHAPE_MODEL: bool = True
    ENSEMBLE_SOH: bool = True
    NORMALIZE_SHAPE_BY_HF_RE: bool = True

    # RUL parameters
    DECISION_SOH_PERCENT: float = 50.0
    ILLUSTRATIVE_MIN_SOH: float = 40.0
    CPP_ROLLING_WINDOW: int = 5
    CPP_MIN_POINTS: int = 6
    CPP_FALLBACK: float = 20.0  # fallback cycles-per-percent

    # Inference extras
    TEST_TEMPERATURE_OVERRIDE: Optional[float] = 25.0  # applied if metadata absent
    FORCE_RETRAIN: bool = False  # force retraining even if bundle exists

    # Saving / logging
    SAVE_FEATURE_TABLE: bool = True
    VERBOSE: bool = True
    FEATURE_VERSION: int = 7

    # OOD thresholds
    MAHAL_THRESHOLD: float = 10.0
    GP_ARD_NORM_THRESHOLD: float = 6.0

    # Projection curve shape exponent
    PLOT_EXPONENT: float = 1.25

cfg = Config()
if cfg.EIS_TEST_FILES is None:
    cfg.EIS_TEST_FILES = [
        Path("Mazda-Battery-Cell1.xlsx"),
        Path("Mazda-Battery-Cell2.xlsx")
    ]
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# 2. UTILITIES
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)

set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path): return str(x)
    if isinstance(x, dict): return {k: to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list, tuple)): return [to_jsonable(i) for i in x]
    return x

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# =========================
# 3. REGEX
# =========================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

# =========================
# 4. PARSERS
# =========================
def parse_eis_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = EIS_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "SOC": int(d["SOC"]),
        "Temp": int(d["Temp"]),
        "RealSOH_file": int(d["RealSOH"])/100.0
    }

def parse_cap_metadata(stem: str) -> Optional[Dict[str, Any]]:
    m = CAP_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "Temp": int(d["Temp"]),
        "CycleIndex": int(d["Cycle"])
    }

# =========================
# 5. LOW-LEVEL LOADERS / INTERPOLATION
# =========================
def _find_matrix(mat_dict: dict):
    for v in mat_dict.values():
        if isinstance(v, np.ndarray) and v.ndim == 2 and v.shape[1] >= 3 and v.shape[0] >= 10:
            return v
    return None

def _interp_channel(freq_raw, y_raw, freq_target):
    freq_raw = np.asarray(freq_raw).astype(float)
    y_raw = np.asarray(y_raw).astype(float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw = freq_raw[::-1]; y_raw = y_raw[::-1]
    uniq, idx = np.unique(freq_raw, return_index=True)
    if len(uniq) != len(freq_raw):
        order = np.argsort(idx)
        freq_raw = uniq[order]; y_raw = y_raw[idx][order]
    f = interp1d(freq_raw, y_raw, bounds_error=False,
                 fill_value=(y_raw[0], y_raw[-1]), kind="linear")
    return f(freq_target)

FREQ_CANDS = ["frequency","freq","f","hz","frequency(hz)","Frequency(Hz)"]
RE_CANDS   = ["zreal","re(z)","re","real","z_re","zreal(ohm)","re (ohm)","re(z) (ohm)","Zreal","Zreal (ohm)","Zreal(ohm)"]
IM_CANDS   = ["-zimag","zimag","im(z)","im","imag","imaginary","z_im","zimg","z_imag"," -Zimag (ohm)"," -Zimag(ohm)","-Zimag","Zimag","Zimag (ohm)"]

def _select_column(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    low = {c.lower(): c for c in df.columns}
    for c in cands:
        if c.lower() in low: return low[c.lower()]
    for c in cands:
        for col in df.columns:
            if c.lower() in col.lower():
                return col
    return None

def load_mat_eis(path: Path):
    mat = loadmat(path)
    arr = _find_matrix(mat)
    if arr is None:
        raise ValueError(f"No valid EIS matrix in {path.name}")
    return arr[:,0].astype(float), arr[:,1].astype(float), arr[:,2].astype(float)

def load_table_eis(path: Path):
    if path.suffix.lower() == ".csv":
        df = pd.read_csv(path)
    else:
        df = pd.read_excel(path)
    if df.empty:
        raise ValueError("Empty table.")
    fcol = _select_column(df, FREQ_CANDS)
    recol = _select_column(df, RE_CANDS)
    imcol = _select_column(df, IM_CANDS)
    if recol is None or imcol is None:
        raise ValueError(f"Missing Re/Im columns in {path.name}")
    re_vals = pd.to_numeric(df[recol], errors="coerce").to_numpy()
    im_vals = pd.to_numeric(df[imcol], errors="coerce").to_numpy()
    if fcol is not None:
        freq_vals = pd.to_numeric(df[fcol], errors="coerce").to_numpy()
    else:
        n = min(len(re_vals), len(im_vals))
        freq_vals = np.geomspace(cfg.F_MAX, cfg.F_MIN, n)
    n = min(len(freq_vals), len(re_vals), len(im_vals))
    freq_vals = freq_vals[:n]; re_vals = re_vals[:n]; im_vals = im_vals[:n]
    # Imag sign normalization (prefer negative semicircle)
    if np.nanmean(im_vals) > 0:
        im_vals = -im_vals
    return freq_vals, re_vals.astype(float), im_vals.astype(float)

def load_any_inference(path: Path):
    suf = path.suffix.lower()
    if suf == ".mat": return load_mat_eis(path)
    if suf in (".csv",".xls",".xlsx"): return load_table_eis(path)
    raise ValueError(f"Unsupported test file extension: {suf}")

# =========================
# 6. FEATURE ENGINEERING
# =========================
def compute_F_features(freq, re_i, im_i):
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    F1 = re_i[0]; F2 = re_i[idx_peak]; F3 = re_i[-1]
    sc = np.where(np.sign(im_i[:-1]) != np.sign(im_i[1:]))[0]
    if len(sc):
        k = sc[0]; y0,y1 = im_i[k], im_i[k+1]
        w = -y0/(y1 - y0 + 1e-12)
        F4 = re_i[k] + w*(re_i[k+1]-re_i[k])
    else:
        F4 = np.nan
    F5 = (re_i[idx_peak]-F1) if idx_peak>0 else np.nan
    F6 = np.min(im_i)
    mid_target = 10.0
    idx_mid = int(np.argmin(np.abs(freq-mid_target)))
    F7 = re_i[idx_mid]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES = [
    "Rs","Rct","tau_peak","warburg_sigma","arc_quality",
    "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"
]

def physical_features(freq, re_i, im_i):
    freq = np.asarray(freq); re_i = np.asarray(re_i); im_i = np.asarray(im_i)
    neg_im = -im_i
    idx_peak = int(np.argmax(neg_im))
    Rs = float(re_i[0]); Rpeak = float(re_i[idx_peak]); Rlow = float(re_i[-1])
    Rct = max(Rpeak - Rs, 0.0)
    arc_diam = Rlow - Rs
    norm_arc = arc_diam / (Rs + 1e-9)
    f_peak = float(freq[idx_peak])
    tau_peak = 1.0/(2*math.pi*f_peak) if f_peak>0 else np.nan
    K = min(10, len(freq)//3)
    if K >= 4:
        w_section = (2*np.pi*freq[-K:])**(-0.5)
        re_section = re_i[-K:]
        if len(np.unique(w_section)) > 2:
            warburg_sigma = float(np.polyfit(w_section, re_section, 1)[0])
        else:
            warburg_sigma = np.nan
    else:
        warburg_sigma = np.nan
    phase = np.arctan2(-im_i, re_i)
    mid_mask = (freq>=1) & (freq<=100)
    if mid_mask.sum()>2:
        phase_mean_mid = float(phase[mid_mask].mean())
        phase_std_mid  = float(phase[mid_mask].std())
    else:
        phase_mean_mid = np.nan; phase_std_mid = np.nan
    phase_min = float(phase.min())
    lf_mask = (freq<=1.0)
    if lf_mask.sum() >= 4:
        x = np.log10(freq[lf_mask]+1e-12); y = neg_im[lf_mask]
        lf_slope = np.polyfit(x, y, 1)[0]
    else:
        lf_slope = np.nan
    arc_quality = (neg_im.max() - neg_im.min())/(abs(neg_im.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_quality,
            phase_mean_mid,phase_std_mid,phase_min,lf_slope,norm_arc]

BANDS = [(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq, re_i, im_i):
    feats=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m], im_i[m])
            feats += [z.mean(), z.std()]
        else:
            feats += [np.nan, np.nan]
    return feats

def diff_slopes(freq, re_i, im_i, segments=5):
    logf = np.log10(freq)
    edges = np.linspace(logf.min(), logf.max(), segments+1)
    out=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            x=logf[m]
            out += [np.polyfit(x,re_i[m],1)[0], np.polyfit(x,(-im_i)[m],1)[0]]
        else:
            out += [np.nan, np.nan]
    return out

DRT_FEATURE_NAMES = [
    "drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
    "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"
]

def compute_drt(freq,re_i,im_i,tau_min,tau_max,n_tau,lam):
    w = 2*np.pi*freq
    tau = np.geomspace(tau_max, tau_min, n_tau)
    WT = w[:,None]*tau[None,:]
    denom = 1+WT**2
    K_re = 1.0/denom
    K_im = -WT/denom
    R_inf = re_i[0]
    y_re = re_i - R_inf
    y_im = im_i
    Y = np.concatenate([y_re, y_im])
    K = np.vstack([K_re, K_im])
    A = K.T @ K + lam*np.eye(n_tau)
    b = K.T @ Y
    gamma = linalg.solve(A,b,assume_a='pos')
    gamma = np.clip(gamma,0,None)
    return tau, gamma

def drt_features(freq,re_i,im_i):
    try:
        tau,gamma = compute_drt(freq,re_i,im_i,
                                 cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,
                                 cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        log_tau = np.log10(tau)
        g_sum = gamma.sum()+1e-12
        w_norm = gamma/g_sum
        mean_logtau = float((w_norm*log_tau).sum())
        var_logtau  = float((w_norm*(log_tau-mean_logtau)**2).sum())
        p = int(np.argmax(gamma))
        peak_tau = float(tau[p]); peak_gamma=float(gamma[p])
        mid = np.median(log_tau)
        frac_low = float(w_norm[log_tau<=mid].sum())
        frac_high = 1-frac_low
        return [g_sum,mean_logtau,var_logtau,peak_tau,peak_gamma,frac_low,frac_high]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i, im_i, temp, freq, include_names=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts += [re_i, im_i]
        names += [f"Re_{i}" for i in range(len(re_i))] + [f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z = np.hypot(re_i, im_i)
        basics=[re_i[0], re_i[-1], re_i[-1]-re_i[0], z.max(), z.mean(), z.std()]
        parts.append(np.array(basics)); names += ["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        Ff=compute_F_features(freq,re_i,im_i); parts.append(np.array(Ff)); names += [f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        Pf=physical_features(freq,re_i,im_i); parts.append(np.array(Pf)); names += PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        Bf=band_stats(freq,re_i,im_i); parts.append(np.array(Bf))
        for bi in range(len(BANDS)): names += [f"band{bi}_mean", f"band{bi}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        Ds=diff_slopes(freq,re_i,im_i); parts.append(np.array(Ds))
        for i in range(len(Ds)//2): names += [f"slope_re_seg{i}", f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        Df=drt_features(freq,re_i,im_i); parts.append(np.array(Df)); names += DRT_FEATURE_NAMES
    parts.append(np.array([temp])); names += ["Temp"]
    vec = np.concatenate(parts).astype(float)
    vec = np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    if include_names: return vec, names
    return vec

def build_shape_normalized(re_i, im_i):
    hf = re_i[0] if re_i[0] != 0 else 1.0
    return re_i / hf, im_i / hf

# =========================
# 7. CAPACITY & CPP
# =========================
def load_capacity_info(cap_dir: Path) -> pd.DataFrame:
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta = parse_cap_metadata(fp.stem)
        if not meta: continue
        try:
            mat=loadmat(fp); arr=_find_matrix(mat)
            if arr is None: continue
            col=np.argmax(np.abs(arr[-50:, :]).mean(axis=0))
            cap=float(np.nanmax(arr[:,col]))
            meta["MeasuredCapacity_Ah"]=cap
            recs.append(meta)
        except Exception:
            pass
    df=pd.DataFrame(recs)
    if df.empty: return df
    ref=df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"]=df["MeasuredCapacity_Ah"]/ref
    df["SoH_percent"]=df["NormCapacity"]*100.0
    return df

def estimate_cpp_per_cell(capacity_df: pd.DataFrame,
                          window:int, min_points:int)->Dict[str,float]:
    cpp={}
    for cid,grp in capacity_df.groupby("CellID"):
        g=grp.sort_values("CycleIndex")
        if g.shape[0]<min_points: continue
        tail=g.tail(window)
        x=tail["CycleIndex"].values.astype(float)
        y=tail["SoH_percent"].values.astype(float)
        if len(np.unique(x))<2: continue
        slope=np.polyfit(x,y,1)[0]  # SoH% / cycle
        if slope >= -1e-6:  # non-degrading
            continue
        cpp[cid]=1.0/abs(slope)
    return cpp

def build_cpp_map(cap_df: pd.DataFrame):
    if cap_df.empty: return {}, cfg.CPP_FALLBACK
    cpp_map=estimate_cpp_per_cell(
        cap_df[["CellID","CycleIndex","SoH_percent"]],
        cfg.CPP_ROLLING_WINDOW, cfg.CPP_MIN_POINTS
    )
    if not cpp_map:
        return {}, cfg.CPP_FALLBACK
    return cpp_map, float(np.median(list(cpp_map.values())))

def get_cpp(meta: dict, cpp_map: Dict[str,float], global_cpp: float):
    if not meta: return global_cpp
    return cpp_map.get(meta.get("CellID"), global_cpp)

# =========================
# 8. DATASET BUILD (training on .mat only)
# =========================
def load_single_eis_mat(fp: Path):
    meta = parse_eis_metadata(fp.stem)
    if meta is None:
        raise ValueError(f"Bad filename: {fp.name}")
    freq,re_z,im_z = load_mat_eis(fp)
    re_i=_interp_channel(freq, re_z, CANON_FREQ)
    im_i=_interp_channel(freq, im_z, CANON_FREQ)
    vec=build_feature_vector(re_i, im_i, meta["Temp"], CANON_FREQ)
    return vec, meta, re_i, im_i

def build_dataset(eis_dir: Path, cap_df: Optional[pd.DataFrame]):
    files = sorted(eis_dir.rglob("*.mat"))
    if not files:
        raise FileNotFoundError(f"No .mat spectra in {eis_dir}")

    # Feature names from first file
    f0,r0,i0 = load_mat_eis(files[0])
    re0=_interp_channel(f0,r0,CANON_FREQ); im0=_interp_channel(f0,i0,CANON_FREQ)
    _, feature_names = build_feature_vector(re0, im0, 25.0, CANON_FREQ, include_names=True)

    feats=[]; rows=[]; shape_feats=[]
    for fp in tqdm(files, desc="Loading training spectra"):
        try:
            v, m, rei, imi = load_single_eis_mat(fp)
            feats.append(v); rows.append(m)
            if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and cfg.NORMALIZE_SHAPE_BY_HF_RE:
                rsh, ish = build_shape_normalized(rei, imi)
                shape_vec = build_feature_vector(rsh, ish, m["Temp"], CANON_FREQ)
                shape_feats.append(shape_vec)
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")

    if not rows:
        raise RuntimeError("No valid training spectra after filtering.")

    X = np.vstack(feats)
    X_shape = np.vstack(shape_feats) if (cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and shape_feats) else None
    meta_df = pd.DataFrame(rows)

    # SoH refinement with capacity
    if cap_df is not None and not cap_df.empty and cfg.REFINE_SOH_WITH_CAPACITY:
        lookup = cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        refined=[]
        for cid, stage, fallback in zip(meta_df.CellID, meta_df.SOH_stage, meta_df.RealSOH_file):
            nc = lookup.get((cid, stage))
            refined.append(100.0*nc if nc is not None else fallback)
        meta_df["SoH_cont"]=refined
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    y_soc = meta_df["SOC"].values
    y_soh = meta_df["SoH_cont"].values

    soh_var = float(np.var(y_soh))
    if cfg.VERBOSE:
        print(f"[DATA] SoH range: {y_soh.min():.2f} – {y_soh.max():.2f} (var={soh_var:.3f})")
        if soh_var < 1.0:
            print("[WARN] Low SoH variance → model may output near-constant SoH.")

    if cfg.SAVE_FEATURE_TABLE:
        pd.concat(
            [meta_df.reset_index(drop=True),
             pd.DataFrame(X, columns=feature_names)], axis=1
        ).to_parquet(cfg.MODEL_DIR/"training_features.parquet", index=False)

    return meta_df, X, (X_shape, feature_names), y_soc, y_soh

# =========================
# 9. SPLITTING
# =========================
def cell_split_mask(meta_df: pd.DataFrame):
    cells = meta_df.CellID.unique()
    rng = np.random.default_rng(cfg.RANDOM_STATE)
    n_test = max(1, int(len(cells)*cfg.TEST_FRAC))
    test_cells = rng.choice(cells, size=n_test, replace=False)
    return meta_df.CellID.isin(test_cells)

# =========================
# 10. TRAINING
# =========================
def train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh):
    X_shape, feature_names = shape_bundle
    mask_test = cell_split_mask(meta_df)

    # ----- SoC pipeline -----
    soc_scaler = StandardScaler()
    X_soc_s = soc_scaler.fit_transform(X_raw)
    soc_pca = None
    X_soc_model = X_soc_s
    if cfg.USE_PCA_SOC:
        soc_pca = PCA(n_components=min(cfg.PCA_SOC_COMPONENTS, X_soc_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soc_model = soc_pca.fit_transform(X_soc_s)

    soc_model = RandomForestClassifier(
        n_estimators=600, min_samples_leaf=2, class_weight='balanced',
        n_jobs=-1, random_state=cfg.RANDOM_STATE
    )
    soc_model.fit(X_soc_model[~mask_test], y_soc[~mask_test])
    soc_pred = soc_model.predict(X_soc_model[mask_test])
    soc_acc = accuracy_score(y_soc[mask_test], soc_pred)
    soc_f1  = f1_score(y_soc[mask_test], soc_pred, average='macro')
    if cfg.VERBOSE:
        print(f"[SoC] Acc={soc_acc:.3f} MacroF1={soc_f1:.3f}")
        print(classification_report(y_soc[mask_test], soc_pred, digits=4))

    # ----- SoH raw pipeline -----
    soh_scaler = StandardScaler()
    X_soh_s = soh_scaler.fit_transform(X_raw)
    soh_pca=None
    X_soh_model = X_soh_s
    if cfg.USE_PCA_SOH:
        soh_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_soh_s.shape[1]-1),
                      random_state=cfg.RANDOM_STATE)
        X_soh_model = soh_pca.fit_transform(X_soh_s)

    candidates = {}

    # (1) Raw Gaussian Process
    dim = X_soh_model.shape[1]
    kernel = RBF(length_scale=np.ones(dim)*3.0,
                 length_scale_bounds=(1e-1,1e4)) + \
             WhiteKernel(noise_level=1e-2,
                         noise_level_bounds=(1e-6,1e-1))
    gpr = GaussianProcessRegressor(
        kernel=kernel, alpha=0.0, normalize_y=True,
        random_state=cfg.RANDOM_STATE, n_restarts_optimizer=3
    )
    if X_soh_model.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
        idx = np.random.default_rng(cfg.RANDOM_STATE).choice(
            X_soh_model.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
        gpr.fit(X_soh_model[idx], y_soh[idx])
    else:
        gpr.fit(X_soh_model, y_soh)
    pred_gpr = gpr.predict(X_soh_model[mask_test])
    r2_gpr = r2_score(y_soh[mask_test], pred_gpr)
    rmse_gpr = math.sqrt(mean_squared_error(y_soh[mask_test], pred_gpr))
    candidates["gpr_raw"] = (gpr, r2_gpr, rmse_gpr)

    # (2) HistGradientBoosting
    hgb = HistGradientBoostingRegressor(
        learning_rate=0.05, max_iter=500,
        l2_regularization=1e-3, random_state=cfg.RANDOM_STATE
    )
    hgb.fit(X_soh_model[~mask_test], y_soh[~mask_test])
    pred_hgb = hgb.predict(X_soh_model[mask_test])
    r2_hgb = r2_score(y_soh[mask_test], pred_hgb)
    rmse_hgb = math.sqrt(mean_squared_error(y_soh[mask_test], pred_hgb))
    candidates["hgb_raw"] = (hgb, r2_hgb, rmse_hgb)

    # (3) Shape-normalized GP
    shape_model = None; shape_scaler=None; shape_pca=None; shape_metrics=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and (X_shape is not None):
        shape_scaler = StandardScaler()
        X_shape_s = shape_scaler.fit_transform(X_shape)
        X_shape_model = X_shape_s
        if cfg.USE_PCA_SOH:
            shape_pca = PCA(n_components=min(cfg.PCA_SOH_COMPONENTS, X_shape_s.shape[1]-1),
                            random_state=cfg.RANDOM_STATE)
            X_shape_model = shape_pca.fit_transform(X_shape_s)
        dim_s = X_shape_model.shape[1]
        kernel_s = RBF(length_scale=np.ones(dim_s)*3.0,
                       length_scale_bounds=(1e-1,1e4)) + \
                   WhiteKernel(noise_level=1e-2,
                               noise_level_bounds=(1e-6,1e-1))
        shape_model = GaussianProcessRegressor(
            kernel=kernel_s, alpha=0.0, normalize_y=True,
            random_state=cfg.RANDOM_STATE, n_restarts_optimizer=3
        )
        if X_shape_model.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
            idx_s = np.random.default_rng(cfg.RANDOM_STATE).choice(
                X_shape_model.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
            shape_model.fit(X_shape_model[idx_s], y_soh[idx_s])
        else:
            shape_model.fit(X_shape_model, y_soh)
        spred = shape_model.predict(X_shape_model[mask_test])
        r2_shape = r2_score(y_soh[mask_test], spred)
        rmse_shape = math.sqrt(mean_squared_error(y_soh[mask_test], spred))
        candidates["gpr_shape"] = (shape_model, r2_shape, rmse_shape)
        shape_metrics = {"r2": r2_shape, "rmse": rmse_shape}

    # Select best raw candidate
    best_name = max(["gpr_raw","hgb_raw"], key=lambda k: candidates[k][1])
    best_model, best_r2, best_rmse = candidates[best_name]

    if cfg.VERBOSE:
        print(f"[SoH] GPR_raw:  R2={r2_gpr:.3f} RMSE={rmse_gpr:.2f}")
        print(f"[SoH] HGB_raw:  R2={r2_hgb:.3f} RMSE={rmse_hgb:.2f}")
        if shape_metrics:
            print(f"[SoH] ShapeGP: R2={shape_metrics['r2']:.3f} RMSE={shape_metrics['rmse']:.2f}")
        print(f"[SoH] Selected raw model = {best_name}")

    # Mahalanobis precomputed on scaled raw SoH space
    cov = np.cov(X_soh_s.T)
    try:
        cov_inv = np.linalg.pinv(cov)
    except Exception:
        cov_inv = np.eye(cov.shape[0])
    center = X_soh_s.mean(axis=0)

    bundle = {
        # SOC
        "soc_scaler": soc_scaler,
        "soc_pca": soc_pca,
        "soc_model": soc_model,
        # SOH primary
        "soh_scaler": soh_scaler,
        "soh_pca": soh_pca,
        "soh_model": best_model,
        "soh_model_name": best_name,
        # Shape model (optional)
        "shape_scaler": shape_scaler,
        "shape_pca": shape_pca,
        "shape_model": shape_model,
        # Meta
        "freq_grid": CANON_FREQ,
        "feature_version": cfg.FEATURE_VERSION,
        "feature_manifest": feature_names,
        "config": to_jsonable(asdict(cfg)),
        "metrics": {
            "soc_accuracy": soc_acc,
            "soc_macro_f1": soc_f1,
            "soh_r2_selected": best_r2,
            "soh_rmse_selected": best_rmse
        },
        "soh_candidates_metrics": {
            "gpr_raw": {"r2": r2_gpr, "rmse": rmse_gpr},
            "hgb_raw": {"r2": r2_hgb, "rmse": rmse_hgb},
            "gpr_shape": shape_metrics
        },
        "train_mahal": {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
    }
    out_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    joblib.dump(bundle, out_path)
    if cfg.VERBOSE:
        print(f"[MODEL] Saved bundle → {out_path}")
        print(json.dumps(bundle["metrics"], indent=2))
    return bundle

# =========================
# 11. LOAD (WITH LEGACY COMPATIBILITY SHIM)
# =========================
def load_bundle():
    """
    Load model bundle with backward compatibility for earlier versions.

    Legacy schema keys:
        "scaler", "pca", "soc_model", "soh_model"
    New schema keys (v7):
        "soc_scaler","soc_pca","soc_model","soh_scaler","soh_pca","soh_model","soh_model_name",...

    We synthesize missing fields for legacy bundles so inference does not break.
    """
    path = cfg.MODEL_DIR / "eis_soc_soh_phys_models.joblib"
    bundle = joblib.load(path)

    # Detect legacy
    legacy = ("scaler" in bundle) and ("soc_scaler" not in bundle)
    if legacy:
        scaler = bundle["scaler"]
        pca = bundle.get("pca")
        soc_model = bundle.get("soc_model")
        soh_model = bundle.get("soh_model")

        bundle["soc_scaler"] = scaler
        bundle["soh_scaler"] = scaler
        bundle["soc_pca"]    = pca
        bundle["soh_pca"]    = pca
        bundle["soh_model"]  = soh_model
        bundle["soh_model_name"] = bundle.get("soh_model_name","legacy_model")
        if "metrics" not in bundle:
            bundle["metrics"] = {}
        bundle["metrics"].setdefault("soh_rmse_selected", 5.0)
        if "train_mahal" not in bundle:
            try:
                center = scaler.mean_
                cov_inv = np.eye(len(center))
                bundle["train_mahal"] = {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
            except Exception:
                bundle["train_mahal"] = None
        bundle.setdefault("feature_version", -1)

    # Sanity check
    for key in ["soc_scaler","soc_model","soh_scaler","soh_model","freq_grid"]:
        if key not in bundle:
            raise KeyError(f"Bundle missing required key: {key}")

    return bundle

# =========================
# 12. INFERENCE FEATURIZATION
# =========================
def featurize_any(file_path: Path, bundle):
    freq_grid = bundle["freq_grid"]
    meta = parse_eis_metadata(file_path.stem)
    freq,re_raw,im_raw = load_any_inference(file_path)
    re_i=_interp_channel(freq, re_raw, freq_grid)
    im_i=_interp_channel(freq, im_raw, freq_grid)
    if meta is None and cfg.TEST_TEMPERATURE_OVERRIDE is not None:
        temp = cfg.TEST_TEMPERATURE_OVERRIDE
    else:
        temp = meta["Temp"] if meta else -1
    vec = build_feature_vector(re_i, im_i, temp, freq_grid)
    norm_vec=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and bundle.get("shape_model") is not None:
        if cfg.NORMALIZE_SHAPE_BY_HF_RE:
            rsh, ish = build_shape_normalized(re_i, im_i)
            norm_vec = build_feature_vector(rsh, ish, temp, freq_grid)
    return vec, norm_vec, meta

# =========================
# 13. OOD UTILITIES
# =========================
def mahalanobis_distance(x, center, cov_inv):
    diff = x - center
    return float(np.sqrt(diff @ cov_inv @ diff.T))

def gp_ard_norm(Xp, model):
    try:
        K = model.kernel_
        from sklearn.gaussian_process.kernels import RBF
        rbf = None
        if hasattr(K,"k1") and isinstance(K.k1,RBF): rbf=K.k1
        elif hasattr(K,"k2") and isinstance(K.k2,RBF): rbf=K.k2
        if rbf is None: return None
        ls = np.atleast_1d(rbf.length_scale)
        z = (Xp / ls).ravel()
        return float(np.linalg.norm(z))
    except Exception:
        return None

# =========================
# 14. PROJECTION PLOT
# =========================
def build_projection(soh_current, cpp, lower, exponent=None, n=160):
    if soh_current <= lower or cpp <= 0:
        return np.array([0.0]), np.array([soh_current])
    total = (soh_current - lower) * cpp
    cycles = np.linspace(0, total, n)
    S0 = soh_current; Smin=lower
    if exponent is None: exponent = cfg.PLOT_EXPONENT
    soh_curve = Smin + (S0 - Smin)*(1 - cycles/total)**exponent
    return cycles, soh_curve

def plot_projection(file_base, soh_current, soh_std, cycles_to_target,
                    cycles_to_lower, cpp, ood_flag, out_path):
    if cycles_to_lower <= 0:
        return
    cycles, curve = build_projection(soh_current, cpp, cfg.ILLUSTRATIVE_MIN_SOH)
    plt.figure(figsize=(6.4,4))
    plt.plot(cycles, curve, lw=2, label="Projected SoH")
    plt.axhline(cfg.DECISION_SOH_PERCENT, color="orange", ls="--", label=f"{cfg.DECISION_SOH_PERCENT:.0f}% target")
    plt.axhline(cfg.ILLUSTRATIVE_MIN_SOH, color="red", ls=":", label=f"{cfg.ILLUSTRATIVE_MIN_SOH:.0f}% lower")
    plt.scatter([0],[soh_current], c="green", s=55, label=f"Current {soh_current:.2f}%")
    plt.text(0, soh_current+0.7, f"±{soh_std:.2f}", color="green", fontsize=8)
    if cycles_to_target > 0:
        plt.axvline(cycles_to_target, color="orange", ls="-.")
        plt.scatter([cycles_to_target],[cfg.DECISION_SOH_PERCENT], c="orange", s=45)
        plt.text(cycles_to_target, cfg.DECISION_SOH_PERCENT+1.0,
                 f"{cycles_to_target:.0f} cyc", ha="center", color="orange", fontsize=8)
    plt.scatter([cycles[-1]],[cfg.ILLUSTRATIVE_MIN_SOH], c="red", s=50)
    plt.text(cycles[-1], cfg.ILLUSTRATIVE_MIN_SOH-2,
             f"{cycles[-1]:.0f} cyc", ha="center", color="red", fontsize=8)
    if ood_flag:
        plt.text(0.98,0.05,"OOD", transform=plt.gca().transAxes,
                 ha="right", va="bottom", color="crimson", fontsize=11,
                 bbox=dict(boxstyle="round", fc="w", ec="crimson"))
    plt.xlabel("Remaining Cycles")
    plt.ylabel("SoH (%)")
    plt.title(f"RUL Projection – {file_base}")
    plt.grid(alpha=0.35)
    plt.legend(fontsize=8)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140)
    plt.close()

# =========================
# 15. INFERENCE (SINGLE FILE)
# =========================
def predict_file(file_path: Path, bundle, cpp_map, global_cpp):
    vec, norm_vec, meta = featurize_any(file_path, bundle)

    # SoC
    soc_scaler=bundle["soc_scaler"]; soc_pca=bundle.get("soc_pca")
    soc_model=bundle["soc_model"]
    X_soc = soc_scaler.transform(vec.reshape(1,-1))
    X_soc_in = soc_pca.transform(X_soc) if soc_pca else X_soc
    soc_probs = soc_model.predict_proba(X_soc_in)[0]
    soc_classes = soc_model.classes_
    soc_pred = int(soc_classes[np.argmax(soc_probs)])

    # SoH (raw primary)
    soh_scaler=bundle["soh_scaler"]; soh_pca=bundle.get("soh_pca")
    soh_model=bundle["soh_model"]; model_name=bundle.get("soh_model_name","unknown")
    X_soh_s = soh_scaler.transform(vec.reshape(1,-1))
    X_soh_in = soh_pca.transform(X_soh_s) if soh_pca else X_soh_s

    if "gpr" in model_name:
        sm, ss = soh_model.predict(X_soh_in, return_std=True)
        soh_mean_raw = float(sm[0]); soh_std_raw=float(ss[0])
    else:
        soh_mean_raw = float(soh_model.predict(X_soh_in)[0])
        soh_std_raw  = float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    # Shape model
    shape_model = bundle.get("shape_model")
    shape_soh_mean=None; shape_soh_std=None
    if shape_model is not None and norm_vec is not None:
        sscaler = bundle.get("shape_scaler")
        spca = bundle.get("shape_pca")
        X_shape_s = sscaler.transform(norm_vec.reshape(1,-1))
        X_shape_in = spca.transform(X_shape_s) if spca else X_shape_s
        if hasattr(shape_model,"predict"):
            if isinstance(shape_model, GaussianProcessRegressor):
                sm2, ss2 = shape_model.predict(X_shape_in, return_std=True)
                shape_soh_mean=float(sm2[0]); shape_soh_std=float(ss2[0])
            else:
                shape_soh_mean=float(shape_model.predict(X_shape_in)[0])
                shape_soh_std=float(bundle["metrics"].get("soh_rmse_selected", 5.0))

    # Ensemble
    if cfg.ENSEMBLE_SOH and shape_soh_mean is not None:
        soh_mean = 0.5*(soh_mean_raw + shape_soh_mean)
        stds = [soh_std_raw]
        if shape_soh_std is not None: stds.append(shape_soh_std)
        soh_std = float(np.sqrt(np.mean(np.array(stds)**2)))
    else:
        soh_mean, soh_std = soh_mean_raw, soh_std_raw

    # RUL
    cpp = get_cpp(meta, cpp_map, global_cpp)
    if soh_mean > cfg.DECISION_SOH_PERCENT:
        cycles_to_target = (soh_mean - cfg.DECISION_SOH_PERCENT)*cpp
    else:
        cycles_to_target = 0.0
    if soh_mean > cfg.ILLUSTRATIVE_MIN_SOH:
        cycles_to_lower = (soh_mean - cfg.ILLUSTRATIVE_MIN_SOH)*cpp
    else:
        cycles_to_lower = 0.0

    # OOD diagnostics
    train_mahal = bundle.get("train_mahal")
    mahal_dist=None
    if train_mahal:
        cov_inv = np.array(train_mahal["cov_inv"])
        center = np.array(train_mahal["center"])
        mahal_dist = mahalanobis_distance(X_soh_s[0], center, cov_inv)
    ard_norm=None
    if "gpr" in model_name:
        ard_norm = gp_ard_norm(X_soh_in, soh_model)
    ood_flag=False
    if (mahal_dist is not None and mahal_dist > cfg.MAHAL_THRESHOLD) or \
       (ard_norm is not None and ard_norm > cfg.GP_ARD_NORM_THRESHOLD):
        ood_flag=True

    result={
        "file": str(file_path),
        "parsed_metadata": meta,
        "predicted_SoC": soc_pred,
        "SoC_probabilities": {int(c): float(p) for c,p in zip(soc_classes, soc_probs)},
        "predicted_SoH_percent": soh_mean,
        "SoH_std_estimate": soh_std,
        "raw_model_mean": soh_mean_raw,
        "raw_model_std": soh_std_raw,
        "shape_model_mean": shape_soh_mean,
        "shape_model_std": shape_soh_std,
        "cycles_per_percent_used": cpp,
        "cycles_to_target": cycles_to_target,
        "cycles_to_lower": cycles_to_lower,
        "decision_threshold_percent": cfg.DECISION_SOH_PERCENT,
        "lower_threshold_percent": cfg.ILLUSTRATIVE_MIN_SOH,
        "feature_version": bundle.get("feature_version"),
        "soh_model_chosen": model_name,
        "OOD_mahal": mahal_dist,
        "OOD_gp_ard_norm": ard_norm,
        "OOD_flag": ood_flag
    }
    return result, ood_flag, cycles_to_target, cycles_to_lower

# =========================
# 16. MAIN
# =========================
def main():
    if cfg.VERBOSE:
        print("Configuration:\n", json.dumps(to_jsonable(asdict(cfg)), indent=2))

    # Directory assertions
    assert cfg.EIS_DIR.exists(), f"EIS_DIR missing: {cfg.EIS_DIR}"
    if cfg.REFINE_SOH_WITH_CAPACITY:
        assert cfg.CAP_DIR.exists(), f"CAP_DIR missing: {cfg.CAP_DIR}"

    # Capacity + dynamic CPP
    cap_df = load_capacity_info(cfg.CAP_DIR)
    if cap_df.empty:
        if cfg.VERBOSE:
            print("[INFO] No / empty capacity data -> fallback Cpp.")
        cpp_map, global_cpp = {}, cfg.CPP_FALLBACK
    else:
        cpp_map, global_cpp = build_cpp_map(cap_df)
        if cfg.VERBOSE:
            print(f"[CPP] dynamic cells={len(cpp_map)} global_cpp_median={global_cpp:.2f}")

    bundle_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    if bundle_path.exists() and not cfg.FORCE_RETRAIN:
        if cfg.VERBOSE:
            print(f"[LOAD] Using existing model bundle: {bundle_path}")
        bundle = load_bundle()
    else:
        if cfg.VERBOSE:
            print("[TRAIN] Building dataset & training models...")
        meta_df, X_raw, shape_bundle, y_soc, y_soh = build_dataset(cfg.EIS_DIR, cap_df)
        if cfg.VERBOSE:
            print(f"[TRAIN] Samples={X_raw.shape[0]} Features={X_raw.shape[1]} Cells={meta_df.CellID.nunique()}")
        bundle = train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh)

    # Inference
    for test_fp in cfg.EIS_TEST_FILES:
        print(f"\n===== TEST: {test_fp.name} =====")
        if not test_fp.exists():
            print(f"[WARN] Test file not found: {test_fp}")
            continue
        try:
            result, ood_flag, cyc_target, cyc_lower = predict_file(test_fp, bundle, cpp_map, global_cpp)
        except Exception as e:
            print(f"[ERROR] Prediction failed for {test_fp.name}: {e}")
            continue

        out_plot = cfg.MODEL_DIR / f"{test_fp.stem}_projection.png"
        plot_projection(
            test_fp.stem,
            result["predicted_SoH_percent"],
            result["SoH_std_estimate"],
            result["cycles_to_target"],
            result["cycles_to_lower"],
            result["cycles_per_percent_used"],
            result["OOD_flag"],
            out_plot
        )

        out_json = cfg.MODEL_DIR / f"{test_fp.stem}_prediction.json"
        with out_json.open("w", encoding="utf-8") as f:
            json.dump(result, f, indent=2)
        print(json.dumps(result, indent=2))
        print(f"[PLOT] Saved: {out_plot}")
        print(f"[JSON] Saved: {out_json}")

    print("\nDone.")

if __name__ == "__main__":
    main()


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "EIS_TEST_FILES": [
    "Mazda-Battery-Cell1.xlsx",
    "Mazda-Battery-Cell2.xlsx"
  ],
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "TEST_FRAC": 0.2,
  "GROUP_KFOLDS": 0,
  "RANDOM_STATE": 42,
  "USE_PCA_SOC": true,
  "USE_PCA_SOH": false,
  "PCA_SOC_COMPONENTS": 25,
  "PCA_SOH_COMPONENTS": 30,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_TAU_MAX": 10000.0,
  "DRT_LAMBDA": 0.01,
  "REFINE_SOH_WITH_CAPACITY

In [1]:
s

NameError: name 's' is not defined