In [4]:

from __future__ import annotations

# ── std / third-party imports ─────────────────────────────────────────
import argparse, json, math, os, pickle, random, warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Tuple

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from scipy.io import loadmat
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import r2_score
from sklearn.model_selection import GroupKFold

# ── 0 · GLOBALS  ──────────────────────────────────────────────────────
CANON_FREQ = np.logspace(-2, 4, 50)  # 0.01 Hz → 10 kHz


def set_seed(seed: int = 17) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)


def to_jsonable(o: Any) -> Any:
    if isinstance(o, (np.integer, np.floating)):
        return o.item()
    if isinstance(o, (np.ndarray, list, tuple, set)):
        return [to_jsonable(x) for x in o]
    if isinstance(o, dict):
        return {k: to_jsonable(v) for k, v in o.items()}
    if isinstance(o, Path):
        return str(o)
    return o


# ── 1 · PATH CONFIG  ─────────────────────────────────────────────────
@dataclass
class Config:
    EIS_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\EIS_Test")
    CAP_DIR: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\DIB_Data\.matfiles\Capacity_Check")
    MODEL_DIR: Path = Path("models_eis_phase2_phys")
    EIS_TEST_FILE: Path = Path(r"C:\Users\tgondal0\OneDrive - Edith Cowan University\00 - Megallan Power\NMC Batteries Warwick Station\NMC\TestFile\Mazda-Battery-Cell5.xlsx")

    FORCE_RETRAIN: bool = False
    INCLUDE_DRT: bool = True
    USE_PCA_SOC_REG: bool = True
    PCA_SOC_REG_COMPONENTS: int = 25
    MAX_GP_TRAIN_SAMPLES_SOC: int = 8_000
    MAX_GP_TRAIN_SAMPLES_RUL: int = 8_000
    SOC_STD_CLAMP: float = 3.0
    RANDOM_STATE: int = 17


cfg = Config()
set_seed(cfg.RANDOM_STATE)

# ── 2 · HELPERS – parsing & feature engineering ──────────────────────
import re
_EIS_RE = re.compile(r"(?P<cell>[A-Za-z0-9\-]+)[_\-]?(?:Cycl[e]?|CYL)?(?P<cycle>\d+)[_\-]?(?:SoC)?(?P<soc>\d+)?", re.I)
_CAP_RE = re.compile(r"(?P<cell>[A-Za-z0-9\-]+)[_\-]?(?:Cap|CYC|Cycle)?(?P<cycle>\d+)", re.I)


def parse_eis_metadata(fp: Path) -> Dict[str, Any]:
    m = _EIS_RE.search(fp.stem)
    return dict(cell_id=fp.stem, cycle_idx=np.nan, soc=np.nan) if not m else \
        dict(cell_id=m["cell"], cycle_idx=int(m["cycle"]), soc=float(m["soc"]) if m["soc"] else np.nan)


def parse_cap_metadata(fp: Path) -> Dict[str, Any]:
    m = _CAP_RE.search(fp.stem)
    return dict(cell_id=fp.stem, cycle_idx=np.nan) if not m else \
        dict(cell_id=m["cell"], cycle_idx=int(m["cycle"]))


def _load_one_capacity(fp: Path) -> pd.DataFrame:
    df = pd.read_csv(fp) if fp.suffix.lower() == ".csv" else pd.read_excel(fp, engine="openpyxl" if fp.suffix.lower()==".xlsx" else None)
    df = df.rename(columns=lambda s: s.strip().lower())
    if "capacity" not in df.columns:
        for alt in ("capacity_ah", "cap_ah", "capacity (ah)"):
            if alt in df.columns:
                df["capacity"] = df[alt]; break
    if "cycle" not in df.columns:
        raise ValueError(f"{fp.name} missing 'cycle' column")
    meta = parse_cap_metadata(fp)
    df = df[["cycle", "capacity"]].copy()
    df.insert(0, "cell_id", meta["cell_id"])
    return df


