In [None]:
#!/usr/bin/env python3
"""

Generates open-loop and closed-loop forecasts for the SFC data, with calibrated
CQR prediction bands + full metrics & plots.

Mapping to Fig S5 (page 65) as per info given:
    • "prediction model" in the paper  →  **closed-loop** here
    • "simulation model"               →  **open-loop** here
"""

# ──  Imports  ──────────────────────────────────────────────────────────────
from __future__ import annotations

import json, pickle
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import trange
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ──  Paths & runtime constants  ────────────────────────────────────────────
TEST_DIR   = Path(r"C:/Users/nishi/Desktop/MLME Project/dta")
MODEL_ROOT = Path(r"C:/Users/nishi/Desktop/MLME Project/model_5files")
OUT_DIR    = MODEL_ROOT/'Results_CQR_5files'; OUT_DIR.mkdir(exist_ok=True)

# ----------  Load metadata to stay in sync with training  -----------------

# --- Unit configuration ---------------------------------------------------
#USE_MICRONS = False        # True  ➜ internally work in µm  
PSD_COLS    = ('d10', 'd50', 'd90')


meta         = json.loads((MODEL_ROOT/'metadata.json').read_text())
STATE_COLS   = meta['state_cols']
EXOG_COLS    = meta['exog_cols']
LAG          = meta['lag']
CLUST_COLS   = STATE_COLS + EXOG_COLS
HORIZON      = 5         # closed-loop rollout length  (= Fig S5 horizon)

print(f"[INFO]  LAG = {LAG},  horizon = {HORIZON}")

# ──  Pre-processing helpers (same as training)  ───────────────────────────
def read_txt(p): return pd.read_csv(p, sep='\t', engine='python'
                                   ).apply(pd.to_numeric, errors='coerce')
def clean_df(df):
    df = df.dropna(subset=CLUST_COLS)
    df = df[(df.T_PM.between(250,400)) & (df.T_TM.between(250,400))
            & (df.d10>0)&(df.d50>0)&(df.d90>0)
            & (df.mf_PM>=0)&(df.mf_TM>=0)&(df.Q_g>=0)]
    return df.reset_index(drop=True)
#def harmonise_units(df):
    """
    Make sure d10 / d50 / d90 are in *micrometres*.

    Rule:
        • If the median of a column is smaller than 1 × 10⁻² (i.e. < 1 cm)
          the data must already be in metres  ➜  multiply by 1 × 10⁶.
        • Otherwise assume it is already µm and leave unchanged.

    Works row-wise, so mixed units inside the same file are also fixed.
    """
    if not USE_MICRONS:
        return df    # fall-back for future experiments

    for col in PSD_COLS:
        median = df[col].median(skipna=True)
        if median < 1e-2:         # < 1 cm  ⇒ data were metres
            df[col] *= 1e6        # m → µm
    return df
def to_metres(df):
    """
    Ensure d10 / d50 / d90 are in metres, regardless of the file’s unit.

    Heuristic:
        • If the column median is > 0.01  (i.e. larger than 1 cm)
          the numbers must be µm  → divide by 1 × 10⁶.
        • Otherwise they are already metres → leave untouched.
    """
    for col in PSD_COLS:
        if df[col].median(skipna=True) > 1e-2:   # > 1 cm ⇒ µm
            df[col] /= 1e6                       # µm → m
    return df


def preprocess(path: Path) -> pd.DataFrame:
    df = clean_df(read_txt(path))
    df = to_metres(df)        # <<< make sure we are in metres
    return df


# ──  Cluster artefacts  ───────────────────────────────────────────────────
sc_feat = pickle.loads((MODEL_ROOT/'feature_scaler.pkl').read_bytes())
kmeans  = pickle.loads((MODEL_ROOT/'kmeans_model.pkl' ).read_bytes())

def file_signature(df):
    """Return same feature vector used during training clustering."""
    arr = df[CLUST_COLS].values
    return np.concatenate([arr.mean(0), arr.std(0), arr.min(0), arr.max(0)]).reshape(1,-1)

def detect_cluster(df) -> int:
    return int(kmeans.predict(sc_feat.transform(file_signature(df)))[0])

# ──  Build lag matrix (newest-to-oldest)  ──────────────────────────────────
def build_lagged(df, lag=LAG):
    rows = []
    for i in range(lag, len(df)-1):          # need y_{t+1} for target
        row = []
        for l in range(0, lag+1):            # 0 … LAG
            idx = i - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        rows.append(row)
    return np.asarray(rows, np.float32)

# ──  Load per-cluster artefacts  ───────────────────────────────────────────
def load_cluster(cid):
    scX = pickle.loads((MODEL_ROOT/f'narx/scaler_X_{cid}.pkl').read_bytes())
    scY = pickle.loads((MODEL_ROOT/f'narx/scaler_Y_{cid}.pkl').read_bytes())
    narx = tf.keras.models.load_model(MODEL_ROOT/f'narx/cluster_{cid}',
                                      compile=False)
    return scX, scY, narx

