In [2]:
# ======================================================================
# Unified EIS Training + Inference + Dynamic RUL  (v9 – single file)
# ======================================================================
#   • Accepts ONE EIS test file (cfg.EIS_TEST_FILE or --test path)
#   • Learns a data-driven calibration factor for the GP’s predictive σ
#   • Caps SoH uncertainty to ≈3 percentage-points
#   • Outputs the single most-likely SoC class
#   • Back-compatible with legacy model bundles
# ======================================================================

from __future__ import annotations
import sys, argparse, json, math, random, re, warnings, joblib
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
from scipy import linalg
from scipy.interpolate import interp1d

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, r2_score
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# =========================
# 1. CONFIGURATION
# =========================
@dataclass
class Config:
    # ---------- adjust these four paths ----------
    # --- local folders -------------------------------------------------
    EIS_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University (1)\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University (1)\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")
    EIS_TEST_FILE: Path = Path(r"C:\Users\tmgon\OneDrive - Edith Cowan University (1)\00 - Megallan Power\NMC Batteries Warwick Station\NMC\TestFile\Mazda-Battery-Cell5.xlsx")

    # spectrum grid
    F_MIN: float = 1e-2; F_MAX: float = 1e4; N_FREQ: int = 60
    # split & seeds
    TEST_FRAC: float = 0.20; RANDOM_STATE: int = 42
    # PCA
    USE_PCA_SOC: bool = True;  PCA_SOC_COMPONENTS: int = 25
    USE_PCA_SOH: bool = False; PCA_SOH_COMPONENTS: int = 30
    # feature toggles
    INCLUDE_RAW_RE_IM: bool = True
    INCLUDE_BASICS: bool = True
    INCLUDE_F_FEATS: bool = True
    INCLUDE_PHYSICAL: bool = True
    INCLUDE_DRT: bool = True
    INCLUDE_BAND_STATS: bool = True
    INCLUDE_DIFF_SLOPES: bool = True
    # DRT
    DRT_POINTS: int = 60; DRT_TAU_MIN: float = 1e-4; DRT_TAU_MAX: float = 1e4; DRT_LAMBDA: float = 1e-2
    # SoH / RUL
    REFINE_SOH_WITH_CAPACITY: bool = True
    MAX_GPR_TRAIN_SAMPLES: int = 3500
    INCLUDE_NORMALIZED_SHAPE_MODEL: bool = True; NORMALIZE_SHAPE_BY_HF_RE: bool = True
    ENSEMBLE_SOH: bool = True
    DECISION_SOH_PERCENT: float = 50.0; ILLUSTRATIVE_MIN_SOH: float = 40.0
    CPP_ROLLING_WINDOW: int = 5; CPP_MIN_POINTS: int = 6; CPP_FALLBACK: float = 20.0
    # misc
    TEST_TEMPERATURE_OVERRIDE: Optional[float] = 25.0
    FORCE_RETRAIN: bool = False          # set True first time
    SAVE_FEATURE_TABLE: bool = True; VERBOSE: bool = True
    FEATURE_VERSION: int = 9; PLOT_EXPONENT: float = 1.25

cfg = Config()
cfg.MODEL_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# 2. UTILITIES
# =========================
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
set_seed(cfg.RANDOM_STATE)

def to_jsonable(x):
    if isinstance(x, Path): return str(x)
    if isinstance(x, dict): return {k: to_jsonable(v) for k, v in x.items()}
    if isinstance(x, (list, tuple)): return [to_jsonable(i) for i in x]
    return x

CANON_FREQ = np.geomspace(cfg.F_MAX, cfg.F_MIN, cfg.N_FREQ)

# =========================
# 3. REGEX & METADATA PARSERS
# =========================
EIS_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_(?P<Temp>\d+)degC_(?P<SOC>\d+)SOC_(?P<RealSOH>\d+)"
)
CAP_META_PATTERN = re.compile(
    r"Cell(?P<CellID>\d+)_(?P<SOH>80|85|90|95|100)SOH_Capacity_Check_(?P<Temp>\d+)degC_(?P<Cycle>\d+)cycle"
)

def parse_eis_metadata(stem:str)->Optional[Dict[str,Any]]:
    m=EIS_META_PATTERN.search(stem)
    if not m: return None
    d=m.groupdict()
    return {"CellID":f"Cell{d['CellID']}",
            "SOH_stage":int(d["SOH"]),
            "SOC":int(d["SOC"]),
            "Temp":int(d["Temp"]),
            "RealSOH_file":int(d["RealSOH"])/100.0}

def parse_cap_metadata(stem:str)->Optional[Dict[str,Any]]:
    m=CAP_META_PATTERN.search(stem)
    if not m: return None
    d=m.groupdict()
    return {"CellID":f"Cell{d['CellID']}",
            "SOH_stage":int(d["SOH"]),
            "Temp":int(d["Temp"]),
            "CycleIndex":int(d["Cycle"])}

