In [2]:
import os
import warnings
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")
plt.rcParams["figure.dpi"] = 120

# ==============================
#            PATHS
# ==============================
BASE_DIR = r"C:\Users\sagni\Downloads\Med Assist"
FOLDER1  = os.path.join(BASE_DIR, r"archive\mturkfitbit_export_4.12.16-5.12.16\Fitabase Data 4.12.16-5.12.16")
FOLDER2  = os.path.join(BASE_DIR, r"archive\mturkfitbit_export_3.12.16-4.11.16\Fitabase Data 3.12.16-4.11.16")

PKL_PATH = os.path.join(BASE_DIR, "medassist_preprocess.pkl")
H5_PATH  = os.path.join(BASE_DIR, "medassist_model.h5")
HIST_CSV = os.path.join(BASE_DIR, "medassist_history.csv")  # optional

OUT_LOSS_MAE = os.path.join(BASE_DIR, "medassist_loss_mae.png")
OUT_PVA      = os.path.join(BASE_DIR, "medassist_pred_vs_actual.png")
OUT_CORR     = os.path.join(BASE_DIR, "medassist_corr_heatmap.png")

# ==============================
#     MEMORY & LOADING OPTS
# ==============================
INCLUDE_MINUTE_FILES = False
ROW_LIMIT_PER_CSV    = None           # e.g. 150_000 to hard-cap rows per CSV
SAMPLE_FRAC_PER_CSV  = 0.15
BIG_FILE_MIN_ROWS    = 200_000
RANDOM_STATE         = 42

# Small scratch training (only if no history CSV is found)
SCRATCH_EPOCHS   = 6
SCRATCH_VAL_SIZE = 0.2
SCRATCH_BATCH    = 64
SCRATCH_MAX_ROWS = 10_000  # numeric-only, tiny to avoid OOM

# Batch size for sparse->dense prediction
PRED_BATCH = 2048          # small enough to avoid RAM spikes

# Heuristics for file filtering
INCLUDE_PATTERNS = [
    "daily", "day", "Sleep", "sleep", "weight", "Weight",
    "summary", "calories", "Calories", "activities",
    "heart", "resting", "steps", "Steps", "distance", "Distance"
]
EXCLUDE_PATTERNS = ["minute", "Minute", "seconds", "second", "intraday", "Intraday"]

# Guess target candidates
TARGET_CANDIDATES = ["Calories", "calories", "TotalCalories", "Calories Burned", "Calories_Burned"]

# ==============================
#      BASIC HELPERS
# ==============================
def ensure_exists(p: str, name: str):
    if not os.path.exists(p):
        raise FileNotFoundError(f"{name} not found at: {p}")

def _downcast_numeric(df: pd.DataFrame) -> pd.DataFrame:
    for c in df.select_dtypes(include=["float64"]).columns:
        df[c] = pd.to_numeric(df[c], downcast="float")
    for c in df.select_dtypes(include=["int64"]).columns:
        df[c] = pd.to_numeric(df[c], downcast="integer")
    return df

def _safe_union_concat(dfs):
    if not dfs:
        return pd.DataFrame()
    cols = set()
    for d in dfs:
        cols.update(d.columns.tolist())
    cols = list(cols)
    aligned = [d.reindex(columns=cols) for d in dfs]
    return pd.concat(aligned, ignore_index=True)

def _want_file(fname: str) -> bool:
    low = fname.lower()
    if not INCLUDE_MINUTE_FILES:
        for bad in EXCLUDE_PATTERNS:
            if bad.lower() in low:
                return False
    for good in INCLUDE_PATTERNS:
        if good.lower() in low:
            return True
    return INCLUDE_MINUTE_FILES

def _read_csv_safely(path: str, row_limit):
    try:
        if row_limit:
            df = pd.read_csv(path, nrows=row_limit)
        else:
            df = pd.read_csv(path)
        return _downcast_numeric(df)
    except Exception as e:
        print(f"[WARN] Failed to read {path}: {e}")
        return pd.DataFrame()

def _maybe_sample(df: pd.DataFrame, frac: float) -> pd.DataFrame:
    if frac is None or frac >= 1.0 or len(df) == 0:
        return df
    if len(df) >= BIG_FILE_MIN_ROWS:
        df = df.sample(frac=frac, random_state=RANDOM_STATE)
    return df