# ──  Closed-loop rollout  (scaled space)  ──────────────────────────────────
def rollout(model, lag_scaled, horizon, scX, scY, exog_future_raw):
    """
    Predict horizon-step ahead sequence *closed-loop*:
        at each step feed back the *predicted* (scaled) state,
        plus the *true* exogenous input for that step.
    """
    x      = lag_scaled.copy()              # shape = (input_dim,)
    preds  = []
    n_y    = len(STATE_COLS)
    n_u    = len(EXOG_COLS)
    stride = n_y + n_u                      # one (y+u) block per time step

    # handy slices
    y_slice  = slice(0, n_y)                # first part of each block
    u_slice  = slice(n_y, n_y+n_u)          # second part of each block

    for k in range(horizon):
        y_scaled  = model.predict(x[None], verbose=0)[0]     # (n_y,)
        y_raw     = scY.inverse_transform(y_scaled[None])[0]
        preds.append(y_raw)

        # --- push history backwards (newest first layout) ---
        x[stride:] = x[:-stride] * 1.0      # shift older blocks
        x[y_slice] = y_scaled               # newest state (pred)
        # scale & insert exogenous truth for step k
        mu_u   = scX.mean_[u_slice]
        sig_u  = scX.scale_[u_slice]
        x[u_slice] = (exog_future_raw[k] - mu_u) / sig_u

    return np.asarray(preds)

def predict_closed(df, scX, scY, narx, horizon=HORIZON):
    """
    Rolling closed-loop prediction with stride = 1.
    Returns
        df_out   – ground-truth rows matching predictions
        Xs_all   – scaled lag vectors   (for QR nets)
        Yh_all   – raw predictions      (shape rows × |STATE|)
    """
    total   = len(df) - 1 - LAG
    usable  = total - horizon + 1
    Xs_all, Yh_all = [], []

    for t0 in trange(usable, desc="closed", leave=False):
        # lag vector (newest-first)
        row  = []
        for l in range(0, LAG+1):
            idx = t0 + LAG - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        lag_raw  = np.asarray(row, np.float32)
        lag_s    = scX.transform(lag_raw[None])[0]
        exog_f   = df[EXOG_COLS].iloc[t0+1 : t0+1+horizon].values
        y_seq    = rollout(narx, lag_s, horizon, scX, scY, exog_f)
        Xs_all.append(lag_s)
        Yh_all.append(y_seq[-1])            # only last step (t+h)

    df_out = df.iloc[LAG+horizon : LAG+horizon+usable].reset_index(drop=True)
    return df_out, np.vstack(Xs_all), np.vstack(Yh_all)

def predict_open(df, scX, scY, narx):
    X      = build_lagged(df)
    Xs     = scX.transform(X)
    y_pred = scY.inverse_transform(narx.predict(Xs, verbose=0))
    return df.iloc[LAG+1:].reset_index(drop=True), Xs, y_pred

# ──  QR nets & conformal deltas  ───────────────────────────────────────────
QR = {}
for col in STATE_COLS:
    for q in (0.1, 0.9):
        QR[(col, q)] = tf.keras.models.load_model(MODEL_ROOT/f'qr/{col}_{q:.1f}',
                                                  compile=False)
DELTAS = pickle.loads((MODEL_ROOT/'conformal_deltas.pkl').read_bytes())

def add_cqr(df, Xs, base_pred, mode: str):
    """Attach point-pred + CQR bounds to DataFrame."""
    out = df.copy()
    for i, col in enumerate(STATE_COLS):
        lo = QR[(col, 0.1)].predict(Xs, verbose=0).flatten()
        hi = QR[(col, 0.9)].predict(Xs, verbose=0).flatten()
        out[f"{col}_{mode}"]    = base_pred[:, i]
        out[f"{col}_{mode}_lo"] = base_pred[:, i] - lo - DELTAS[col]
        out[f"{col}_{mode}_hi"] = base_pred[:, i] + hi + DELTAS[col]
    return out

# ──  Metrics helper  ───────────────────────────────────────────────────────
def metric_table(df: pd.DataFrame, mode: str):
    res = {}
    for col in STATE_COLS:
        y_true = df[col].values
        y_pred = df[f"{col}_{mode}"].values
        lo     = df[f"{col}_{mode}_lo"].values
        hi     = df[f"{col}_{mode}_hi"].values
        msk    = np.isfinite(y_true) & np.isfinite(y_pred)

        res[f"{col}_MAE"] = mean_absolute_error(y_true[msk], y_pred[msk])
        res[f"{col}_MSE"] = mean_squared_error(y_true[msk], y_pred[msk])
        inside            = (y_true >= lo) & (y_true <= hi)
        res[f"{col}_COV"] = 100. * inside[msk].mean()
    return res

# ----------  Plot helpers  -------------------------------------------------
def plot_ts(df, out, mode):
    t = np.arange(len(df))
    for col in STATE_COLS:
        plt.figure(figsize=(7,3))
        plt.plot(t, df[col],  label='truth', lw=1)
        plt.plot(t, df[f'{col}_{mode}'], label='pred', lw=1)
        plt.fill_between(t, df[f'{col}_{mode}_lo'], df[f'{col}_{mode}_hi'],
                         alpha=.25, label='90 % PI')
        plt.title(col); plt.tight_layout()
        plt.legend()
        plt.savefig(out/f"{col}_{mode}.png", dpi=150)
        plt.close()

def plot_scatter(df, out, mode):
    for col in STATE_COLS:
        plt.figure(figsize=(3.5,3.5))
        plt.scatter(df[col], df[f'{col}_{mode}'], s=8, alpha=.6)
        mn, mx = df[[col, f'{col}_{mode}']].values.min(), \
                 df[[col, f'{col}_{mode}']].values.max()
        plt.plot([mn, mx],[mn, mx],'r--'); plt.title(col); plt.tight_layout()
        plt.savefig(out/f"{col}_{mode}_scatter.png", dpi=150)
        plt.close()

