In [1]:
#!/usr/bin/env python
"""
Unified EIS → SoC (classification) & SoH (regression) Model (Realistic v8a)
---------------------------------------------------------------------------

Changes vs v8:
  * Temperature feature renamed to Temp_feat (avoid duplicate with metadata Temp).
  * Defensive duplicate renaming before parquet save.

See previous description for full feature list & rationale.
"""

from __future__ import annotations
import re, json, math, random, joblib, warnings
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy import linalg
from scipy.interpolate import interp1d

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    mean_squared_error, r2_score
)
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt

# ====================================
# 1. CONFIGURATION
# ====================================
@dataclass
class Config:
    EIS_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")

    EIS_TEST_FILES: List[Path] = None

    F_MIN: float = 1e-2
    F_MAX: float = 1e4
    N_FREQ: int = 60

    TEST_FRAC: float = 0.2
    RANDOM_STATE: int = 42

    USE_PCA_SOC: bool = True
    USE_PCA_SOH: bool = False
    PCA_SOC_COMPONENTS: int = 25
    PCA_SOH_COMPONENTS: int = 30

    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True

    INCLUDE_SHAPE_NORMALIZED_BRANCH: bool = True
    NORMALIZE_SHAPE_BY_HF_RE: bool = True

    DRT_POINTS: int = 60
    DRT_TAU_MIN: float = 1e-4
    DRT_TAU_MAX: float = 1e4
    DRT_LAMBDA: float = 1e-2

    REFINE_SOH_WITH_CAPACITY: bool = True

    MAX_GPR_TRAIN_SAMPLES: int = 3500
    ENSEMBLE_SOH: bool = True
    ENSEMBLE_STD_MODE: str = "rms"

    DECISION_SOH_PERCENT: float = 50.0
    ILLUSTRATIVE_MIN_SOH: float = 40.0
    CPP_ROLLING_WINDOW: int = 5
    CPP_MIN_POINTS: int = 6
    CPP_FALLBACK: float = 20.0

    TEST_TEMPERATURE_OVERRIDE: Optional[float] = 25.0
    FORCE_RETRAIN: bool = True

    MAHAL_THRESHOLD: float = 10.0
    GP_ARD_NORM_THRESHOLD: float = 6.0

    MIN_UNIQUE_SOH: int = 4
    MIN_REQUIRED_SOH_STD: float = 0.75

    PLOT_EXPONENT: float = 1.25

    SAVE_FEATURE_TABLE: bool = True
    VERBOSE: bool = True
    FEATURE_VERSION: int = 8

cfg = Config()
if cfg.EIS_TEST_FILES is None:
    cfg.EIS_TEST_FILES = [
        Path("Mazda-Battery-Cell1.xlsx"),
        Path("Mazda-Battery-Cell2.xlsx")
    ]
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# ====================================
# 2. UTILITIES
# ====================================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path): return str(x)
    if isinstance(x, dict): return {k: to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list,tuple)): return [to_jsonable(i) for i in x]
    return x

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# ====================================
# 3. REGEX
# ====================================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

def parse_eis_metadata(stem: str):
    m = EIS_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "SOC": int(d["SOC"]),
        "Temp": int(d["Temp"]),
        "RealSOH_file": int(d["RealSOH"]) / 100.0
    }

def parse_cap_metadata(stem: str):
    m = CAP_META_PATTERN.search(stem)
    if not m: return None
    d = m.groupdict()
    return {
        "CellID": f"Cell{d['CellID']}",
        "SOH_stage": int(d["SOH"]),
        "Temp": int(d["Temp"]),
        "CycleIndex": int(d["Cycle"])
    }

# ====================================
# 4. LOADING / INTERPOLATION
# ====================================
def _find_matrix(mat_dict: dict):
    for v in mat_dict.values():
        if isinstance(v, np.ndarray) and v.ndim==2 and v.shape[1]>=3 and v.shape[0]>=10:
            return v
    return None