def load_fitbit_folder(folder: str) -> pd.DataFrame:
    if not os.path.isdir(folder):
        print(f"[WARN] Folder not found: {folder}")
        return pd.DataFrame()
    csvs = [f for f in os.listdir(folder) if f.lower().endswith(".csv")]
    kept = [f for f in csvs if _want_file(f)]
    if not kept:
        print("[WARN] No CSVs selected to load after filtering.")
        return pd.DataFrame()

    dfs, total = [], 0
    for i, fname in enumerate(kept, 1):
        fpath = os.path.join(folder, fname)
        df = _read_csv_safely(fpath, ROW_LIMIT_PER_CSV)
        if df.empty:
            continue
        if ROW_LIMIT_PER_CSV is None:
            df = _maybe_sample(df, SAMPLE_FRAC_PER_CSV)
        df["__source_file"] = fname
        dfs.append(df)
        total += len(df)
        if len(dfs) >= 8:
            dfs = [_safe_union_concat(dfs)]
        if i % 10 == 0:
            print(f"[INFO] Loaded ~{i} files; rows so far: {total:,}")
    out = _safe_union_concat(dfs)
    print(f"[INFO] Loaded {folder} -> shape={out.shape}")
    return out

def drop_ids_dates(df: pd.DataFrame) -> pd.DataFrame:
    bads = []
    for c in df.columns:
        cl = c.lower()
        if "id" in cl or "date" in cl or "time" in cl or "datetime" in cl:
            bads.append(c)
    return df.drop(columns=bads, errors="ignore")

def pick_target(df: pd.DataFrame) -> str:
    for c in TARGET_CANDIDATES:
        if c in df.columns:
            return c
    nums = df.select_dtypes(include=[np.number]).columns.tolist()
    if not nums:
        raise ValueError("No numeric target column detected.")
    return nums[-1]

def sparse_batch_predict(model, X, batch_size=2048, dtype=np.float32):
    """Predict in RAM-safe batches from (possibly sparse) X."""
    n = X.shape[0]
    out = np.empty(n, dtype=np.float32)
    for start in range(0, n, batch_size):
        stop = min(start + batch_size, n)
        Xi = X[start:stop]
        if hasattr(Xi, "toarray"):
            Xi = Xi.toarray().astype(dtype, copy=False)
        else:
            Xi = np.asarray(Xi, dtype=dtype)
        out[start:stop] = model.predict(Xi, verbose=0).ravel().astype(np.float32)
    return out

# ==============================
#   1) TRAINING CURVES PLOT
# ==============================
def plot_loss_mae(history_csv: str, out_path: str):
    if not os.path.exists(history_csv):
        print(f"[WARN] No history CSV at {history_csv}; skipping curve plot.")
        return
    hist = pd.read_csv(history_csv)
    epochs = np.arange(1, len(hist) + 1)

    plt.figure(figsize=(8,5))
    if "loss" in hist.columns:    plt.plot(epochs, hist["loss"], label="Train Loss")
    if "val_loss" in hist.columns:plt.plot(epochs, hist["val_loss"], label="Val Loss")
    if "mae" in hist.columns:     plt.plot(epochs, hist["mae"], label="Train MAE")
    if "val_mae" in hist.columns: plt.plot(epochs, hist["val_mae"], label="Val MAE")
    plt.xlabel("Epoch"); plt.ylabel("Value")
    plt.title("Training Curves (Loss & MAE)")
    plt.grid(alpha=0.3); plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] Saved curves -> {out_path}")