# ──────────────────────────────────────────────────────────────────────────
#  Main loop over test files
# --------------------------------------------------------------------------
summary = []
for p in sorted(TEST_DIR.glob("*.txt")):
    stem  = p.stem
    out_f = OUT_DIR / stem
    out_f.mkdir(exist_ok=True)
    print(f"\n⚙  Processing {stem} …")

    try:
        # 1. preprocess & cluster
        df    = preprocess(p)
        cid   = detect_cluster(df)
        scX, scY, narx = load_cluster(cid)

        # 2. open-loop
        df_o, Xo_s, y_open = predict_open( df, scX, scY, narx)
        df_o = add_cqr(df_o, Xo_s, y_open, mode="open")

        # 3. closed-loop
        df_c, Xc_s, y_closed = predict_closed(df, scX, scY, narx, HORIZON)
        df_c = add_cqr(df_c, Xc_s, y_closed, mode="closed")

        # 4. merge
        df_pred = pd.concat(
            [df_o,
             df_c[[f"{c}_{m}" for c in STATE_COLS
                               for m in ("closed", "closed_lo", "closed_hi")]]],
            axis=1)

        # 5. save & plots
        df_pred.to_csv(out_f/"predictions.csv", index=False)
        plot_ts(df_pred, out_f, mode="open")
        plot_ts(df_pred, out_f, mode="closed")
        plot_scatter(df_pred, out_f, mode="open")
        plot_scatter(df_pred, out_f, mode="closed")

        # 6. metrics
        m_open   = metric_table(df_pred, mode="open")
        m_closed = metric_table(df_pred, mode="closed")
        summary.append(
            {"file": stem, **m_open,
                        **{f"{k}_closed": v for k,v in m_closed.items()}}
        )

    except Exception as e:
        print(f"⨯  {stem} skipped  →  {e}")

# ──  Aggregate summary  ────────────────────────────────────────────────────
df_sum = pd.DataFrame(summary)
df_sum.to_csv(OUT_DIR/"metrics_summary.csv", index=False)

rows = []
for col in STATE_COLS:
    rows.append([
        col,
        df_sum[f"{col}_MAE"].mean(),
        df_sum[f"{col}_MAE_closed"].mean(),
        df_sum[f"{col}_MSE"].mean(),
        df_sum[f"{col}_MSE_closed"].mean(),
        df_sum[f"{col}_COV"].mean(),
        df_sum[f"{col}_COV_closed"].mean()
    ])

print("\n📊  Average error & coverage (open vs closed)\n")
print(pd.DataFrame(
        rows,
        columns=["Var",
                 "Open MAE", "Closed MAE",
                 "Open MSE", "Closed MSE",
                 "Open COV %", "Closed COV %"]
).to_string(index=False, float_format="%.5f"))

[INFO]  LAG = 10,  horizon = 5

⚙  Processing file_20726 …


                                                         


⚙  Processing file_49595 …


                                                         


⚙  Processing file_550 …


                                                         


⚙  Processing file_63816 …


                                                         


⚙  Processing file_96991 …


                                                         


📊  Average error & coverage (open vs closed)

 Var  Open MAE  Closed MAE  Open MSE  Closed MSE  Open COV %  Closed COV %
T_PM   0.16463     0.38134   0.06438     0.48875    99.65622      97.60406
   c   0.00034     0.00074   0.00000     0.00000    94.56016      94.09137
 d10   0.00010     0.00010   0.00000     0.00000    57.59353      56.36548
 d50   0.00011     0.00011   0.00000     0.00000    51.10212      50.17259
 d90   0.00008     0.00008   0.00000     0.00000    46.41052      45.21827
T_TM   0.17469     0.40536   0.07187     0.55408    99.65622      97.52284


In [None]:
#!/usr/bin/env python3
"""

Generates open-loop and closed-loop forecasts for the SFC data, with calibrated
CQR prediction bands + full metrics & plots.

Mapping to Fig S5 (page 65) as per info given:
    • "prediction model" in the paper  →  **closed-loop** here
    • "simulation model"               →  **open-loop** here
"""

# ──  Imports  ──────────────────────────────────────────────────────────────
from __future__ import annotations

import json, pickle
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import trange
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ──  Paths & runtime constants  ────────────────────────────────────────────
TEST_DIR   = Path(r"C:/Users/nishi/Desktop/MLME Project/Data/Test")
MODEL_ROOT = Path(r"C:/Users/nishi/Desktop/MLME Project/model_5files")
OUT_DIR    = MODEL_ROOT/'Results_CQR_5files'; OUT_DIR.mkdir(exist_ok=True)

# ----------  Load metadata to stay in sync with training  -----------------

# --- Unit configuration ---------------------------------------------------
#USE_MICRONS = False        # True  ➜ internally work in µm  
PSD_COLS    = ('d10', 'd50', 'd90')


meta         = json.loads((MODEL_ROOT/'metadata.json').read_text())
STATE_COLS   = meta['state_cols']
EXOG_COLS    = meta['exog_cols']
LAG          = meta['lag']
CLUST_COLS   = STATE_COLS + EXOG_COLS
HORIZON      = 5         # closed-loop rollout length  (= Fig S5 horizon)

print(f"[INFO]  LAG = {LAG},  horizon = {HORIZON}")

# ──  Pre-processing helpers (same as training)  ───────────────────────────
def read_txt(p): return pd.read_csv(p, sep='\t', engine='python'
                                   ).apply(pd.to_numeric, errors='coerce')
def clean_df(df):
    df = df.dropna(subset=CLUST_COLS)
    df = df[(df.T_PM.between(250,400)) & (df.T_TM.between(250,400))
            & (df.d10>0)&(df.d50>0)&(df.d90>0)
            & (df.mf_PM>=0)&(df.mf_TM>=0)&(df.Q_g>=0)]
    return df.reset_index(drop=True)