def load_capacity_info(cap_dir: Path) -> pd.DataFrame:
    dfs: List[pd.DataFrame] = []
    for fp in cap_dir.rglob("*"):
        if fp.suffix.lower() in {".csv", ".xls", ".xlsx"}:
            try:
                dfs.append(_load_one_capacity(fp))
            except Exception as e:
                warnings.warn(f"Skip capacity file {fp.name}: {e}")
    if not dfs:
        raise FileNotFoundError(f"No capacity files in {cap_dir}")
    cap = pd.concat(dfs, ignore_index=True)
    cap["init_cap"] = cap.groupby("cell_id")["capacity"].transform("first")
    cap["soh_percent"] = 100 * cap["capacity"] / cap["init_cap"]
    return cap


def build_cpp_map(cap_dir: Path) -> Dict[float, float]:
    cap = load_capacity_info(cap_dir)
    cap["cycles_remaining"] = cap.groupby("cell_id")["cycle"].transform("max") - cap["cycle"]
    med = cap.groupby(cap["soh_percent"].round())["cycles_remaining"].median()
    return med.dropna().to_dict()


def get_cpp(cpp: float, cpp_map: Dict[float, float]) -> float:
    if not cpp_map or math.isnan(cpp):
        return np.nan
    keys = np.array(list(cpp_map))
    return float(cpp_map[keys[np.argmin(np.abs(keys - cpp))]])