def plot_loss_mae_fallback_sample(out_path: str):
    """No history CSV? Build a tiny numeric-only scratch history to avoid OOM."""
    try:
        d1 = load_fitbit_folder(FOLDER1)
        d2 = load_fitbit_folder(FOLDER2)
        df = _safe_union_concat([d1, d2])
        if df.empty:
            print("[WARN] No data for scratch history.")
            return

        df = drop_ids_dates(df)
        # numeric-only to avoid OHE explosion
        num_df = df.select_dtypes(include=[np.number]).copy()
        if num_df.shape[1] < 2:
            print("[WARN] Not enough numeric columns for scratch history.")
            return

        # target guess: last numeric col
        target_col = pick_target(num_df)
        num_df = num_df.dropna(subset=[target_col]).copy()

        # small sample
        if len(num_df) > SCRATCH_MAX_ROWS:
            num_df = num_df.sample(n=SCRATCH_MAX_ROWS, random_state=RANDOM_STATE)

        y = num_df[target_col].astype("float32").values
        X = num_df.drop(columns=[target_col]).fillna(0.0).astype("float32")

        from sklearn.model_selection import train_test_split
        from sklearn.preprocessing import StandardScaler
        Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=SCRATCH_VAL_SIZE, random_state=RANDOM_STATE)

        scaler = StandardScaler()
        Xtr = scaler.fit_transform(Xtr)
        Xte = scaler.transform(Xte)

        import tensorflow as tf
        from tensorflow import keras
        from tensorflow.keras import layers

        model = keras.Sequential([
            layers.Input(shape=(Xtr.shape[1],)),
            layers.Dense(64, activation="relu"),
            layers.Dense(32, activation="relu"),
            layers.Dense(1)
        ])
        model.compile(optimizer="adam", loss="mse", metrics=["mae"])

        hist = model.fit(
            Xtr, ytr,
            validation_data=(Xte, yte),
            epochs=SCRATCH_EPOCHS,
            batch_size=SCRATCH_BATCH,
            verbose=0
        )

        # Plot
        plt.figure(figsize=(8,5))
        plt.plot(hist.history.get("loss", []), label="Train Loss")
        plt.plot(hist.history.get("val_loss", []), label="Val Loss")
        plt.plot(hist.history.get("mae", []), label="Train MAE")
        plt.plot(hist.history.get("val_mae", []), label="Val MAE")
        plt.xlabel("Epoch"); plt.ylabel("Value")
        plt.title("Training Curves (Scratch Sample, Numeric Only)")
        plt.grid(alpha=0.3); plt.legend()
        plt.tight_layout()
        plt.savefig(out_path, dpi=160)
        plt.close()
        print(f"[OK] Saved scratch curves -> {out_path}")
    except Exception as e:
        print(f"[WARN] Scratch history failed: {e}")

# ==============================
#  2) PRED vs ACTUAL HEATMAP
# ==============================
def plot_pred_vs_actual_heatmap(pkl_path: str, h5_path: str, out_path: str):
    ensure_exists(pkl_path, "preprocess PKL")
    ensure_exists(h5_path, "model H5")

    bundle = joblib.load(pkl_path)
    pre = bundle["preprocess"]
    target_col = bundle.get("target_col", None)

    # Load a manageable sample from folders
    d1 = load_fitbit_folder(FOLDER1)
    d2 = load_fitbit_folder(FOLDER2)
    df = _safe_union_concat([d1, d2])
    if df.empty:
        print("[WARN] No data for Pred vs Actual plot.")
        return
    df = drop_ids_dates(df)

    if target_col is None:
        target_col = pick_target(df)
    df = df.dropna(subset=[target_col]).copy()

    # keep sample to make prediction fast and memory-safe
    MAX_ROWS_FOR_PVA = 20_000
    if len(df) > MAX_ROWS_FOR_PVA:
        df = df.sample(n=MAX_ROWS_FOR_PVA, random_state=RANDOM_STATE)

    y = df[target_col].astype("float32").values
    X = df.drop(columns=[target_col])

    # Transform (likely sparse)
    Xp = pre.transform(X)

    # Load model without compiling to avoid 'mse' deserialization mismatch
    import tensorflow as tf
    from tensorflow import keras
    model = keras.models.load_model(h5_path, compile=False)

    # RAM-safe batch predictions
    yhat = sparse_batch_predict(model, Xp, batch_size=PRED_BATCH, dtype=np.float32)

    # 2D histogram heatmap
    plt.figure(figsize=(6.5, 6))
    both = np.concatenate([y, yhat])
    vmin, vmax = np.nanpercentile(both, [1, 99])
    plt.hist2d(y, yhat, bins=60, range=[[vmin, vmax], [vmin, vmax]], cmap="viridis")
    plt.xlabel("Actual"); plt.ylabel("Predicted")
    plt.title("Predicted vs Actual (Heatmap)")
    cbar = plt.colorbar(); cbar.set_label("Count")
    plt.plot([vmin, vmax], [vmin, vmax], ls="--", lw=1, color="white")
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] Saved Pred vs Actual heatmap -> {out_path}")