# =========================
# 4. LOW-LEVEL LOADERS
# =========================
def _find_matrix(mat_dict:dict):
    for v in mat_dict.values():
        if isinstance(v,np.ndarray) and v.ndim==2 and v.shape[1]>=3:
            return v
    return None

def _interp_channel(freq_raw,y_raw,freq_target):
    freq_raw=np.asarray(freq_raw,float); y_raw=np.asarray(y_raw,float)
    if freq_raw[0] < freq_raw[-1]:
        freq_raw=freq_raw[::-1]; y_raw=y_raw[::-1]
    uniq,idx=np.unique(freq_raw,return_index=True)
    if len(uniq)!=len(freq_raw):
        order=np.argsort(idx); freq_raw=uniq[order]; y_raw=y_raw[idx][order]
    f=interp1d(freq_raw,y_raw,bounds_error=False,
               fill_value=(y_raw[0],y_raw[-1]),kind="linear")
    return f(freq_target)

FREQ_CANDS=["frequency","freq","f","hz","frequency(hz)","Frequency(Hz)"]
RE_CANDS  =["zreal","re(z)","re","real","z_re","Zreal","Re (ohm)","Zreal (ohm)"]
IM_CANDS  =["-zimag","zimag","im","imag","z_im","Zimag","Zimag (ohm)"]

def _select_column(df:pd.DataFrame,cands:List[str])->Optional[str]:
    low={c.lower():c for c in df.columns}
    for c in cands:
        if c.lower() in low: return low[c.lower()]
    for c in cands:
        for col in df.columns:
            if c.lower() in col.lower(): return col
    return None

def load_mat_eis(path:Path):
    arr=_find_matrix(loadmat(path))
    if arr is None: raise ValueError("No EIS matrix")
    return arr[:,0].astype(float),arr[:,1].astype(float),arr[:,2].astype(float)

def load_table_eis(path:Path):
    df=pd.read_csv(path) if path.suffix.lower()==".csv" else pd.read_excel(path)
    if df.empty: raise ValueError("Empty table")
    fcol=_select_column(df,FREQ_CANDS)
    recol=_select_column(df,RE_CANDS); imcol=_select_column(df,IM_CANDS)
    if recol is None or imcol is None: raise ValueError("Re/Im columns missing")
    re_vals=pd.to_numeric(df[recol],errors="coerce").to_numpy()
    im_vals=pd.to_numeric(df[imcol],errors="coerce").to_numpy()
    freq_vals=pd.to_numeric(df[fcol],errors="coerce").to_numpy() if fcol else \
               np.geomspace(cfg.F_MAX,cfg.F_MIN,len(re_vals))
    if np.nanmean(im_vals)>0: im_vals=-im_vals
    return freq_vals,re_vals.astype(float),im_vals.astype(float)

def load_any_inference(path:Path):
    suf=path.suffix.lower()
    if suf==".mat": return load_mat_eis(path)
    if suf in (".csv",".xls",".xlsx"): return load_table_eis(path)
    raise ValueError(f"Unsupported ext {suf}")

# =========================
# 5. FEATURE ENGINEERING
# =========================
def compute_F_features(freq,re_i,im_i):
    neg=-im_i; k=int(np.argmax(neg))
    F1,F2,F3=re_i[0],re_i[k],re_i[-1]
    sc=np.where(np.sign(im_i[:-1])!=np.sign(im_i[1:]))[0]
    F4=np.nan
    if len(sc):
        j=sc[0]; y0,y1=im_i[j],im_i[j+1]; w=-y0/(y1-y0+1e-12)
        F4=re_i[j]+w*(re_i[j+1]-re_i[j])
    F5=re_i[k]-F1 if k>0 else np.nan
    F6=np.min(im_i)
    F7=re_i[int(np.argmin(np.abs(freq-10)))]
    return [F1,F2,F3,F4,F5,F6,F7]

PHYSICAL_FEATURE_NAMES=["Rs","Rct","tau_peak","warburg_sigma","arc_quality",
                        "phase_mean_mid","phase_std_mid","phase_min","lf_slope_negIm","norm_arc"]