# ─── drop-in replacement for `_read_eis` ──────────────────────────
def _read_eis(fp: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Return (frequency [Hz], Zreal [Ω], Zimag [Ω]) from .csv/.xls/.mat files.

    • For .mat we accept several common key variants:
        - freq  | frequency  | f
        - Zreal | Zre        | Z_real
        - Zimag | Zim        | Z_imag
      Keys are matched case-insensitively.
    """
    suf = fp.suffix.lower()

    # ---------- text / Excel ----------
    if suf in {".csv", ".txt"}:
        df = pd.read_csv(fp)
    elif suf in {".xls", ".xlsx"}:
        df = pd.read_excel(fp, engine="openpyxl" if suf == ".xlsx" else None)
    # ---------- MATLAB ----------
    elif suf == ".mat":
        mat = loadmat(fp)
        # normalise keys to lowercase for easier lookup
        lk = {k.lower(): k for k in mat.keys()}
        def grab(cands: List[str]) -> np.ndarray:
            for c in cands:
                if c.lower() in lk:
                    return np.squeeze(mat[lk[c.lower()]])
            raise KeyError(f"{fp.name}: none of {cands} found")
        freq = grab(["freq", "frequency", "f"])
        zre  = grab(["zreal", "zre", "z_real"])
        zim  = grab(["zimag", "zim", "z_imag"])
        return freq, zre, zim
    else:
        raise ValueError(f"Unsupported EIS file type: {fp.suffix}")

    # ---------- tidy dataframe ----------
    df = df.rename(columns=lambda s: s.strip().lower())
    freq = df.iloc[:, 0].to_numpy(float)
    zre  = df.iloc[:, 1].to_numpy(float)
    zim  = df.iloc[:, 2].to_numpy(float)
    return freq, zre, zim



def _interp(freq, arr): return 10**interp1d(np.log10(freq), np.log10(np.abs(arr)), fill_value="extrapolate")(np.log10(CANON_FREQ))


def featurize_any(fp: Path, include_drt: bool=True) -> np.ndarray:
    f,zr,zi = _read_eis(fp)
    mag = np.sqrt(zr**2+zi**2)
    pha = np.degrees(np.arctan2(zi,zr))
    feats = [_interp(f,mag), _interp(f,pha)]
    if include_drt: feats.append(np.zeros_like(CANON_FREQ))  # placeholder DRT
    return np.concatenate(feats).astype(np.float32)


def build_dataset(eis_dir: Path, cap_dir: Path, include_drt=True) -> pd.DataFrame:
    rows=[]
    for fp in eis_dir.rglob("*"):
        if fp.suffix.lower() in {".csv",".txt",".xls",".xlsx",".mat"}:
            meta=parse_eis_metadata(fp)
            rows.append(dict(features=featurize_any(fp,include_drt),
                             cell_id=meta["cell_id"],
                             cycle_idx=meta["cycle_idx"],
                             soc=meta["soc"]/100 if not math.isnan(meta["soc"]) else np.nan,
                             file_id=fp.stem))
    if not rows: raise FileNotFoundError("No EIS files found")
    df=pd.DataFrame(rows)
    try:
        fin=load_capacity_info(cap_dir).groupby("cell_id")["cycle"].max().rename("final_cycle").reset_index()
        df=df.merge(fin,on="cell_id",how="left")
    except FileNotFoundError: df["final_cycle"]=np.nan
    return df


def build_feature_vector(test_fp: Path, include_drt=True) -> Dict[str,Any]:
    feats=featurize_any(test_fp,include_drt)
    meta=parse_eis_metadata(test_fp)
    cpp=100.0
    try:
        caps=load_capacity_info(test_fp.parents[3])
        row=caps[(caps.cell_id==meta["cell_id"])&(caps.cycle==meta["cycle_idx"])]
        if not row.empty: cpp=float(row.soh_percent.iloc[0])
    except Exception: pass
    return {"features":feats,"cpp":cpp}


def _cell_soh_curve(cell_id:str,cap_dir:Path)->Tuple[np.ndarray,np.ndarray]:
    try:
        df=load_capacity_info(cap_dir); sub=df[df.cell_id==cell_id]
        if sub.empty: raise ValueError
        return sub.cycle.to_numpy(), sub.soh_percent.to_numpy()
    except Exception:
        cy=np.linspace(0,1200,240); return cy, 80+20*np.exp(-cy/400)


def plot_projection(test_fp:Path)->Tuple[np.ndarray,np.ndarray]:
    meta=parse_eis_metadata(test_fp)
    return _cell_soh_curve(meta["cell_id"], test_fp.parents[2])

# ── 3 · MODEL TRAINING  ────────────────────────────────────────────
def _subsample(X,y,max_n):
    if len(X)<=max_n: return X,y
    idx=np.random.default_rng(cfg.RANDOM_STATE).choice(len(X),max_n,replace=False)
    return X[idx],y[idx]

def _maybe_pca(X,fit,pca=None):
    if not cfg.USE_PCA_SOC_REG: return X,None if fit else pca
    if fit:
        pca=PCA(cfg.PCA_SOC_REG_COMPONENTS,random_state=cfg.RANDOM_STATE); return pca.fit_transform(X),pca
    return pca.transform(X),pca

def _cv_r2(est,X,y,groups):
    gkf=GroupKFold(n_splits=min(5,len(np.unique(groups))))
    sc=[]
    for tr,te in gkf.split(X,y,groups):
        e=pickle.loads(pickle.dumps(est)); e.fit(X[tr],y[tr]); sc.append(r2_score(y[te],e.predict(X[te])))
    return float(np.mean(sc))

def train_soc(df):
    X=np.vstack(df.features.values); y=df.soc_percent.values
    groups=df.cell_id.values
    X,y=_subsample(X,y,cfg.MAX_GP_TRAIN_SAMPLES_SOC)
    Xp,pca=_maybe_pca(X,True)
    try:
        gp=GaussianProcessRegressor(RBF(np.ones(Xp.shape[1]))+WhiteKernel(),alpha=1e-6,normalize_y=True,random_state=cfg.RANDOM_STATE)
        gp.fit(Xp,y); model=("gp",gp)
    except MemoryError:
        rf=RandomForestRegressor(400,random_state=cfg.RANDOM_STATE); rf.fit(Xp,y); model=("rf",rf)
    return {"model_type":model[0],"est":model[1],"pca":pca,"r2":_cv_r2(model[1],Xp,y,groups)}

def train_rul(df):
    df=df.dropna(subset=["cycles_remaining"])
    X=np.vstack(df.features.values); y=df.cycles_remaining.values
    groups=df.cell_id.values
    X,y=_subsample(X,y,cfg.MAX_GP_TRAIN_SAMPLES_RUL)
    gp=GaussianProcessRegressor(RBF(np.ones(X.shape[1]))+WhiteKernel(),alpha=1e-4,normalize_y=True,random_state=cfg.RANDOM_STATE)
    gp.fit(X,y)
    return {"gp":gp,"r2":_cv_r2(gp,X,y,groups)}

# ── 4 · INFERENCE  ────────────────────────────────────────────────
def predict_soc(bundle,feat):
    X=feat.reshape(1,-1)
    if bundle["pca"] is not None: X,_=_maybe_pca(X,False,bundle["pca"])
    if bundle["model_type"]=="gp":
        mean,std=bundle["est"].predict(X,return_std=True)
    else:
        mean=bundle["est"].predict(X); std=np.std([t.predict(X) for t in bundle["est"].estimators_],ddof=1)
    return float(mean),float(np.clip(std,0.0,cfg.SOC_STD_CLAMP))

def predict_rul(gp,feat):
    mean,std=gp.predict(feat.reshape(1,-1),return_std=True)
    return float(mean),float(std)

# ── 5 · I/O helpers ───────────────────────────────────────────────
def save_json(p,data): p.write_text(json.dumps(to_jsonable(data),indent=2))

def save_plot(p,cycles,mean,std,ref):
    plt.figure(figsize=(6,4))
    plt.plot(cycles,mean,lw=2,label="GP mean")
    plt.fill_between(cycles,mean-std,mean+std,alpha=0.3,label="GP ±1σ")
    plt.plot(cycles,ref,"--",lw=2,label="Parametric decay (ref)")
    plt.xlabel("Cycle"); plt.ylabel("SoH (%)"); plt.legend(); plt.tight_layout(); plt.savefig(p,dpi=300); plt.close()

# ── 6 · End-to-end routine ────────────────────────────────────────
def train_and_save():
    df=build_dataset(cfg.EIS_DIR,cfg.CAP_DIR,cfg.INCLUDE_DRT)
    df["cycles_remaining"]=df.final_cycle-df.cycle_idx
    df["soc_percent"]=df.soc*100
    soc=train_soc(df); rul=train_rul(df)
    cfg.MODEL_DIR.mkdir(parents=True,exist_ok=True)
    joblib.dump(soc,cfg.MODEL_DIR/"soc.pkl"); joblib.dump(rul,cfg.MODEL_DIR/"rul.pkl")
    save_json(cfg.MODEL_DIR/"metrics.json",{"soc_R2":soc["r2"],"rul_R2":rul["r2"]})
    print(f"Training finished · SoC R²={soc['r2']:.3f}  RUL R²={rul['r2']:.3f}")

def infer(test_fp:Path):
    soc=joblib.load(cfg.MODEL_DIR/"soc.pkl"); rul=joblib.load(cfg.MODEL_DIR/"rul.pkl")
    d=build_feature_vector(test_fp,cfg.INCLUDE_DRT); feat=d["features"]
    soc_m,soc_s=predict_soc(soc,feat); rul_m,rul_s=predict_rul(rul["gp"],feat)
    cpp_map=build_cpp_map(cfg.CAP_DIR); cpp_est=get_cpp(d["cpp"],cpp_map)
    base=test_fp.with_suffix("")
    save_json(Path(f"{base}_prediction.json"),
              dict(predicted_SoC_percent=soc_m,SoC_std_estimate=soc_s,
                   predicted_cycles_remaining=rul_m,cycles_remaining_std=rul_s,
                   fallback_cpp_cycles_remaining=cpp_est))
    cy,ref=plot_projection(test_fp)
    save_plot(Path(f"{base}_projection.png"),cy,rul_m-cy,np.full_like(cy,rul_s),ref)
    print("Inference done → JSON & PNG saved.")

# ── 7 · CLI ───────────────────────────────────────────────────────
# ── 7 · CLI ───────────────────────────────────────────────────────
def main():
    ap = argparse.ArgumentParser(description="One-file EIS training & inference")
    ap.add_argument("--test", type=str, default=str(cfg.EIS_TEST_FILE))
    ap.add_argument("--retrain", action="store_true")
    
    # ↓↓↓  ONE-LINE FIX: ignore any extra flags Jupyter adds (like “-f …json”)
    args, _ = ap.parse_known_args()          # <= replaces:  args = ap.parse_args()

    if args.retrain or cfg.FORCE_RETRAIN or not (cfg.MODEL_DIR / "soc.pkl").exists():
        train_and_save()
    infer(Path(args.test))

# ------------------------------------------------------------------
# run automatically in both contexts
if __name__ == "__main__":
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore")
        main()


KeyError: "Cell02_95SOH_15degC_05SOC_9505.mat: none of ['freq', 'frequency', 'f'] found"