#def harmonise_units(df):
    """
    Make sure d10 / d50 / d90 are in *micrometres*.

    Rule:
        • If the median of a column is smaller than 1 × 10⁻² (i.e. < 1 cm)
          the data must already be in metres  ➜  multiply by 1 × 10⁶.
        • Otherwise assume it is already µm and leave unchanged.

    Works row-wise, so mixed units inside the same file are also fixed.
    """
    if not USE_MICRONS:
        return df    # fall-back for future experiments

    for col in PSD_COLS:
        median = df[col].median(skipna=True)
        if median < 1e-2:         # < 1 cm  ⇒ data were metres
            df[col] *= 1e6        # m → µm
    return df
def to_metres(df):
    """
    Ensure d10 / d50 / d90 are in metres, regardless of the file’s unit.

    Heuristic:
        • If the column median is > 0.01  (i.e. larger than 1 cm)
          the numbers must be µm  → divide by 1 × 10⁶.
        • Otherwise they are already metres → leave untouched.
    """
    for col in PSD_COLS:
        if df[col].median(skipna=True) > 1e-2:   # > 1 cm ⇒ µm
            df[col] /= 1e6                       # µm → m
    return df


def preprocess(path: Path) -> pd.DataFrame:
    df = clean_df(read_txt(path))
    df = to_metres(df)        # <<< make sure we are in metres
    return df


# ──  Cluster artefacts  ───────────────────────────────────────────────────
sc_feat = pickle.loads((MODEL_ROOT/'feature_scaler.pkl').read_bytes())
kmeans  = pickle.loads((MODEL_ROOT/'kmeans_model.pkl' ).read_bytes())

def file_signature(df):
   # Return same feature vector used during training clustering.
    arr = df[CLUST_COLS].values
    return np.concatenate([arr.mean(0), arr.std(0), arr.min(0), arr.max(0)]).reshape(1,-1)

def detect_cluster(df) -> int:
    return int(kmeans.predict(sc_feat.transform(file_signature(df)))[0])

# ──  Build lag matrix (newest-to-oldest)  ──────────────────────────────────
def build_lagged(df, lag=LAG):
    rows = []
    for i in range(lag, len(df)-1):          # need y_{t+1} for target
        row = []
        for l in range(0, lag+1):            # 0 … LAG
            idx = i - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        rows.append(row)
    return np.asarray(rows, np.float32)

# ──  Load per-cluster artefacts  ───────────────────────────────────────────
def load_cluster(cid):
    scX = pickle.loads((MODEL_ROOT/f'narx/scaler_X_{cid}.pkl').read_bytes())
    scY = pickle.loads((MODEL_ROOT/f'narx/scaler_Y_{cid}.pkl').read_bytes())
    narx = tf.keras.models.load_model(MODEL_ROOT/f'narx/cluster_{cid}',
                                      compile=False)
    return scX, scY, narx

# ──  Closed-loop rollout  (scaled space)  ──────────────────────────────────
def rollout(model, lag_scaled, horizon, scX, scY, exog_future_raw):
    """
    Predict horizon-step ahead sequence *closed-loop*:
        at each step feed back the *predicted* (scaled) state,
        plus the *true* exogenous input for that step.
    """
    x      = lag_scaled.copy()              # shape = (input_dim,)
    preds  = []
    n_y    = len(STATE_COLS)
    n_u    = len(EXOG_COLS)
    stride = n_y + n_u                      # one (y+u) block per time step

    # handy slices
    y_slice  = slice(0, n_y)                # first part of each block
    u_slice  = slice(n_y, n_y+n_u)          # second part of each block

    for k in range(horizon):
        y_scaled  = model.predict(x[None], verbose=0)[0]     # (n_y,)
        y_raw     = scY.inverse_transform(y_scaled[None])[0]
        preds.append(y_raw)

        # --- push history backwards (newest first layout) ---
        x[stride:] = x[:-stride] * 1.0      # shift older blocks
        x[y_slice] = y_scaled               # newest state (pred)
        # scale & insert exogenous truth for step k
        mu_u   = scX.mean_[u_slice]
        sig_u  = scX.scale_[u_slice]
        x[u_slice] = (exog_future_raw[k] - mu_u) / sig_u

    return np.asarray(preds)

def predict_closed(df, scX, scY, narx, horizon=HORIZON):
    """
    Rolling closed-loop prediction with stride = 1.
    Returns
        df_out   – ground-truth rows matching predictions
        Xs_all   – scaled lag vectors   (for QR nets)
        Yh_all   – raw predictions      (shape rows × |STATE|)
    """
    total   = len(df) - 1 - LAG
    usable  = total - horizon + 1
    Xs_all, Yh_all = [], []

    for t0 in trange(usable, desc="closed", leave=False):
        # lag vector (newest-first)
        row  = []
        for l in range(0, LAG+1):
            idx = t0 + LAG - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        lag_raw  = np.asarray(row, np.float32)
        lag_s    = scX.transform(lag_raw[None])[0]
        exog_f   = df[EXOG_COLS].iloc[t0+1 : t0+1+horizon].values
        y_seq    = rollout(narx, lag_s, horizon, scX, scY, exog_f)
        Xs_all.append(lag_s)
        Yh_all.append(y_seq[-1])            # only last step (t+h)

    df_out = df.iloc[LAG+horizon : LAG+horizon+usable].reset_index(drop=True)
    return df_out, np.vstack(Xs_all), np.vstack(Yh_all)

def predict_open(df, scX, scY, narx):
    X      = build_lagged(df)
    Xs     = scX.transform(X)
    y_pred = scY.inverse_transform(narx.predict(Xs, verbose=0))
    return df.iloc[LAG+1:].reset_index(drop=True), Xs, y_pred