def _interp_channel(freq_raw, y_raw, freq_target):
    freq_raw=np.asarray(freq_raw).astype(float)
    y_raw=np.asarray(y_raw).astype(float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw=freq_raw[::-1]; y_raw=y_raw[::-1]
    uniq, idx = np.unique(freq_raw, return_index=True)
    if len(uniq)!=len(freq_raw):
        order=np.argsort(idx)
        freq_raw=uniq[order]; y_raw=y_raw[idx][order]
    f = interp1d(freq_raw, y_raw, bounds_error=False,
                 fill_value=(y_raw[0], y_raw[-1]), kind="linear")
    return f(freq_target)

FREQ_CANDS=["frequency","freq","f","hz","frequency(hz)","Frequency(Hz)"]
RE_CANDS=["zreal","re(z)","re","real","z_re","zreal(ohm)","re (ohm)","re(z) (ohm)","Zreal","Zreal (ohm)","Zreal(ohm)"]
IM_CANDS=["-zimag","zimag","im(z)","im","imag","imaginary","z_im","zimg","z_imag",
          " -Zimag (ohm)"," -Zimag(ohm)","-Zimag","Zimag","Zimag (ohm)"]

def _select_column(df: pd.DataFrame, cands: List[str]):
    lower={c.lower(): c for c in df.columns}
    for c in cands:
        if c.lower() in lower: return lower[c.lower()]
    for c in cands:
        for col in df.columns:
            if c.lower() in col.lower(): return col
    return None

def load_mat_eis(path: Path):
    mat=loadmat(path); arr=_find_matrix(mat)
    if arr is None: raise ValueError(f"No valid matrix in {path.name}")
    return arr[:,0].astype(float), arr[:,1].astype(float), arr[:,2].astype(float)

def load_table_eis(path: Path):
    if path.suffix.lower()==".csv":
        df=pd.read_csv(path)
    else:
        df=pd.read_excel(path)
    if df.empty: raise ValueError("Empty EIS table.")
    fcol=_select_column(df,FREQ_CANDS)
    recol=_select_column(df,RE_CANDS)
    imcol=_select_column(df,IM_CANDS)
    if recol is None or imcol is None:
        raise ValueError(f"Missing Re/Im columns in {path.name}")
    re_vals = pd.to_numeric(df[recol], errors="coerce").to_numpy()
    im_vals = pd.to_numeric(df[imcol], errors="coerce").to_numpy()
    if fcol:
        freq_vals = pd.to_numeric(df[fcol], errors="coerce").to_numpy()
    else:
        n = min(len(re_vals), len(im_vals))
        freq_vals = np.geomspace(cfg.F_MAX, cfg.F_MIN, n)
    n=min(len(freq_vals), len(re_vals), len(im_vals))
    freq_vals=freq_vals[:n]; re_vals=re_vals[:n]; im_vals=im_vals[:n]
    if np.nanmean(im_vals) > 0: im_vals = -im_vals
    return freq_vals, re_vals, im_vals

def load_any_inference(path: Path):
    suf=path.suffix.lower()
    if suf==".mat": return load_mat_eis(path)
    if suf in (".csv",".xls",".xlsx"): return load_table_eis(path)
    raise ValueError(f"Unsupported extension {suf}")

# ====================================
# 5. FEATURES
# ====================================
def compute_F_features(freq, re_i, im_i):
    neg_im=-im_i
    idx_peak=int(np.argmax(neg_im))
    F1=re_i[0]; F2=re_i[idx_peak]; F3=re_i[-1]
    sc=np.where(np.sign(im_i[:-1])!=np.sign(im_i[1:]))[0]
    if len(sc):
        k=sc[0]; y0,y1=im_i[k],im_i[k+1]
        w=-y0/(y1-y0+1e-12)
        F4=re_i[k]+w*(re_i[k+1]-re_i[k])
    else:
        F4=np.nan
    F5=(re_i[idx_peak]-F1) if idx_peak>0 else np.nan
    F6=np.min(im_i)
    mid_target=10.0
    idx_mid=int(np.argmin(np.abs(freq-mid_target)))
    F7=re_i[idx_mid]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES=[
    "Rs","Rct","tau_peak","warburg_sigma","arc_quality",
    "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"
]

def physical_features(freq, re_i, im_i):
    freq=np.asarray(freq); re_i=np.asarray(re_i); im_i=np.asarray(im_i)
    neg_im=-im_i
    idx_peak=int(np.argmax(neg_im))
    Rs=float(re_i[0]); Rpeak=float(re_i[idx_peak]); Rlow=float(re_i[-1])
    Rct=max(Rpeak-Rs,0.0)
    arc_diam=Rlow-Rs
    norm_arc=arc_diam/(Rs+1e-9)
    f_peak=float(freq[idx_peak])
    tau_peak=1/(2*math.pi*f_peak) if f_peak>0 else np.nan
    K=min(10,len(freq)//3)
    if K>=4:
        w_section=(2*np.pi*freq[-K:])**(-0.5)
        re_section=re_i[-K:]
        warburg_sigma=float(np.polyfit(w_section,re_section,1)[0]) if len(np.unique(w_section))>2 else np.nan
    else:
        warburg_sigma=np.nan
    phase=np.arctan2(-im_i,re_i)
    mid_mask=(freq>=1)&(freq<=100)
    if mid_mask.sum()>2:
        phase_mean_mid=float(phase[mid_mask].mean())
        phase_std_mid=float(phase[mid_mask].std())
    else:
        phase_mean_mid=np.nan; phase_std_mid=np.nan
    phase_min=float(phase.min())
    lf_mask=(freq<=1.0)
    if lf_mask.sum()>=4:
        x=np.log10(freq[lf_mask]+1e-12); y=neg_im[lf_mask]
        lf_slope=np.polyfit(x,y,1)[0]
    else:
        lf_slope=np.nan
    arc_quality=(neg_im.max()-neg_im.min())/(abs(neg_im.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_quality,
            phase_mean_mid,phase_std_mid,phase_min,lf_slope,norm_arc]

BANDS=[(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq,re_i,im_i):
    feats=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m], im_i[m])
            feats += [z.mean(), z.std()]
        else:
            feats += [np.nan, np.nan]
    return feats

def diff_slopes(freq,re_i,im_i,segments=5):
    logf=np.log10(freq)
    edges=np.linspace(logf.min(),logf.max(),segments+1)
    out=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            x=logf[m]
            out += [np.polyfit(x,re_i[m],1)[0], np.polyfit(x,(-im_i)[m],1)[0]]
        else:
            out += [np.nan, np.nan]
    return out

DRT_FEATURE_NAMES=[
    "drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
    "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"
]

def compute_drt(freq,re_i,im_i,tau_min,tau_max,n_tau,lam):
    w=2*np.pi*freq
    tau=np.geomspace(tau_max,tau_min,n_tau)
    WT=w[:,None]*tau[None,:]
    denom=1+WT**2
    K_re=1.0/denom
    K_im=-WT/denom
    R_inf=re_i[0]
    y_re=re_i-R_inf
    y_im=im_i
    Y=np.concatenate([y_re,y_im])
    K=np.vstack([K_re,K_im])
    A=K.T@K + lam*np.eye(n_tau)
    b=K.T@Y
    gamma=linalg.solve(A,b,assume_a='pos')
    gamma=np.clip(gamma,0,None)
    return tau,gamma

def drt_features(freq,re_i,im_i):
    try:
        tau,gamma=compute_drt(freq,re_i,im_i,
                              cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,
                              cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        log_tau=np.log10(tau)
        g_sum=gamma.sum()+1e-12
        w_norm=gamma/g_sum
        mean_logtau=float((w_norm*log_tau).sum())
        var_logtau=float((w_norm*(log_tau-mean_logtau)**2).sum())
        p=int(np.argmax(gamma))
        peak_tau=float(tau[p]); peak_gamma=float(gamma[p])
        mid=np.median(log_tau)
        frac_low=float(w_norm[log_tau<=mid].sum())
        frac_high=1-frac_low
        return [g_sum,mean_logtau,var_logtau,peak_tau,peak_gamma,frac_low,frac_high]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i, im_i, temp, freq, include_names=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts += [re_i, im_i]
        names += [f"Re_{i}" for i in range(len(re_i))] + [f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z=np.hypot(re_i,im_i)
        basics=[re_i[0], re_i[-1], re_i[-1]-re_i[0], z.max(), z.mean(), z.std()]
        parts.append(np.array(basics)); names += ["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        Ff=compute_F_features(freq,re_i,im_i); parts.append(np.array(Ff)); names += [f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        Pf=physical_features(freq,re_i,im_i); parts.append(np.array(Pf)); names += PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        Bf=band_stats(freq,re_i,im_i); parts.append(np.array(Bf))
        for bi in range(len(BANDS)): names += [f"band{bi}_mean", f"band{bi}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        Ds=diff_slopes(freq,re_i,im_i); parts.append(np.array(Ds))
        for i in range(len(Ds)//2): names += [f"slope_re_seg{i}", f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        Df=drt_features(freq,re_i,im_i); parts.append(np.array(Df)); names += DRT_FEATURE_NAMES
    parts.append(np.array([temp])); names += ["Temp_feat"]  # <-- renamed
    vec=np.concatenate(parts).astype(float)
    vec=np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    if include_names: return vec, names
    return vec

def shape_normalize(re_i, im_i):
    hf = re_i[0] if re_i[0] != 0 else 1.0
    return re_i / hf, im_i / hf, hf

# ====================================
# 6. CAPACITY / CPP
# ====================================
def load_capacity_info(cap_dir: Path)->pd.DataFrame:
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta=parse_cap_metadata(fp.stem)
        if not meta: continue
        try:
            mat=loadmat(fp); arr=_find_matrix(mat)
            if arr is None: continue
            col=np.argmax(np.abs(arr[-50:, :]).mean(axis=0))
            cap=float(np.nanmax(arr[:,col]))
            meta["MeasuredCapacity_Ah"]=cap
            recs.append(meta)
        except Exception:
            pass
    df=pd.DataFrame(recs)
    if df.empty: return df
    ref=df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"]=df["MeasuredCapacity_Ah"]/ref
    df["SoH_percent"]=df["NormCapacity"]*100.0
    return df

def estimate_cpp(cap_df: pd.DataFrame, window:int, min_points:int):
    cpp={}
    for cid,grp in cap_df.groupby("CellID"):
        g=grp.sort_values("CycleIndex")
        if g.shape[0]<min_points: continue
        tail=g.tail(window)
        x=tail["CycleIndex"].values.astype(float)
        y=tail["SoH_percent"].values.astype(float)
        if len(np.unique(x))<2: continue
        slope=np.polyfit(x,y,1)[0]
        if slope >= -1e-6: continue
        cpp[cid]=1.0/abs(slope)
    return cpp

def build_cpp_map(cap_df: pd.DataFrame):
    if cap_df.empty: return {}, cfg.CPP_FALLBACK
    cpp_map=estimate_cpp(cap_df, cfg.CPP_ROLLING_WINDOW, cfg.CPP_MIN_POINTS)
    if not cpp_map: return {}, cfg.CPP_FALLBACK
    return cpp_map, float(np.median(list(cpp_map.values())))

def get_cpp(meta: dict, cpp_map: Dict[str,float], global_cpp: float):
    if not meta: return global_cpp
    return cpp_map.get(meta.get("CellID"), global_cpp)

# ====================================
# 7. DATASET BUILD
# ====================================
def load_single_mat(fp: Path):
    meta=parse_eis_metadata(fp.stem)
    if meta is None: raise ValueError(f"Pattern mismatch: {fp.name}")
    freq,re_raw,im_raw=load_mat_eis(fp)
    re_i=_interp_channel(freq,re_raw,CANON_FREQ)
    im_i=_interp_channel(freq,im_raw,CANON_FREQ)
    return meta, re_i, im_i

def build_dataset(eis_dir: Path, cap_df: Optional[pd.DataFrame]):
    files=sorted(eis_dir.rglob("*.mat"))
    if not files: raise FileNotFoundError(f"No .mat spectra in {eis_dir}")

    # For feature names
    first=None
    for f in files:
        try:
            m,re_i,im_i=load_single_mat(f)
            first=(m,re_i,im_i); break
        except Exception: continue
    if first is None: raise RuntimeError("No parsable .mat files.")
    _, r0, i0 = first
    _, feature_names = build_feature_vector(r0, i0, 25.0, CANON_FREQ, include_names=True)

    rows=[]; raw_feats=[]; shape_feats=[]
    for fp in tqdm(files, desc="Loading training spectra"):
        try:
            meta, re_i, im_i = load_single_mat(fp)
            raw_vec = build_feature_vector(re_i, im_i, meta["Temp"], CANON_FREQ)
            raw_feats.append(raw_vec); rows.append(meta)
            if cfg.INCLUDE_SHAPE_NORMALIZED_BRANCH and cfg.NORMALIZE_SHAPE_BY_HF_RE:
                rsh, ish, _ = shape_normalize(re_i, im_i)
                shape_vec = build_feature_vector(rsh, ish, meta["Temp"], CANON_FREQ)
                shape_feats.append(shape_vec)
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")

    if not rows: raise RuntimeError("No usable spectra.")
    X_raw=np.vstack(raw_feats)
    X_shape=np.vstack(shape_feats) if (cfg.INCLUDE_SHAPE_NORMALIZED_BRANCH and shape_feats) else None
    meta_df=pd.DataFrame(rows)

    # SoH refinement
    if cap_df is not None and not cap_df.empty and cfg.REFINE_SOH_WITH_CAPACITY:
        lookup=cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        refined=[]
        for cid, stage, fallback in zip(meta_df.CellID, meta_df.SOH_stage, meta_df.RealSOH_file):
            val=lookup.get((cid,stage))
            refined.append(100.0*val if val is not None else fallback)
        meta_df["SoH_cont"]=refined
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    y_soc=meta_df["SOC"].values
    y_soh=meta_df["SoH_cont"].values

    unique_soh=np.unique(y_soh)
    soh_std=float(np.std(y_soh))
    soh_range=(float(np.min(y_soh)), float(np.max(y_soh)))
    if cfg.VERBOSE:
        print(f"[LABELS] SoH unique count={len(unique_soh)} range={soh_range} std={soh_std:.3f}")
        print(f"[LABELS] Unique SoH values (truncated): {unique_soh[:20]}{'...' if len(unique_soh)>20 else ''}")
    if len(unique_soh) < cfg.MIN_UNIQUE_SOH or soh_std < cfg.MIN_REQUIRED_SOH_STD:
        raise RuntimeError(
            f"Insufficient SoH diversity (unique={len(unique_soh)}, std={soh_std:.3f}). "
            f"Add more varied degradation stages."
        )

    if cfg.SAVE_FEATURE_TABLE:
        feat_df = pd.DataFrame(X_raw, columns=feature_names)
        # Defensive duplicate renaming
        dup = set(meta_df.columns).intersection(feat_df.columns)
        if dup:
            rename_map = {c: f"{c}_feat" for c in dup}
            feat_df = feat_df.rename(columns=rename_map)
            if cfg.VERBOSE:
                print(f"[INFO] Renamed duplicate feature columns: {rename_map}")
        pd.concat([meta_df.reset_index(drop=True), feat_df], axis=1)\
          .to_parquet(cfg.MODEL_DIR/"training_features.parquet", index=False)

    return meta_df, X_raw, X_shape, y_soc, y_soh, feature_names

# ====================================
# 8. TRAINING
# ====================================
def split_mask(meta_df: pd.DataFrame):
    cells=meta_df.CellID.unique()
    rng=np.random.default_rng(cfg.RANDOM_STATE)
    n_test=max(1,int(len(cells)*cfg.TEST_FRAC))
    test_cells=rng.choice(cells,size=n_test,replace=False)
    return meta_df.CellID.isin(test_cells)

def train_models(meta_df, X_raw, X_shape, y_soc, y_soh, feature_names):
    mask_test=split_mask(meta_df)

    # SoC
    soc_scaler=StandardScaler()
    X_soc_s=soc_scaler.fit_transform(X_raw)
    soc_pca=None
    X_soc_model=X_soc_s
    if cfg.USE_PCA_SOC:
        soc_pca=PCA(n_components=min(cfg.PCA_SOC_COMPONENTS,X_soc_s.shape[1]-1),
                    random_state=cfg.RANDOM_STATE)
        X_soc_model=soc_pca.fit_transform(X_soc_s)
    soc_model=RandomForestClassifier(
        n_estimators=600, min_samples_leaf=2, class_weight='balanced',
        n_jobs=-1, random_state=cfg.RANDOM_STATE
    )
    soc_model.fit(X_soc_model[~mask_test], y_soc[~mask_test])
    soc_pred=soc_model.predict(X_soc_model[mask_test])
    soc_acc=accuracy_score(y_soc[mask_test], soc_pred)
    soc_f1=f1_score(y_soc[mask_test], soc_pred, average='macro')
    if cfg.VERBOSE:
        print(f"[SoC] Acc={soc_acc:.3f} MacroF1={soc_f1:.3f}")
        print(classification_report(y_soc[mask_test], soc_pred, digits=4))

    # SoH raw
    soh_scaler=StandardScaler()
    X_soh_s=soh_scaler.fit_transform(X_raw)
    soh_pca=None
    X_soh_in=X_soh_s
    if cfg.USE_PCA_SOH:
        soh_pca=PCA(n_components=min(cfg.PCA_SOH_COMPONENTS,X_soh_s.shape[1]-1),
                    random_state=cfg.RANDOM_STATE)
        X_soh_in=soh_pca.fit_transform(X_soh_s)

    dim_raw=X_soh_in.shape[1]
    kernel_raw=RBF(length_scale=np.ones(dim_raw)*3.0,
                   length_scale_bounds=(1e-1,1e4)) + \
               WhiteKernel(noise_level=1e-2,
                           noise_level_bounds=(1e-6,1e-1))
    gpr_raw=GaussianProcessRegressor(
        kernel=kernel_raw, alpha=0.0, normalize_y=True,
        n_restarts_optimizer=3, random_state=cfg.RANDOM_STATE
    )
    if X_soh_in.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
        idx=np.random.default_rng(cfg.RANDOM_STATE).choice(
            X_soh_in.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
        gpr_raw.fit(X_soh_in[idx], y_soh[idx])
    else:
        gpr_raw.fit(X_soh_in, y_soh)
    pred_raw=gpr_raw.predict(X_soh_in[mask_test])
    r2_raw=r2_score(y_soh[mask_test], pred_raw)
    rmse_raw=math.sqrt(mean_squared_error(y_soh[mask_test], pred_raw))

    hgb=HistGradientBoostingRegressor(
        learning_rate=0.05, max_iter=500,
        l2_regularization=1e-3, random_state=cfg.RANDOM_STATE
    )
    hgb.fit(X_soh_in[~mask_test], y_soh[~mask_test])
    pred_hgb=hgb.predict(X_soh_in[mask_test])
    r2_hgb=r2_score(y_soh[mask_test], pred_hgb)
    rmse_hgb=math.sqrt(mean_squared_error(y_soh[mask_test], pred_hgb))

    # Shape branch
    shape_bundle=None
    if cfg.INCLUDE_SHAPE_NORMALIZED_BRANCH and X_shape is not None:
        shape_scaler=StandardScaler()
        X_sh_s=shape_scaler.fit_transform(X_shape)
        shape_pca=None
        X_sh_in=X_sh_s
        if cfg.USE_PCA_SOH:
            shape_pca=PCA(n_components=min(cfg.PCA_SOH_COMPONENTS,X_sh_s.shape[1]-1),
                          random_state=cfg.RANDOM_STATE)
            X_sh_in=shape_pca.fit_transform(X_sh_s)
        dim_sh=X_sh_in.shape[1]
        kernel_sh=RBF(length_scale=np.ones(dim_sh)*3.0,
                      length_scale_bounds=(1e-1,1e4)) + \
                  WhiteKernel(noise_level=1e-2,
                              noise_level_bounds=(1e-6,1e-1))
        gpr_shape=GaussianProcessRegressor(
            kernel=kernel_sh, normalize_y=True, alpha=0.0,
            n_restarts_optimizer=3, random_state=cfg.RANDOM_STATE
        )
        if X_sh_in.shape[0] > cfg.MAX_GPR_TRAIN_SAMPLES:
            idxs=np.random.default_rng(cfg.RANDOM_STATE).choice(
                X_sh_in.shape[0], size=cfg.MAX_GPR_TRAIN_SAMPLES, replace=False)
            gpr_shape.fit(X_sh_in[idxs], y_soh[idxs])
        else:
            gpr_shape.fit(X_sh_in, y_soh)
        pred_sh=gpr_shape.predict(X_sh_in[mask_test])
        r2_sh=r2_score(y_soh[mask_test], pred_sh)
        rmse_sh=math.sqrt(mean_squared_error(y_soh[mask_test], pred_sh))
        shape_bundle={
            "shape_scaler": shape_scaler,
            "shape_pca": shape_pca,
            "shape_model": gpr_shape,
            "shape_metrics": {"r2": r2_sh, "rmse": rmse_sh}
        }
    else:
        r2_sh=rmse_sh=None

    # Select primary
    if r2_raw >= r2_hgb:
        primary_model=gpr_raw; primary_name="gpr_raw"; primary_r2=r2_raw; primary_rmse=rmse_raw
    else:
        primary_model=hgb; primary_name="hgb_raw"; primary_r2=r2_hgb; primary_rmse=rmse_hgb

    if cfg.VERBOSE:
        print(f"[SoH] GPR_raw R2={r2_raw:.3f} RMSE={rmse_raw:.2f}")
        print(f"[SoH] HGB     R2={r2_hgb:.3f} RMSE={rmse_hgb:.2f}")
        if shape_bundle:
            print(f"[SoH] ShapeGP R2={shape_bundle['shape_metrics']['r2']:.3f} "
                  f"RMSE={shape_bundle['shape_metrics']['rmse']:.2f}")
        print(f"[SoH] Selected raw = {primary_name}")

    cov=np.cov(X_soh_s.T)
    try:
        cov_inv=np.linalg.pinv(cov)
    except Exception:
        cov_inv=np.eye(cov.shape[0])
    center=X_soh_s.mean(axis=0)

    bundle={
        "soc_scaler": soc_scaler,
        "soc_pca": soc_pca,
        "soc_model": soc_model,
        "soh_scaler": soh_scaler,
        "soh_pca": soh_pca,
        "soh_model": primary_model,
        "soh_model_name": primary_name,
        "shape_scaler": shape_bundle["shape_scaler"] if shape_bundle else None,
        "shape_pca": shape_bundle["shape_pca"] if shape_bundle else None,
        "shape_model": shape_bundle["shape_model"] if shape_bundle else None,
        "shape_metrics": shape_bundle["shape_metrics"] if shape_bundle else None,
        "freq_grid": CANON_FREQ,
        "feature_version": cfg.FEATURE_VERSION,
        "feature_manifest": feature_names,
        "config": to_jsonable(asdict(cfg)),
        "metrics": {
            "soc_accuracy": soc_acc,
            "soc_macro_f1": soc_f1,
            "soh_primary_r2": primary_r2,
            "soh_primary_rmse": primary_rmse,
            "soh_gpr_r2": r2_raw,
            "soh_hgb_r2": r2_hgb,
            "soh_shape_r2": r2_sh
        },
        "train_mahal": {"center": center.tolist(), "cov_inv": cov_inv.tolist()}
    }
    out=cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    joblib.dump(bundle,out)
    if cfg.VERBOSE:
        print(f"[MODEL] Saved → {out}")
        print(json.dumps(bundle["metrics"], indent=2))
    return bundle

# ====================================
# 9. LOAD
# ====================================
def load_bundle():
    path=cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    b=joblib.load(path)
    required={"soc_scaler","soc_model","soh_scaler","soh_model","freq_grid"}
    if not required.issubset(b.keys()):
        raise RuntimeError("Bundle schema mismatch; retrain required.")
    if b.get("feature_version", -1) < cfg.FEATURE_VERSION:
        warnings.warn("Older feature version detected; consider retraining.")
    return b

# ====================================
# 10. INFERENCE
# ====================================
def mahalanobis_distance(x, center, cov_inv):
    diff=x-center
    return float(np.sqrt(diff @ cov_inv @ diff.T))

def gp_ard_norm(Xp, model):
    try:
        K=model.kernel_
        from sklearn.gaussian_process.kernels import RBF
        rbf=None
        if hasattr(K,"k1") and isinstance(K.k1,RBF): rbf=K.k1
        elif hasattr(K,"k2") and isinstance(K.k2,RBF): rbf=K.k2
        if rbf is None: return None
        ls=np.atleast_1d(rbf.length_scale)
        return float(np.linalg.norm((Xp/ls).ravel()))
    except Exception:
        return None

def featurize_for_inference(path: Path, bundle):
    freq_grid=bundle["freq_grid"]
    meta=parse_eis_metadata(path.stem)
    freq,re_raw,im_raw=load_any_inference(path)
    re_i=_interp_channel(freq,re_raw,freq_grid)
    im_i=_interp_channel(freq,im_raw,freq_grid)
    temp = meta["Temp"] if meta else (cfg.TEST_TEMPERATURE_OVERRIDE if cfg.TEST_TEMPERATURE_OVERRIDE is not None else -1)
    raw_vec=build_feature_vector(re_i, im_i, temp, freq_grid)
    shape_vec=None
    if cfg.INCLUDE_SHAPE_NORMALIZED_BRANCH and bundle.get("shape_model") is not None:
        rsh, ish, _ = shape_normalize(re_i, im_i)
        shape_vec=build_feature_vector(rsh, ish, temp, freq_grid)
    return raw_vec, shape_vec, meta

def ensemble_soh(raw_mean, raw_std, shape_mean, shape_std):
    if shape_mean is None: return raw_mean, raw_std
    if raw_mean is None: return shape_mean, shape_std
    m=0.5*(raw_mean+shape_mean)
    if cfg.ENSEMBLE_STD_MODE.lower()=="rms":
        s=math.sqrt(np.mean([raw_std**2, shape_std**2]))
    else:
        s=0.5*(raw_std+shape_std)
    return m,s

def predict_file(path: Path, bundle, cpp_map, global_cpp):
    raw_vec, shape_vec, meta=featurize_for_inference(path, bundle)

    # SoC
    soc_scaler=bundle["soc_scaler"]; soc_pca=bundle.get("soc_pca")
    soc_model=bundle["soc_model"]
    X_soc=soc_scaler.transform(raw_vec.reshape(1,-1))
    X_soc_in=soc_pca.transform(X_soc) if soc_pca else X_soc
    soc_probs=soc_model.predict_proba(X_soc_in)[0]
    soc_classes=soc_model.classes_
    soc_pred=int(soc_classes[np.argmax(soc_probs)])

    # SoH raw
    soh_scaler=bundle["soh_scaler"]; soh_pca=bundle.get("soh_pca")
    soh_model=bundle["soh_model"]; model_name=bundle.get("soh_model_name")
    X_soh_s=soh_scaler.transform(raw_vec.reshape(1,-1))
    X_soh_in=soh_pca.transform(X_soh_s) if soh_pca else X_soh_s
    if isinstance(soh_model, GaussianProcessRegressor):
        sm, ss=soh_model.predict(X_soh_in, return_std=True)
        raw_mean=float(sm[0]); raw_std=float(ss[0])
    else:
        raw_mean=float(soh_model.predict(X_soh_in)[0])
        raw_std=float(b["metrics"].get("soh_primary_rmse",5.0))

    # Shape branch
    shape_mean=None; shape_std=None
    shape_model=bundle.get("shape_model")
    if shape_model is not None and shape_vec is not None:
        sscaler=bundle.get("shape_scaler"); spca=bundle.get("shape_pca")
        X_shape_s=sscaler.transform(shape_vec.reshape(1,-1))
        X_shape_in=spca.transform(X_shape_s) if spca else X_shape_s
        if isinstance(shape_model, GaussianProcessRegressor):
            sm2, ss2=shape_model.predict(X_shape_in, return_std=True)
            shape_mean=float(sm2[0]); shape_std=float(ss2[0])
        else:
            shape_mean=float(shape_model.predict(X_shape_in)[0])
            shape_std=float(bundle["metrics"].get("soh_primary_rmse",5.0))

    soh_mean, soh_std = ensemble_soh(raw_mean, raw_std, shape_mean, shape_std)

    cpp = get_cpp(meta, cpp_map, global_cpp)
    cycles_to_target = (soh_mean - cfg.DECISION_SOH_PERCENT)*cpp if soh_mean > cfg.DECISION_SOH_PERCENT else 0.0
    cycles_to_lower = (soh_mean - cfg.ILLUSTRATIVE_MIN_SOH)*cpp if soh_mean > cfg.ILLUSTRATIVE_MIN_SOH else 0.0

    train_mahal=bundle.get("train_mahal")
    mahal=None
    if train_mahal:
        center=np.array(train_mahal["center"])
        cov_inv=np.array(train_mahal["cov_inv"])
        mahal=mahalanobis_distance(X_soh_s[0], center, cov_inv)
    ard_norm=None
    if isinstance(soh_model, GaussianProcessRegressor):
        ard_norm=gp_ard_norm(X_soh_in, soh_model)
    ood_flag=False
    if (mahal is not None and mahal>cfg.MAHAL_THRESHOLD) or \
       (ard_norm is not None and ard_norm>cfg.GP_ARD_NORM_THRESHOLD):
        ood_flag=True

    return {
        "file": str(path),
        "parsed_metadata": meta,
        "predicted_SoC": soc_pred,
        "SoC_probabilities": {int(c): float(p) for c,p in zip(soc_classes, soc_probs)},
        "predicted_SoH_percent": soh_mean,
        "SoH_std_estimate": soh_std,
        "raw_model_mean": raw_mean,
        "raw_model_std": raw_std,
        "shape_model_mean": shape_mean,
        "shape_model_std": shape_std,
        "cycles_per_percent_used": cpp,
        "cycles_to_target": cycles_to_target,
        "cycles_to_lower": cycles_to_lower,
        "decision_threshold_percent": cfg.DECISION_SOH_PERCENT,
        "lower_threshold_percent": cfg.ILLUSTRATIVE_MIN_SOH,
        "feature_version": bundle.get("feature_version"),
        "soh_model_chosen": model_name,
        "OOD_mahal": mahal,
        "OOD_gp_ard_norm": ard_norm,
        "OOD_flag": ood_flag
    }

# ====================================
# 11. PROJECTION PLOT
# ====================================
def build_projection(soh_current, cpp, lower, exponent=None, n=160):
    if soh_current <= lower or cpp <= 0: return np.array([0.0]), np.array([soh_current])
    total=(soh_current-lower)*cpp
    cycles=np.linspace(0,total,n)
    exp=exponent if exponent is not None else cfg.PLOT_EXPONENT
    curve=lower + (soh_current-lower)*(1 - cycles/total)**exp
    return cycles, curve

def plot_projection(base, soh_current, soh_std, cyc_target, cyc_lower, cpp, ood_flag, out_path):
    if cyc_lower <= 0: return
    cycles, curve=build_projection(soh_current, cpp, cfg.ILLUSTRATIVE_MIN_SOH)
    plt.figure(figsize=(6.4,4))
    plt.plot(cycles, curve, lw=2, label="Projected SoH")
    plt.axhline(cfg.DECISION_SOH_PERCENT, color="orange", ls="--", label="Decision")
    plt.axhline(cfg.ILLUSTRATIVE_MIN_SOH, color="red", ls=":", label="Lower")
    plt.scatter([0],[soh_current], c="green", s=50)
    plt.text(0, soh_current+0.6, f"{soh_current:.2f}±{soh_std:.2f}", color="green", fontsize=8)
    if cyc_target>0:
        plt.axvline(cyc_target, color="orange", ls="-.")
        plt.text(cyc_target, cfg.DECISION_SOH_PERCENT+1, f"{cyc_target:.0f} cyc",
                 ha="center", color="orange", fontsize=8)
    plt.scatter([cycles[-1]],[cfg.ILLUSTRATIVE_MIN_SOH], c="red", s=45)
    plt.text(cycles[-1], cfg.ILLUSTRATIVE_MIN_SOH-2, f"{cycles[-1]:.0f} cyc",
             ha="center", color="red", fontsize=8)
    if ood_flag:
        plt.text(0.98,0.05,"OOD", transform=plt.gca().transAxes,
                 ha="right", va="bottom", color="crimson",
                 bbox=dict(boxstyle="round", fc="white", ec="crimson"))
    plt.xlabel("Remaining Cycles")
    plt.ylabel("SoH (%)")
    plt.title(f"RUL Projection – {base}")
    plt.grid(alpha=0.35)
    plt.legend(fontsize=8)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140)
    plt.close()

# ====================================
# 12. MAIN
# ====================================
def main():
    if cfg.VERBOSE:
        print("Configuration:\n", json.dumps({k: to_jsonable(v) for k,v in asdict(cfg).items()}, indent=2))

    assert cfg.EIS_DIR.exists(), f"EIS_DIR missing: {cfg.EIS_DIR}"
    cap_df = load_capacity_info(cfg.CAP_DIR)

    bundle_path = cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    need_retrain = cfg.FORCE_RETRAIN or (not bundle_path.exists())
    bundle=None
    if not need_retrain:
        try:
            bundle=load_bundle()
            if cfg.VERBOSE: print("[LOAD] Existing bundle loaded (set FORCE_RETRAIN=True to rebuild).")
        except Exception as e:
            print(f"[LOAD] Failed to load bundle ({e}); retraining.")
            need_retrain=True

    if need_retrain:
        if cfg.VERBOSE: print("[TRAIN] Building dataset & training...")
        try:
            meta_df, X_raw, X_shape, y_soc, y_soh, feature_names = build_dataset(cfg.EIS_DIR, cap_df)
        except RuntimeError as e:
            print(f"[FATAL] {e}")
            return
        if cfg.VERBOSE:
            print(f"[TRAIN] Samples={X_raw.shape[0]} Features={X_raw.shape[1]} Cells={meta_df.CellID.nunique()}")
        bundle=train_models(meta_df, X_raw, X_shape, y_soc, y_soh, feature_names)

    # CPP map
    if cap_df.empty:
        if cfg.VERBOSE: print("[CPP] No capacity data; using fallback.")
        cpp_map, global_cpp = {}, cfg.CPP_FALLBACK
    else:
        cpp_map, global_cpp = build_cpp_map(cap_df)
        if cfg.VERBOSE:
            print(f"[CPP] dynamic cells={len(cpp_map)} global median cpp={global_cpp:.2f}")

    for tf in cfg.EIS_TEST_FILES:
        print(f"\n===== TEST: {tf.name} =====")
        if not tf.exists():
            print(f"[WARN] Missing file: {tf}")
            continue
        try:
            result = predict_file(tf, bundle, cpp_map, global_cpp)
        except Exception as e:
            print(f"[ERROR] Prediction failed: {e}")
            continue
        out_json = cfg.MODEL_DIR / f"{tf.stem}_prediction.json"
        with out_json.open("w", encoding="utf-8") as f:
            json.dump(result, f, indent=2)
        out_plot = cfg.MODEL_DIR / f"{tf.stem}_projection.png"
        plot_projection(
            tf.stem,
            result["predicted_SoH_percent"],
            result["SoH_std_estimate"],
            result["cycles_to_target"],
            result["cycles_to_lower"],
            result["cycles_per_percent_used"],
            result["OOD_flag"],
            out_plot
        )
        print(json.dumps(result, indent=2))
        print(f"[JSON] Saved: {out_json}")
        if result["cycles_to_lower"] > 0:
            print(f"[PLOT] Saved: {out_plot}")
        else:
            print("[PLOT] Skipped (SoH below lower threshold).")

    print("\nDone.")

if __name__ == "__main__":
    main()


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "EIS_TEST_FILES": [
    "Mazda-Battery-Cell1.xlsx",
    "Mazda-Battery-Cell2.xlsx"
  ],
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "TEST_FRAC": 0.2,
  "RANDOM_STATE": 42,
  "USE_PCA_SOC": true,
  "USE_PCA_SOH": false,
  "PCA_SOC_COMPONENTS": 25,
  "PCA_SOH_COMPONENTS": 30,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "INCLUDE_SHAPE_NORMALIZED_BRANCH": true,
  "NORMALIZE_SHAPE_BY_HF_RE": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_TAU_MAX": 

AssertionError: EIS_DIR missing: C:\Users\tmgon\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test