# ==============================
#   3) CORRELATION HEATMAP
# ==============================
def plot_corr_heatmap(out_path: str):
    # Use the same sampled data for correlation
    d1 = load_fitbit_folder(FOLDER1)
    d2 = load_fitbit_folder(FOLDER2)
    df = _safe_union_concat([d1, d2])
    if df.empty:
        print("[WARN] No data for correlation heatmap.")
        return
    df = drop_ids_dates(df)

    num_df = df.select_dtypes(include=[np.number]).copy()
    # remove constant columns
    nunique = num_df.nunique()
    keep = nunique[nunique > 1].index.tolist()
    num_df = num_df[keep]

    if num_df.shape[1] < 2:
        print("[WARN] Not enough numeric cols for correlation heatmap.")
        return

    # limit width for readability
    MAX_COLS = 30
    if num_df.shape[1] > MAX_COLS:
        variances = num_df.var().sort_values(ascending=False)
        top_cols = variances.head(MAX_COLS).index.tolist()
        num_df = num_df[top_cols]

    corr = num_df.corr(numeric_only=True)

    plt.figure(figsize=(min(12, 0.6*corr.shape[1]+4), min(10, 0.6*corr.shape[0]+4)))
    try:
        import seaborn as sns
        sns.heatmap(corr, cmap="viridis", square=False, cbar=True)
    except Exception:
        im = plt.imshow(corr.values, cmap="viridis", aspect="auto")
        plt.colorbar(im)
        plt.xticks(range(corr.shape[1]), corr.columns, rotation=90)
        plt.yticks(range(corr.shape[0]), corr.index)
    plt.title("Correlation Heatmap (Numeric Features)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=180)
    plt.close()
    print(f"[OK] Saved correlation heatmap -> {out_path}")

# ==============================
#            RUN
# ==============================
if __name__ == "__main__":
    # 1) Curves
    if os.path.exists(HIST_CSV):
        plot_loss_mae(HIST_CSV, OUT_LOSS_MAE)
    else:
        print(f"[INFO] No history CSV at {HIST_CSV}. Creating a small scratch curve plot (numeric-only, capped rows)...")
        plot_loss_mae_fallback_sample(OUT_LOSS_MAE)

    # 2) Pred vs Actual heatmap (batch, RAM-safe)
    plot_pred_vs_actual_heatmap(PKL_PATH, H5_PATH, OUT_PVA)

    # 3) Correlation heatmap
    plot_corr_heatmap(OUT_CORR)

    print("\n[DONE] Figures saved to:")
    print(" -", OUT_LOSS_MAE)
    print(" -", OUT_PVA)
    print(" -", OUT_CORR)


[INFO] No history CSV at C:\Users\sagni\Downloads\Med Assist\medassist_history.csv. Creating a small scratch curve plot (numeric-only, capped rows)...
[INFO] Loaded C:\Users\sagni\Downloads\Med Assist\archive\mturkfitbit_export_4.12.16-5.12.16\Fitabase Data 4.12.16-5.12.16 -> shape=(48438, 30)
[INFO] Loaded C:\Users\sagni\Downloads\Med Assist\archive\mturkfitbit_export_3.12.16-4.11.16\Fitabase Data 3.12.16-4.11.16 -> shape=(48658, 25)
[OK] Saved scratch curves -> C:\Users\sagni\Downloads\Med Assist\medassist_loss_mae.png
[INFO] Loaded C:\Users\sagni\Downloads\Med Assist\archive\mturkfitbit_export_4.12.16-5.12.16\Fitabase Data 4.12.16-5.12.16 -> shape=(48438, 30)
[INFO] Loaded C:\Users\sagni\Downloads\Med Assist\archive\mturkfitbit_export_3.12.16-4.11.16\Fitabase Data 3.12.16-4.11.16 -> shape=(48658, 25)
[OK] Saved Pred vs Actual heatmap -> C:\Users\sagni\Downloads\Med Assist\medassist_pred_vs_actual.png
[INFO] Loaded C:\Users\sagni\Downloads\Med Assist\archive\mturkfitbit_export_4.12.1