# ──  QR nets & conformal deltas  ───────────────────────────────────────────
QR = {}
for col in STATE_COLS:
    for q in (0.1, 0.9):
        QR[(col, q)] = tf.keras.models.load_model(MODEL_ROOT/f'qr/{col}_{q:.1f}',
                                                  compile=False)
DELTAS = pickle.loads((MODEL_ROOT/'conformal_deltas.pkl').read_bytes())

def add_cqr(df, Xs, base_pred, mode: str):
   # Attach point-pred + CQR bounds to DataFrame."""
    out = df.copy()
    for i, col in enumerate(STATE_COLS):
        lo = QR[(col, 0.1)].predict(Xs, verbose=0).flatten()
        hi = QR[(col, 0.9)].predict(Xs, verbose=0).flatten()
        out[f"{col}_{mode}"]    = base_pred[:, i]
        out[f"{col}_{mode}_lo"] = base_pred[:, i] - lo - DELTAS[col]
        out[f"{col}_{mode}_hi"] = base_pred[:, i] + hi + DELTAS[col]
    return out

# ──  Metrics helper  ───────────────────────────────────────────────────────
def metric_table(df: pd.DataFrame, mode: str):
    res = {}
    for col in STATE_COLS:
        y_true = df[col].values
        y_pred = df[f"{col}_{mode}"].values
        lo     = df[f"{col}_{mode}_lo"].values
        hi     = df[f"{col}_{mode}_hi"].values
        msk    = np.isfinite(y_true) & np.isfinite(y_pred)

        res[f"{col}_MAE"] = mean_absolute_error(y_true[msk], y_pred[msk])
        res[f"{col}_MSE"] = mean_squared_error(y_true[msk], y_pred[msk])
        inside            = (y_true >= lo) & (y_true <= hi)
        res[f"{col}_COV"] = 100. * inside[msk].mean()
    return res

# ----------  Plot helpers  -------------------------------------------------
def plot_ts(df, out, mode):
    t = np.arange(len(df))
    for col in STATE_COLS:
        plt.figure(figsize=(7,3))
        plt.plot(t, df[col],  label='truth', lw=1)
        plt.plot(t, df[f'{col}_{mode}'], label='pred', lw=1)
        plt.fill_between(t, df[f'{col}_{mode}_lo'], df[f'{col}_{mode}_hi'],
                         alpha=.25, label='90 % PI')
        plt.title(col); plt.tight_layout()
        plt.legend()
        plt.savefig(out/f"{col}_{mode}.png", dpi=150)
        plt.close()

def plot_scatter(df, out, mode):
    for col in STATE_COLS:
        plt.figure(figsize=(3.5,3.5))
        plt.scatter(df[col], df[f'{col}_{mode}'], s=8, alpha=.6)
        mn, mx = df[[col, f'{col}_{mode}']].values.min(), \
                 df[[col, f'{col}_{mode}']].values.max()
        plt.plot([mn, mx],[mn, mx],'r--'); plt.title(col); plt.tight_layout()
        plt.savefig(out/f"{col}_{mode}_scatter.png", dpi=150)
        plt.close()

# ──────────────────────────────────────────────────────────────────────────
#  Main loop over test files
# --------------------------------------------------------------------------
summary = []
for p in sorted(TEST_DIR.glob("*.txt")):
    stem  = p.stem
    out_f = OUT_DIR / stem
    out_f.mkdir(exist_ok=True)
    print(f"\n⚙  Processing {stem} …")

    try:
        # 1. preprocess & cluster
        df    = preprocess(p)
        cid   = detect_cluster(df)
        scX, scY, narx = load_cluster(cid)

        # 2. open-loop
        df_o, Xo_s, y_open = predict_open( df, scX, scY, narx)
        df_o = add_cqr(df_o, Xo_s, y_open, mode="open")

        # 3. closed-loop
        df_c, Xc_s, y_closed = predict_closed(df, scX, scY, narx, HORIZON)
        df_c = add_cqr(df_c, Xc_s, y_closed, mode="closed")

        # 4. merge
        df_pred = pd.concat(
            [df_o,
             df_c[[f"{c}_{m}" for c in STATE_COLS
                               for m in ("closed", "closed_lo", "closed_hi")]]],
            axis=1)

        # 5. save & plots
        df_pred.to_csv(out_f/"predictions.csv", index=False)
        plot_ts(df_pred, out_f, mode="open")
        plot_ts(df_pred, out_f, mode="closed")
        plot_scatter(df_pred, out_f, mode="open")
        plot_scatter(df_pred, out_f, mode="closed")

        # 6. metrics
        m_open   = metric_table(df_pred, mode="open")
        m_closed = metric_table(df_pred, mode="closed")
        summary.append(
            {"file": stem, **m_open,
                        **{f"{k}_closed": v for k,v in m_closed.items()}}
        )

    except Exception as e:
        print(f"⨯  {stem} skipped  →  {e}")

# ──  Aggregate summary  ────────────────────────────────────────────────────
df_sum = pd.DataFrame(summary)
df_sum.to_csv(OUT_DIR/"metrics_summary.csv", index=False)

rows = []
for col in STATE_COLS:
    rows.append([
        col,
        df_sum[f"{col}_MAE"].mean(),
        df_sum[f"{col}_MAE_closed"].mean(),
        df_sum[f"{col}_MSE"].mean(),
        df_sum[f"{col}_MSE_closed"].mean(),
        df_sum[f"{col}_COV"].mean(),
        df_sum[f"{col}_COV_closed"].mean()
    ])