def physical_features(freq,re_i,im_i):
    neg=-im_i; k=int(np.argmax(neg))
    Rs,Rpeak,Rlow=re_i[0],re_i[k],re_i[-1]
    Rct=max(Rpeak-Rs,0); arc_diam=Rlow-Rs; norm_arc=arc_diam/(Rs+1e-9)
    tau_peak=1/(2*math.pi*freq[k]) if freq[k]>0 else np.nan
    K=min(10,len(freq)//3); warburg_sigma=np.nan
    if K>=4:
        ws=(2*np.pi*freq[-K:])**-0.5; rs=re_i[-K:]
        if len(np.unique(ws))>2: warburg_sigma=float(np.polyfit(ws,rs,1)[0])
    ph=np.arctan2(-im_i,re_i); mid=(freq>=1)&(freq<=100)
    phase_mean=ph[mid].mean() if mid.sum()>2 else np.nan
    phase_std =ph[mid].std()  if mid.sum()>2 else np.nan
    lf=(freq<=1); lf_slope=np.nan
    if lf.sum()>=4:
        x=np.log10(freq[lf]); y=neg[lf]
        lf_slope=float(np.polyfit(x,y,1)[0])
    arc_q=(neg.max()-neg.min())/(abs(neg.mean())+1e-9)
    return [Rs,Rct,tau_peak,warburg_sigma,arc_q,
            phase_mean,phase_std,float(ph.min()),lf_slope,norm_arc]

BANDS=[(1e4,1e3),(1e3,1e2),(1e2,10),(10,1),(1,1e-1),(1e-1,1e-2)]
def band_stats(freq,re_i,im_i):
    out=[]; freq=np.asarray(freq)
    for hi,lo in BANDS:
        m=(freq<=hi)&(freq>=lo)
        if m.sum()>1:
            z=np.hypot(re_i[m],im_i[m]); out+=[z.mean(),z.std()]
        else:
            out+=[np.nan,np.nan]
    return out

def diff_slopes(freq,re_i,im_i,segments=5):
    logf=np.log10(freq); edges=np.linspace(logf.min(),logf.max(),segments+1); res=[]
    for i in range(segments):
        m=(logf>=edges[i])&(logf<=edges[i+1])
        if m.sum()>=3:
            res+=list(np.polyfit(logf[m],re_i[m],1)[:1])+\
                  list(np.polyfit(logf[m],-im_i[m],1)[:1])
        else:
            res+=[np.nan,np.nan]
    return res

DRT_FEATURE_NAMES=["drt_sum","drt_mean_logtau","drt_var_logtau","drt_peak_tau",
                   "drt_peak_gamma","drt_frac_low_tau","drt_frac_high_tau"]

def compute_drt(freq,re_i,im_i,tmin,tmax,n,lam):
    w=2*np.pi*freq; tau=np.geomspace(tmax,tmin,n); WT=w[:,None]*tau[None,:]
    Kre=1/(1+WT**2); Kim=-WT/(1+WT**2)
    yre=re_i-re_i[0]; yim=im_i
    Y=np.concatenate([yre,yim]); K=np.vstack([Kre,Kim])
    A=K.T@K+lam*np.eye(n); b=K.T@Y
    g=linalg.solve(A,b,assume_a='pos')
    return tau,np.clip(g,0,None)

def drt_features(freq,re_i,im_i):
    try:
        tau,gamma=compute_drt(freq,re_i,im_i,
                              cfg.DRT_TAU_MIN,cfg.DRT_TAU_MAX,cfg.DRT_POINTS,cfg.DRT_LAMBDA)
        logt=np.log10(tau); s=gamma.sum()+1e-12; w=gamma/s
        mean=(w*logt).sum(); var=(w*(logt-mean)**2).sum()
        p=int(np.argmax(gamma))
        frac_low=w[logt<=np.median(logt)].sum()
        return [s,mean,var,float(tau[p]),float(gamma[p]),frac_low,1-frac_low]
    except Exception:
        return [np.nan]*7

def build_feature_vector(re_i,im_i,temp,freq,include_names=False):
    parts=[]; names=[]
    if cfg.INCLUDE_RAW_RE_IM:
        parts+=[re_i,im_i]
        names+=[f"Re_{i}" for i in range(len(re_i))]+[f"Im_{i}" for i in range(len(im_i))]
    if cfg.INCLUDE_BASICS:
        z=np.hypot(re_i,im_i)
        basics=[re_i[0],re_i[-1],re_i[-1]-re_i[0],z.max(),z.mean(),z.std()]
        parts.append(basics); names+=["hf_re","lf_re","arc_diam","zmag_max","zmag_mean","zmag_std"]
    if cfg.INCLUDE_F_FEATS:
        parts.append(compute_F_features(freq,re_i,im_i)); names+=[f"F{i}" for i in range(1,8)]
    if cfg.INCLUDE_PHYSICAL:
        parts.append(physical_features(freq,re_i,im_i)); names+=PHYSICAL_FEATURE_NAMES
    if cfg.INCLUDE_BAND_STATS:
        parts.append(band_stats(freq,re_i,im_i))
        for i in range(len(BANDS)): names+= [f"band{i}_mean",f"band{i}_std"]
    if cfg.INCLUDE_DIFF_SLOPES:
        ds=diff_slopes(freq,re_i,im_i); parts.append(ds)
        for i in range(len(ds)//2): names += [f"slope_re_seg{i}",f"slope_negIm_seg{i}"]
    if cfg.INCLUDE_DRT:
        parts.append(drt_features(freq,re_i,im_i)); names+=DRT_FEATURE_NAMES
    parts.append([temp]); names+=["Temp_feat"]
    vec=np.concatenate(parts).astype(float)
    vec=np.nan_to_num(vec,0,0,0)
    return (vec,names) if include_names else vec

def build_shape_normalized(re_i,im_i):
    hf=re_i[0] if re_i[0]!=0 else 1.0
    return re_i/hf, im_i/hf

# =========================
# 6. CAPACITY  →  CPP HELPERS
# =========================
def load_capacity_info(cap_dir:Path)->pd.DataFrame:
    if not (cap_dir.exists() and cfg.REFINE_SOH_WITH_CAPACITY):
        return pd.DataFrame()
    recs=[]
    for fp in cap_dir.rglob("*.mat"):
        meta=parse_cap_metadata(fp.stem)
        if not meta: continue
        try:
            arr=_find_matrix(loadmat(fp))
            col=np.argmax(np.abs(arr[-50:]).mean(axis=0))
            cap=float(np.nanmax(arr[:,col]))
            meta["MeasuredCapacity_Ah"]=cap
            recs.append(meta)
        except Exception: pass
    if not recs: return pd.DataFrame()
    df=pd.DataFrame(recs)
    ref=df.groupby("CellID")["MeasuredCapacity_Ah"].transform("max")
    df["NormCapacity"]=df["MeasuredCapacity_Ah"]/ref
    df["SoH_percent"]=df["NormCapacity"]*100.0
    return df

def estimate_cpp_per_cell(cap_df:pd.DataFrame,window:int,min_pts:int)->Dict[str,float]:
    cpp={}
    for cid,grp in cap_df.groupby("CellID"):
        g=grp.sort_values("CycleIndex")
        if g.shape[0]<min_pts: continue
        tail=g.tail(window)
        x,y=tail["CycleIndex"].values,tail["SoH_percent"].values
        if len(np.unique(x))<2: continue
        slope=np.polyfit(x,y,1)[0]
        if slope>=-1e-6: continue
        cpp[cid]=1/abs(slope)
    return cpp

def build_cpp_map(df:pd.DataFrame):
    if df.empty: return {},cfg.CPP_FALLBACK
    cpp_map=estimate_cpp_per_cell(df, cfg.CPP_ROLLING_WINDOW, cfg.CPP_MIN_POINTS)
    global_cpp=float(np.median(list(cpp_map.values()))) if cpp_map else cfg.CPP_FALLBACK
    return cpp_map,global_cpp

def get_cpp(meta:dict,cpp_map:Dict[str,float],global_cpp:float)->float:
    if not meta: return global_cpp
    return cpp_map.get(meta["CellID"],global_cpp)

# =========================
# 7. DATASET HELPERS
# =========================
def cell_split_mask(meta_df:pd.DataFrame):
    cells=meta_df.CellID.unique()
    rng=np.random.default_rng(cfg.RANDOM_STATE)
    n=max(1,int(len(cells)*cfg.TEST_FRAC))
    val=rng.choice(cells,size=n,replace=False)
    return meta_df.CellID.isin(val)

def build_dataset(eis_dir:Path,cap_df:Optional[pd.DataFrame]):
    files=sorted(eis_dir.rglob("*.mat"))
    if not files: raise FileNotFoundError("No .mat in EIS_DIR")

    f0,r0,i0=load_mat_eis(files[0])
    re0=_interp_channel(f0,r0,CANON_FREQ)
    im0=_interp_channel(f0,i0,CANON_FREQ)
    _,feat_names=build_feature_vector(re0,im0,25.0,CANON_FREQ,include_names=True)

    feats,rows,shape_feats=[],[],[]
    for fp in tqdm(files,desc="Loading"):
        try:
            f,r,i=load_mat_eis(fp)
            re=_interp_channel(f,r,CANON_FREQ)
            im=_interp_channel(f,i,CANON_FREQ)
            meta=parse_eis_metadata(fp.stem)
            if meta is None: continue
            feats.append(build_feature_vector(re,im,meta["Temp"],CANON_FREQ))
            rows.append(meta)
            if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and cfg.NORMALIZE_SHAPE_BY_HF_RE:
                rn,in_=build_shape_normalized(re,im)
                shape_feats.append(build_feature_vector(rn,in_,meta["Temp"],CANON_FREQ))
        except Exception as e:
            if cfg.VERBOSE: print(f"[Skip] {fp.name}: {e}")

    if not rows: raise RuntimeError("No valid spectra")

    X=np.vstack(feats)
    X_shape=np.vstack(shape_feats) if shape_feats else None
    meta_df=pd.DataFrame(rows)

    if cfg.REFINE_SOH_WITH_CAPACITY and not cap_df.empty:
        lookup=cap_df.set_index(["CellID","SOH_stage"])["NormCapacity"].to_dict()
        meta_df["SoH_cont"]=[
            100*lookup.get((cid,stg),fallback)
            for cid,stg,fallback in zip(meta_df.CellID,meta_df.SOH_stage,meta_df.RealSOH_file)
        ]
    else:
        meta_df["SoH_cont"]=meta_df["RealSOH_file"]

    y_soc=meta_df["SOC"].values
    y_soh=meta_df["SoH_cont"].values

    if cfg.SAVE_FEATURE_TABLE:
        pd.concat([meta_df.reset_index(drop=True),
                   pd.DataFrame(X,columns=feat_names)],
                  axis=1).to_parquet(cfg.MODEL_DIR/"training_features.parquet",index=False)

    return meta_df,X,(X_shape,feat_names),y_soc,y_soh

# =========================
# 8. TRAIN MODELS (incl. σ-calibration)
# =========================
def train_models(meta_df,X_raw,shape_bundle,y_soc,y_soh):
    rng=np.random.default_rng(cfg.RANDOM_STATE)
    X_shape,feat_names=shape_bundle
    mask_val=cell_split_mask(meta_df)

    # ---------- SoC ----------
    soc_scal=StandardScaler(); X_soc=soc_scal.fit_transform(X_raw)
    soc_pca=None
    if cfg.USE_PCA_SOC:
        soc_pca=PCA(n_components=min(cfg.PCA_SOC_COMPONENTS,X_soc.shape[1]-1),
                    random_state=cfg.RANDOM_STATE)
        X_soc=soc_pca.fit_transform(X_soc)
    soc_model=RandomForestClassifier(
        n_estimators=800,min_samples_leaf=2,class_weight="balanced",
        n_jobs=-1,random_state=cfg.RANDOM_STATE)
    soc_model.fit(X_soc[~mask_val],y_soc[~mask_val])
    if cfg.VERBOSE:
        pv=soc_model.predict(X_soc[mask_val])
        print(f"[SoC] Acc={accuracy_score(y_soc[mask_val],pv):.3f}  "
              f"F1={f1_score(y_soc[mask_val],pv,average='macro'):.3f}")

    # ---------- SoH GP ----------
    soh_scal=StandardScaler(); Xs=soh_scal.fit_transform(X_raw)
    soh_pca=None
    if cfg.USE_PCA_SOH:
        soh_pca=PCA(n_components=min(cfg.PCA_SOH_COMPONENTS,Xs.shape[1]-1),
                    random_state=cfg.RANDOM_STATE)
        Xs=soh_pca.fit_transform(Xs)

    kernel=RBF(length_scale=np.ones(Xs.shape[1])*3.0,
               length_scale_bounds=(1e-1,1e4))+WhiteKernel(noise_level=1e-2,
                                                           noise_level_bounds=(1e-6,1e-1))
    gpr=GaussianProcessRegressor(kernel=kernel,alpha=0,normalize_y=True,
                                 random_state=cfg.RANDOM_STATE,n_restarts_optimizer=3)
    idx=rng.choice(Xs.shape[0],size=min(cfg.MAX_GPR_TRAIN_SAMPLES,Xs.shape[0]),
                   replace=False)
    gpr.fit(Xs[idx],y_soh[idx])

    mu_val,sig_val=gpr.predict(Xs[mask_val],return_std=True)
    rmse=mean_squared_error(y_soh[mask_val],mu_val,squared=False)
    sigma_fact=float(np.clip(rmse/(sig_val.mean()+1e-12),0.1,5.0))
    r2_gpr=r2_score(y_soh[mask_val],mu_val)

    # ---------- HGB fallback ----------
    hgb=HistGradientBoostingRegressor(learning_rate=0.05,max_iter=500,
                                      l2_regularization=1e-3,
                                      random_state=cfg.RANDOM_STATE)
    hgb.fit(Xs[~mask_val],y_soh[~mask_val])
    r2_hgb=r2_score(y_soh[mask_val],hgb.predict(Xs[mask_val]))

    soh_model,soh_name=(gpr,"gpr_raw") if r2_gpr>=r2_hgb else (hgb,"hgb_raw")

    # ---------- shape-GP (optional) ----------
    shape_scal=shape_pca=shape_model=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and X_shape is not None:
        shape_scal=StandardScaler(); Xsh=shape_scal.fit_transform(X_shape)
        if cfg.USE_PCA_SOH:
            shape_pca=PCA(n_components=min(cfg.PCA_SOH_COMPONENTS,Xsh.shape[1]-1),
                          random_state=cfg.RANDOM_STATE)
            Xsh=shape_pca.fit_transform(Xsh)
        kernel_s=RBF(length_scale=np.ones(Xsh.shape[1])*3.0,
                     length_scale_bounds=(1e-1,1e4))+WhiteKernel(noise_level=1e-2,
                                                                 noise_level_bounds=(1e-6,1e-1))
        shape_model=GaussianProcessRegressor(kernel=kernel_s,alpha=0,normalize_y=True,
                                             random_state=cfg.RANDOM_STATE,n_restarts_optimizer=3)
        idxs=rng.choice(Xsh.shape[0],size=min(cfg.MAX_GPR_TRAIN_SAMPLES,Xsh.shape[0]),
                        replace=False)
        shape_model.fit(Xsh[idxs],y_soh[idxs])

    # ---------- bundle ----------
    cov=np.cov(Xs.T); cov_inv=np.linalg.pinv(cov); center=Xs.mean(axis=0)

    bundle={
        "soc_scaler":soc_scal,"soc_pca":soc_pca,"soc_model":soc_model,
        "soh_scaler":soh_scal,"soh_pca":soh_pca,"soh_model":soh_model,
        "soh_model_name":soh_name,
        "shape_scaler":shape_scal,"shape_pca":shape_pca,"shape_model":shape_model,
        "freq_grid":CANON_FREQ,"feature_version":cfg.FEATURE_VERSION,
        "feature_manifest":feat_names,
        "train_mahal":{"center":center.tolist(),"cov_inv":cov_inv.tolist()},
        "soh_sigma_factor":sigma_fact
    }
    joblib.dump(bundle,cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib")
    return bundle

# =========================
# 9. FEATURISATION FOR INFERENCE
# =========================
def featurize_any(file_path:Path,bundle):
    freq_grid=bundle["freq_grid"]
    meta=parse_eis_metadata(file_path.stem)

    freq,re_raw,im_raw=load_any_inference(file_path)
    re_i=_interp_channel(freq,re_raw,freq_grid)
    im_i=_interp_channel(freq,im_raw,freq_grid)

    temp=meta["Temp"] if meta else cfg.TEST_TEMPERATURE_OVERRIDE or -1

    vec=build_feature_vector(re_i,im_i,temp,freq_grid)

    norm_vec=None
    if cfg.INCLUDE_NORMALIZED_SHAPE_MODEL and bundle.get("shape_model"):
        rn,in_=build_shape_normalized(re_i,im_i)
        norm_vec=build_feature_vector(rn,in_,temp,freq_grid)
    return vec,norm_vec,meta

# =========================
# 10. INFERENCE FOR ONE FILE
# =========================
def predict_file(fp:Path,bundle,cpp_map,global_cpp):
    vec,norm_vec,meta=featurize_any(fp,bundle)

    # SoC
    sc,bp,sm=bundle["soc_scaler"],bundle.get("soc_pca"),bundle["soc_model"]
    X=sc.transform(vec.reshape(1,-1)); X=bp.transform(X) if bp is not None else X
    probs=sm.predict_proba(X)[0]; soc=int(sm.classes_[probs.argmax()])

    # SoH
    ss,sp,soh_m=bundle["soh_scaler"],bundle.get("soh_pca"),bundle["soh_model"]
    Xs=ss.transform(vec.reshape(1,-1)); Xs=sp.transform(Xs) if sp is not None else Xs
    if isinstance(soh_m,GaussianProcessRegressor):
        mu,sig=soh_m.predict(Xs,return_std=True)
        soh=float(mu[0])
        sd=float(sig[0])*bundle.get("soh_sigma_factor",1.0)
    else:
        soh=float(soh_m.predict(Xs)[0]); sd=4.0
    sd=float(np.clip(sd,0.1,3.0))

    # optional ensemble
    if cfg.ENSEMBLE_SOH and bundle.get("shape_model") and norm_vec is not None:
        shc,shp,shm=bundle["shape_scaler"],bundle.get("shape_pca"),bundle["shape_model"]
        Xn=shc.transform(norm_vec.reshape(1,-1)); Xn=shp.transform(Xn) if shp is not None else Xn
        soh=0.5*(soh+float(shm.predict(Xn)[0]))

    cpp=get_cpp(meta,cpp_map,global_cpp)
    cyc_t=max((soh-cfg.DECISION_SOH_PERCENT)*cpp,0.0)
    cyc_l=max((soh-cfg.ILLUSTRATIVE_MIN_SOH)*cpp,0.0)

    return {"file":str(fp),
            "parsed_metadata":meta,
            "predicted_SoC":soc,
            "SoC_probabilities":{int(c):float(p) for c,p in zip(sm.classes_,probs)},
            "predicted_SoH_percent":soh,
            "SoH_std_estimate":sd,
            "cycles_per_percent_used":cpp,
            "cycles_to_target":cyc_t,
            "cycles_to_lower":cyc_l,
            "decision_threshold_percent":cfg.DECISION_SOH_PERCENT,
            "lower_threshold_percent":cfg.ILLUSTRATIVE_MIN_SOH,
            "feature_version":bundle["feature_version"],
            "soh_model_chosen":bundle.get("soh_model_name","raw")}, cyc_t, cyc_l

# =========================
# 11. PLOT HELPER
# =========================
def build_projection(soh,cpp,lower,exp=None,n=160):
    if soh<=lower or cpp<=0: return np.array([0]),np.array([soh])
    total=(soh-lower)*cpp; cycs=np.linspace(0,total,n)
    exp=cfg.PLOT_EXPONENT if exp is None else exp
    curve=lower+(soh-lower)*(1-cycs/total)**exp
    return cycs,curve

def plot_projection(name,soh,sd,cyc_t,cyc_l,cpp,out):
    if cyc_l<=0: return
    cycs,curve=build_projection(soh,cpp,cfg.ILLUSTRATIVE_MIN_SOH)
    plt.figure(figsize=(6.4,4))
    plt.plot(cycs,curve,lw=2,label="Projected SoH")
    plt.axhline(cfg.DECISION_SOH_PERCENT,color="orange",ls="--")
    plt.axhline(cfg.ILLUSTRATIVE_MIN_SOH,color="red",ls=":")
    plt.scatter([0],[soh],c="green",s=55)
    plt.text(0,soh+0.7,f"±{sd:.2f}",color="green",fontsize=8)
    if cyc_t>0:
        plt.axvline(cyc_t,color="orange",ls="-.")
    plt.scatter([cycs[-1]],[cfg.ILLUSTRATIVE_MIN_SOH],c="red",s=50)
    plt.xlabel("Remaining Cycles"); plt.ylabel("SoH (%)")
    plt.title(f"RUL Projection – {name}")
    plt.grid(alpha=0.35); plt.tight_layout(); plt.savefig(out,dpi=140); plt.close()

# =========================
# 12. LOAD (legacy shim)
# =========================
def load_bundle():
    b=joblib.load(cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib")
    if "soc_scaler" not in b:          # older bundles
        b["soc_scaler"]=b["scaler"]; b["soh_scaler"]=b["scaler"]
        b["soc_pca"]=b.get("pca"); b["soh_pca"]=b.get("pca")
    return b

# =========================
# 13. MAIN
# =========================
def main(argv=None):
    p=argparse.ArgumentParser(add_help=False)
    p.add_argument("--test",dest="test_file")
    args,_=p.parse_known_args([] if argv is None else argv)
    if args.test_file:
        cfg.EIS_TEST_FILE=Path(args.test_file)

    assert cfg.EIS_DIR.exists(),  f"EIS_DIR missing: {cfg.EIS_DIR}"
    if cfg.REFINE_SOH_WITH_CAPACITY:
        assert cfg.CAP_DIR.exists(), f"CAP_DIR missing: {cfg.CAP_DIR}"

    if cfg.VERBOSE:
        print("Configuration:\n",json.dumps(to_jsonable(asdict(cfg)),indent=2))

    cap_df=load_capacity_info(cfg.CAP_DIR)
    cpp_map,global_cpp=build_cpp_map(cap_df)

    bundle_path=cfg.MODEL_DIR/"eis_soc_soh_phys_models.joblib"
    if bundle_path.exists() and not cfg.FORCE_RETRAIN:
        bundle=load_bundle()
        if cfg.VERBOSE: print(f"[LOAD] Using bundle → {bundle_path}")
    else:
        if cfg.VERBOSE: print("[TRAIN] Building dataset & training models …")
        meta_df,X,shape_bundle,y_soc,y_soh=build_dataset(cfg.EIS_DIR,cap_df)
        bundle=train_models(meta_df,X,shape_bundle,y_soc,y_soh)

    test_fp=cfg.EIS_TEST_FILE
    if not test_fp.exists(): raise FileNotFoundError(test_fp)
    result,cyc_t,cyc_l=predict_file(test_fp,bundle,cpp_map,global_cpp)

    out_plot=cfg.MODEL_DIR/f"{test_fp.stem}_projection.png"
    plot_projection(test_fp.stem,
                    result["predicted_SoH_percent"],
                    result["SoH_std_estimate"],
                    result["cycles_to_target"],
                    result["cycles_to_lower"],
                    result["cycles_per_percent_used"],
                    out_plot)

    out_json=cfg.MODEL_DIR/f"{test_fp.stem}_prediction.json"
    with out_json.open("w",encoding="utf-8") as f: json.dump(result,f,indent=2)

    print(json.dumps(result,indent=2))
    print(f"[PLOT]  {out_plot}")
    print(f"[JSON]  {out_json}\nDone.")

# =========================
# 14. RUN
# =========================
if __name__ == "__main__":
    main([])


Configuration:
 {
  "EIS_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University (1)\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\EIS_Test",
  "CAP_DIR": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University (1)\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\DIB_Data\\.matfiles\\Capacity_Check",
  "MODEL_DIR": "models_eis_phase2_phys",
  "EIS_TEST_FILE": "C:\\Users\\tmgon\\OneDrive - Edith Cowan University (1)\\00 - Megallan Power\\NMC Batteries Warwick Station\\NMC\\TestFile\\Mazda-Battery-Cell5.xlsx",
  "F_MIN": 0.01,
  "F_MAX": 10000.0,
  "N_FREQ": 60,
  "TEST_FRAC": 0.2,
  "RANDOM_STATE": 42,
  "USE_PCA_SOC": true,
  "PCA_SOC_COMPONENTS": 25,
  "USE_PCA_SOH": false,
  "PCA_SOH_COMPONENTS": 30,
  "INCLUDE_RAW_RE_IM": true,
  "INCLUDE_BASICS": true,
  "INCLUDE_F_FEATS": true,
  "INCLUDE_PHYSICAL": true,
  "INCLUDE_DRT": true,
  "INCLUDE_BAND_STATS": true,
  "INCLUDE_DIFF_SLOPES": true,
  "DRT_POINTS": 60,
  "DRT_TAU_MIN": 0.0001,
  "DRT_

In [None]:
# ────────────────────────────────────────────────────────────────────────
# Gradio demo – upload ONE EIS file → projection plot + predicted SoC
# ────────────────────────────────────────────────────────────────────────
import gradio as gr, tempfile, shutil
from pathlib import Path
from PIL import Image
from sklearn.exceptions import ConvergenceWarning
import warnings; warnings.filterwarnings("ignore", category=ConvergenceWarning)

# ---------- 1. one-time model / CPP setup ------------------------------
cap_df   = load_capacity_info(cfg.CAP_DIR)
cpp_map, global_cpp = build_cpp_map(cap_df) if not cap_df.empty else ({}, cfg.CPP_FALLBACK)

bundle_path = cfg.MODEL_DIR / "eis_soc_soh_phys_models.joblib"
if bundle_path.exists() and not cfg.FORCE_RETRAIN:
    bundle = load_bundle()
    print(f"[GRADIO] Loaded bundle → {bundle_path}")
else:
    print("[GRADIO] Training bundle – first run only …")
    meta_df, X_raw, shape_bundle, y_soc, y_soh = build_dataset(cfg.EIS_DIR, cap_df)
    bundle = train_models(meta_df, X_raw, shape_bundle, y_soc, y_soh)

# ---------- 2. helpers --------------------------------------------------
HEADER_MD = """
### SoH prediction for NMC profile  
Addressing *Electric Vehicle Battery Repurposing Challenges*, iMOVE Australia Limited, iMOVE CRC  
*Prepared by Research Team, Centre for Green and Smart Energy Systems, Edith Cowan University (Joondalup)*
""".strip()

def metrics_md(bndl):
    a, f1 = bndl.get("soc_acc"), bndl.get("soc_f1")
    if a is None or f1 is None:
        return "_Accuracy / F1 not available (bundle predates metric storage)._"
    return f"**Validation —  Acc&nbsp;=&nbsp;{a:.3f},  F1&nbsp;(macro)&nbsp;=&nbsp;{f1:.3f}**"

# ---------- 3. gradio callback -----------------------------------------
def run_inference(uploaded_file):
    tmp_path = Path(tempfile.gettempdir()) / Path(uploaded_file.name).name
    shutil.copy(uploaded_file.name, tmp_path)

    result, *_ = predict_file(tmp_path, bundle, cpp_map, global_cpp)

    plot_path = cfg.MODEL_DIR / f"{tmp_path.stem}_projection.png"
    plot_projection(
        tmp_path.stem,
        result["predicted_SoH_percent"],
        result["SoH_std_estimate"],
        result["cycles_to_target"],
        result["cycles_to_lower"],
        result["cycles_per_percent_used"],
        plot_path
    )

    return (
        Image.open(plot_path),             # RUL chart
        int(result["predicted_SoC"]),      # most-likely SoC
        metrics_md(bundle)                 # model metrics
    )

# ---------- 4. build & launch UI ---------------------------------------
demo = gr.Interface(
    title="SoH prediction for NMC profile",
    description=HEADER_MD,
    fn=run_inference,
    inputs=gr.File(label="Upload EIS test file"),
    outputs=[
        gr.Image(type="pil", label="RUL projection"),
        gr.Number(label="Predicted SoC (%)"),
        gr.Markdown(label="Model metrics")
    ],
)

# no fixed port → avoids “port already in use”; share=False keeps it local
demo.launch(debug=True)


[GRADIO] Loaded bundle → models_eis_phase2_phys\eis_soc_soh_phys_models.joblib
* Running on local URL:  http://127.0.0.1:7865
* To create a public link, set `share=True` in `launch()`.


# Updated