print("\n Average error & coverage (open vs closed)\n")
print(pd.DataFrame(
        rows,
        columns=["Var",
                 "Open MAE", "Closed MAE",
                 "Open MSE", "Closed MSE",
                 "Open COV %", "Closed COV %"]
).to_string(index=False, float_format="%0.8f"))

[INFO]  LAG = 10,  horizon = 5

⚙  Processing file_16361 …


                                                         


⚙  Processing file_20726 …


                                                         


⚙  Processing file_22636 …


                                                         


⚙  Processing file_3388 …


                                                         


⚙  Processing file_34890 …


                                                         


⚙  Processing file_39455 …


                                                         


⚙  Processing file_41551 …


                                                         


⚙  Processing file_49595 …


                                                         


⚙  Processing file_5325 …


                                                         


⚙  Processing file_54889 …


                                                         


⚙  Processing file_550 …


                                                         


⚙  Processing file_56035 …


                                                         


⚙  Processing file_62851 …


                                                         


⚙  Processing file_63816 …


                                                         


⚙  Processing file_68111 …


                                                         


⚙  Processing file_77484 …


                                                         


⚙  Processing file_82278 …


                                                         


⚙  Processing file_87603 …


                                                         


⚙  Processing file_96991 …


                                                         


⚙  Processing file_9985 …


                                                         


 Average error & coverage (open vs closed)

 Var   Open MAE  Closed MAE   Open MSE  Closed MSE  Open COV %  Closed COV %
T_PM 0.14411705  0.30596885 0.05252401  0.34400150 99.81294237   98.63959391
   c 0.00030193  0.00062255 0.00000025  0.00000127 60.55611729   60.09137056
 d10 0.00010021  0.00010313 0.00000186  0.00000186  7.52780586    7.44162437
 d50 0.00010400  0.00010712 0.00000189  0.00000190 10.15166835   10.13197970
 d90 0.00011353  0.00011639 0.00000145  0.00000146 12.14863498   12.14720812
T_TM 0.15961778  0.32704550 0.06174111  0.37979913 99.81294237   98.73604061


In [None]:
#!/usr/bin/env python3
"""

Generates open-loop and closed-loop forecasts for the SFC data, with calibrated
CQR prediction bands + full metrics & plots.

Mapping to Fig S5 (page 65) as per info given:
    • "prediction model" in the paper  →  **closed-loop** here
    • "simulation model"               →  **open-loop** here
"""

# ──  Imports  ──────────────────────────────────────────────────────────────
from __future__ import annotations

import json, pickle
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import trange
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ──  Paths & runtime constants  ────────────────────────────────────────────
TEST_DIR   = Path(r"C:/Users/nishi/Desktop/MLME Project/Beat-the-Felix")
MODEL_ROOT = Path(r"C:/Users/nishi/Desktop/MLME Project/model_5files")
OUT_DIR    = MODEL_ROOT/'BEAT'; OUT_DIR.mkdir(exist_ok=True)

# ----------  Load metadata to stay in sync with training  -----------------

# --- Unit configuration ---------------------------------------------------
#USE_MICRONS = False        # True  ➜ internally work in µm  
PSD_COLS    = ('d10', 'd50', 'd90')




meta         = json.loads((MODEL_ROOT/'metadata.json').read_text())
STATE_COLS   = meta['state_cols']
EXOG_COLS    = meta['exog_cols']
LAG          = meta['lag']
CLUST_COLS   = STATE_COLS + EXOG_COLS
HORIZON      = 5         # closed-loop rollout length  (= Fig S5 horizon)

print(f"[INFO]  LAG = {LAG},  horizon = {HORIZON}")

# ──  Pre-processing helpers (same as training)  ───────────────────────────
def read_txt(p): return pd.read_csv(p, sep='\t', engine='python'
                                   ).apply(pd.to_numeric, errors='coerce')
def clean_df(df):
    df = df.dropna(subset=CLUST_COLS)
    df = df[(df.T_PM.between(250,400)) & (df.T_TM.between(250,400))
            & (df.d10>0)&(df.d50>0)&(df.d90>0)
            & (df.mf_PM>=0)&(df.mf_TM>=0)&(df.Q_g>=0)]
    return df.reset_index(drop=True)
#def harmonise_units(df):
    """
    Make sure d10 / d50 / d90 are in *micrometres*.

    Rule:
        • If the median of a column is smaller than 1 × 10⁻² (i.e. < 1 cm)
          the data must already be in metres  ➜  multiply by 1 × 10⁶.
        • Otherwise assume it is already µm and leave unchanged.

    Works row-wise, so mixed units inside the same file are also fixed.
    """
    if not USE_MICRONS:
        return df    # fall-back for future experiments

    for col in PSD_COLS:
        median = df[col].median(skipna=True)
        if median < 1e-2:         # < 1 cm  ⇒ data were metres
            df[col] *= 1e6        # m → µm
    return df
def to_metres(df):
    """
    Ensure d10 / d50 / d90 are in metres, regardless of the file’s unit.

    Heuristic:
        • If the column median is > 0.01  (i.e. larger than 1 cm)
          the numbers must be µm  → divide by 1 × 10⁶.
        • Otherwise they are already metres → leave untouched.
    """
    for col in PSD_COLS:
        if df[col].median(skipna=True) > 1e-2:   # > 1 cm ⇒ µm
            df[col] /= 1e6                       # µm → m
    return df


def preprocess(path: Path) -> pd.DataFrame:
    df = clean_df(read_txt(path))
    df = to_metres(df)        # <<< make sure we are in metres
    return df


# ──  Cluster artefacts  ───────────────────────────────────────────────────
sc_feat = pickle.loads((MODEL_ROOT/'feature_scaler.pkl').read_bytes())
kmeans  = pickle.loads((MODEL_ROOT/'kmeans_model.pkl' ).read_bytes())

def file_signature(df):
   # Return same feature vector used during training clustering.
    arr = df[CLUST_COLS].values
    return np.concatenate([arr.mean(0), arr.std(0), arr.min(0), arr.max(0)]).reshape(1,-1)

def detect_cluster(df) -> int:
    return int(kmeans.predict(sc_feat.transform(file_signature(df)))[0])

# ──  Build lag matrix (newest-to-oldest)  ──────────────────────────────────
def build_lagged(df, lag=LAG):
    rows = []
    for i in range(lag, len(df)-1):          # need y_{t+1} for target
        row = []
        for l in range(0, lag+1):            # 0 … LAG
            idx = i - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        rows.append(row)
    return np.asarray(rows, np.float32)

# ──  Load per-cluster artefacts  ───────────────────────────────────────────
def load_cluster(cid):
    scX = pickle.loads((MODEL_ROOT/f'narx/scaler_X_{cid}.pkl').read_bytes())
    scY = pickle.loads((MODEL_ROOT/f'narx/scaler_Y_{cid}.pkl').read_bytes())
    narx = tf.keras.models.load_model(MODEL_ROOT/f'narx/cluster_{cid}',
                                      compile=False)
    return scX, scY, narx

# ──  Closed-loop rollout  (scaled space)  ──────────────────────────────────
def rollout(model, lag_scaled, horizon, scX, scY, exog_future_raw):
    """
    Predict horizon-step ahead sequence *closed-loop*:
        at each step feed back the *predicted* (scaled) state,
        plus the *true* exogenous input for that step.
    """
    x      = lag_scaled.copy()              # shape = (input_dim,)
    preds  = []
    n_y    = len(STATE_COLS)
    n_u    = len(EXOG_COLS)
    stride = n_y + n_u                      # one (y+u) block per time step

    # handy slices
    y_slice  = slice(0, n_y)                # first part of each block
    u_slice  = slice(n_y, n_y+n_u)          # second part of each block

    for k in range(horizon):
        y_scaled  = model.predict(x[None], verbose=0)[0]     # (n_y,)
        y_raw     = scY.inverse_transform(y_scaled[None])[0]
        preds.append(y_raw)

        # --- push history backwards (newest first layout) ---
        x[stride:] = x[:-stride] * 1.0      # shift older blocks
        x[y_slice] = y_scaled               # newest state (pred)
        # scale & insert exogenous truth for step k
        mu_u   = scX.mean_[u_slice]
        sig_u  = scX.scale_[u_slice]
        x[u_slice] = (exog_future_raw[k] - mu_u) / sig_u

    return np.asarray(preds)

def predict_closed(df, scX, scY, narx, horizon=HORIZON):
    """
    Rolling closed-loop prediction with stride = 1.
    Returns
        df_out   – ground-truth rows matching predictions
        Xs_all   – scaled lag vectors   (for QR nets)
        Yh_all   – raw predictions      (shape rows × |STATE|)
    """
    total   = len(df) - 1 - LAG
    usable  = total - horizon + 1
    Xs_all, Yh_all = [], []

    for t0 in trange(usable, desc="closed", leave=False):
        # lag vector (newest-first)
        row  = []
        for l in range(0, LAG+1):
            idx = t0 + LAG - l
            row.extend(df[CLUST_COLS].iloc[idx].values)
        lag_raw  = np.asarray(row, np.float32)
        lag_s    = scX.transform(lag_raw[None])[0]
        exog_f   = df[EXOG_COLS].iloc[t0+1 : t0+1+horizon].values
        y_seq    = rollout(narx, lag_s, horizon, scX, scY, exog_f)
        Xs_all.append(lag_s)
        Yh_all.append(y_seq[-1])            # only last step (t+h)

    df_out = df.iloc[LAG+horizon : LAG+horizon+usable].reset_index(drop=True)
    return df_out, np.vstack(Xs_all), np.vstack(Yh_all)

def predict_open(df, scX, scY, narx):
    X      = build_lagged(df)
    Xs     = scX.transform(X)
    y_pred = scY.inverse_transform(narx.predict(Xs, verbose=0))
    return df.iloc[LAG+1:].reset_index(drop=True), Xs, y_pred

# ──  QR nets & conformal deltas  ───────────────────────────────────────────
QR = {}
for col in STATE_COLS:
    for q in (0.1, 0.9):
        QR[(col, q)] = tf.keras.models.load_model(MODEL_ROOT/f'qr/{col}_{q:.1f}',
                                                  compile=False)
DELTAS = pickle.loads((MODEL_ROOT/'conformal_deltas.pkl').read_bytes())
# Optional: Adjust deltas to improve coverage
DELTAS['c']    *= 1.3
DELTAS['d10']  *= 2.5
DELTAS['d50']  *= 2.5
DELTAS['d90']  *= 2.5


def add_cqr(df, Xs, base_pred, mode: str):
   # Attach point-pred + CQR bounds to DataFrame."""
    out = df.copy()
    for i, col in enumerate(STATE_COLS):
        lo = QR[(col, 0.1)].predict(Xs, verbose=0).flatten()
        hi = QR[(col, 0.9)].predict(Xs, verbose=0).flatten()
        out[f"{col}_{mode}"]    = base_pred[:, i]
        out[f"{col}_{mode}_lo"] = base_pred[:, i] - lo - DELTAS[col]
        out[f"{col}_{mode}_hi"] = base_pred[:, i] + hi + DELTAS[col]
    return out

# ──  Metrics helper  ───────────────────────────────────────────────────────
def metric_table(df: pd.DataFrame, mode: str):
    res = {}
    for col in STATE_COLS:
        y_true = df[col].values
        y_pred = df[f"{col}_{mode}"].values
        lo     = df[f"{col}_{mode}_lo"].values
        hi     = df[f"{col}_{mode}_hi"].values
        msk    = np.isfinite(y_true) & np.isfinite(y_pred)

        res[f"{col}_MAE"] = mean_absolute_error(y_true[msk], y_pred[msk])
        res[f"{col}_MSE"] = mean_squared_error(y_true[msk], y_pred[msk])
        inside            = (y_true >= lo) & (y_true <= hi)
        res[f"{col}_COV"] = 100. * inside[msk].mean()
    return res

# ----------  Plot helpers  -------------------------------------------------
def plot_ts(df, out, mode):
    t = np.arange(len(df))
    for col in STATE_COLS:
        plt.figure(figsize=(7,3))
        plt.plot(t, df[col],  label='truth', lw=1)
        plt.plot(t, df[f'{col}_{mode}'], label='pred', lw=1)
        plt.fill_between(t, df[f'{col}_{mode}_lo'], df[f'{col}_{mode}_hi'],
                         alpha=.25, label='90 % PI')
        plt.title(col); plt.tight_layout()
        plt.legend()
        plt.savefig(out/f"{col}_{mode}.png", dpi=150)
        plt.close()

def plot_scatter(df, out, mode):
    for col in STATE_COLS:
        plt.figure(figsize=(3.5,3.5))
        plt.scatter(df[col], df[f'{col}_{mode}'], s=8, alpha=.6)
        mn, mx = df[[col, f'{col}_{mode}']].values.min(), \
                 df[[col, f'{col}_{mode}']].values.max()
        plt.plot([mn, mx],[mn, mx],'r--'); plt.title(col); plt.tight_layout()
        plt.savefig(out/f"{col}_{mode}_scatter.png", dpi=150)
        plt.close()

# ──────────────────────────────────────────────────────────────────────────
#  Main loop over test files
# --------------------------------------------------------------------------
summary = []
for p in sorted(TEST_DIR.glob("*.txt")):
    stem  = p.stem
    out_f = OUT_DIR / stem
    out_f.mkdir(exist_ok=True)
    print(f"\n⚙  Processing {stem} …")

    try:
        # 1. preprocess & cluster
        df    = preprocess(p)
        cid   = detect_cluster(df)
        scX, scY, narx = load_cluster(cid)

        # 2. open-loop
        df_o, Xo_s, y_open = predict_open( df, scX, scY, narx)
        df_o = add_cqr(df_o, Xo_s, y_open, mode="open")

        # 3. closed-loop
        df_c, Xc_s, y_closed = predict_closed(df, scX, scY, narx, HORIZON)
        df_c = add_cqr(df_c, Xc_s, y_closed, mode="closed")

        # 4. merge
        df_pred = pd.concat(
            [df_o,
             df_c[[f"{c}_{m}" for c in STATE_COLS
                               for m in ("closed", "closed_lo", "closed_hi")]]],
            axis=1)

        # 5. save & plots
        df_pred.to_csv(out_f/"predictions.csv", index=False)
        plot_ts(df_pred, out_f, mode="open")
        plot_ts(df_pred, out_f, mode="closed")
        plot_scatter(df_pred, out_f, mode="open")
        plot_scatter(df_pred, out_f, mode="closed")

        # 6. metrics
        m_open   = metric_table(df_pred, mode="open")
        m_closed = metric_table(df_pred, mode="closed")
        summary.append(
            {"file": stem, **m_open,
                        **{f"{k}_closed": v for k,v in m_closed.items()}}
        )

    except Exception as e:
        print(f"⨯  {stem} skipped  →  {e}")

# ──  Aggregate summary  ────────────────────────────────────────────────────
df_sum = pd.DataFrame(summary)
df_sum.to_csv(OUT_DIR/"metrics_summary.csv", index=False)

rows = []
for col in STATE_COLS:
    rows.append([
        col,
        df_sum[f"{col}_MAE"].mean(),
        df_sum[f"{col}_MAE_closed"].mean(),
        df_sum[f"{col}_MSE"].mean(),
        df_sum[f"{col}_MSE_closed"].mean(),
        df_sum[f"{col}_COV"].mean(),
        df_sum[f"{col}_COV_closed"].mean()
    ])

print("\n Average error & coverage (open vs closed)\n")
print(pd.DataFrame(
        rows,
        columns=["Var",
                 "Open MAE", "Closed MAE",
                 "Open MSE", "Closed MSE",
                 "Open COV %", "Closed COV %"]
).to_string(index=False, float_format="%0.8f"))

[INFO]  LAG = 10,  horizon = 5

⚙  Processing file_12738 …


                                                         


 Average error & coverage (open vs closed)

 Var   Open MAE  Closed MAE   Open MSE  Closed MSE   Open COV %  Closed COV %
T_PM 0.11271435  0.22819815 0.03088757  0.22546750 100.00000000   98.88324873
   c 0.00019928  0.00045378 0.00000012  0.00000056  81.90091001   81.42131980
 d10 0.00010081  0.00010187 0.00000078  0.00000078  19.31243680   19.39086294
 d50 0.00004028  0.00004010 0.00000001  0.00000001  14.15571284   14.21319797
 d90 0.00011414  0.00011434 0.00000154  0.00000155  10.81900910   10.86294416
T_TM 0.10957127  0.22187008 0.02948032  0.24022710 100.00000000   98.78172589
