In [3]:
import os, warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Try seaborn for nicer heatmaps; fall back to matplotlib if unavailable.
try:
    import seaborn as sns
    _HAS_SEABORN = True
except Exception:
    _HAS_SEABORN = False

import joblib
import tensorflow as tf
from tensorflow import keras

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------
# Paths (edit if needed)
# ----------------------------
BASE_DIR = r"C:\Users\sagni\Downloads\Agri Mind"
ARCHIVE  = os.path.join(BASE_DIR, "archive")

HISTCSV  = os.path.join(BASE_DIR, "neuro_history.csv")
PKL_PATH = os.path.join(BASE_DIR, "neuro_preprocess.pkl")
H5_PATH  = os.path.join(BASE_DIR, "neuro_model.h5")
DATACSV  = os.path.join(ARCHIVE, "yield_df.csv")

# If your target column is different, change here or auto-detect below.
POSSIBLE_TARGETS = ["hg/ha_yield", "yield", "Yield", "target", "y"]

# ----------------------------
# Utils
# ----------------------------
def _ensure_dir(p):
    os.makedirs(os.path.dirname(p), exist_ok=True)

def _detect_target(df: pd.DataFrame) -> str:
    for c in POSSIBLE_TARGETS:
        if c in df.columns:
            return c
    # fallback to last numeric col
    nums = df.select_dtypes(include=[np.number]).columns.tolist()
    if not nums:
        raise ValueError("No numeric columns available to choose a target.")
    return nums[-1]

def _to_numpy(X):
    return X.toarray() if hasattr(X, "toarray") else np.asarray(X)

# ----------------------------
# 1) Accuracy Graph (Loss & MAE)
# ----------------------------
def plot_training_curves(history_csv: str, out_path: str):
    if not os.path.exists(history_csv):
        print(f"[WARN] history CSV not found: {history_csv}")
        return
    hist = pd.read_csv(history_csv)
    # Expected columns: epoch, loss, val_loss, mae, val_mae (Keras CSVLogger)
    plt.figure(figsize=(9,5))
    if "loss" in hist.columns:
        plt.plot(hist["loss"], label="Train Loss")
    if "val_loss" in hist.columns:
        plt.plot(hist["val_loss"], label="Val Loss")
    if "mae" in hist.columns:
        plt.plot(hist["mae"], label="Train MAE")
    if "val_mae" in hist.columns:
        plt.plot(hist["val_mae"], label="Val MAE")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.title("Training Curves (Loss & MAE)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    _ensure_dir(out_path)
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] Saved accuracy graph -> {out_path}")

# ----------------------------
# 2) Predictions vs Actual (Hexbin heatmap)
# ----------------------------
def plot_pred_vs_actual_hex(preprocess_pkl: str, model_h5: str, data_csv: str, out_path: str):
    if not (os.path.exists(preprocess_pkl) and os.path.exists(model_h5) and os.path.exists(data_csv)):
        print("[WARN] Missing one of preprocess/model/data files; skipping pred-vs-actual plot.")
        return
    bundle = joblib.load(preprocess_pkl)
    pre = bundle["preprocess"]
    df = pd.read_csv(data_csv)
    target = bundle.get("target_col", None) or _detect_target(df)
    df = df.dropna(subset=[target]).copy()

    X = df.drop(columns=[target])
    y = df[target].astype(float).values
    Xp = pre.transform(X)

    # IMPORTANT: avoid recompiling/deserializing legacy 'mse' by disabling compile
    model = keras.models.load_model(model_h5, compile=False)
    yhat = model.predict(_to_numpy(Xp), verbose=0).ravel()

    # Hexbin 2D histogram
    plt.figure(figsize=(6.8,6))
    hb = plt.hexbin(y, yhat, gridsize=50, mincnt=1)
    plt.colorbar(hb, label='Count')
    plt.xlabel("Actual")
    plt.ylabel("Predicted")
    lim_min = np.nanmin([y.min(), yhat.min()])
    lim_max = np.nanmax([y.max(), yhat.max()])
    plt.plot([lim_min, lim_max], [lim_min, lim_max], ls="--", lw=1, color="black", label="Ideal")
    plt.legend()
    plt.title("Predicted vs Actual (Hexbin Heatmap)")
    plt.tight_layout()
    _ensure_dir(out_path)
    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] Saved prediction heatmap -> {out_path}")

# ----------------------------
# 3) Correlation Heatmap (numeric features)
# ----------------------------
def plot_corr_heatmap(data_csv: str, out_path: str, top_n: int = 30):
    if not os.path.exists(data_csv):
        print(f"[WARN] Dataset not found: {data_csv}")
        return
    df = pd.read_csv(data_csv)
    # Only numeric columns
    num_df = df.select_dtypes(include=[np.number]).copy()
    if num_df.shape[1] < 2:
        print("[WARN] Not enough numeric columns for correlation heatmap.")
        return
    # If too many columns, take the ones with highest variance (top_n)
    if num_df.shape[1] > top_n:
        variances = num_df.var().sort_values(ascending=False)
        cols = variances.head(top_n).index.tolist()
        num_df = num_df[cols]
    corr = num_df.corr(numeric_only=True)

    plt.figure(figsize=(10, 8))
    if _HAS_SEABORN:
        sns.heatmap(corr, cmap="viridis", annot=False, square=True, cbar=True)
    else:
        # fallback with matplotlib imshow
        im = plt.imshow(corr.values, cmap="viridis", aspect="equal")
        plt.colorbar(im)
        plt.xticks(range(corr.shape[1]), corr.columns, rotation=90)
        plt.yticks(range(corr.shape[0]), corr.index)
    plt.title("Correlation Heatmap (Numeric Features)")
    plt.tight_layout()
    _ensure_dir(out_path)
    plt.savefig(out_path, dpi=180)
    plt.close()
    print(f"[OK] Saved correlation heatmap -> {out_path}")

# ----------------------------
# 4) Error Heatmap by Area × Year (if available)
# ----------------------------
def plot_error_pivot_heatmap(preprocess_pkl: str, model_h5: str, data_csv: str, out_path: str,
                             area_cols=("Area","area","Country","country"),
                             year_cols=("Year","year")):
    if not (os.path.exists(preprocess_pkl) and os.path.exists(model_h5) and os.path.exists(data_csv)):
        print("[WARN] Missing one of preprocess/model/data files; skipping error pivot heatmap.")
        return
    bundle = joblib.load(preprocess_pkl)
    pre = bundle["preprocess"]
    df = pd.read_csv(data_csv)
    target = bundle.get("target_col", None) or _detect_target(df)
    df = df.dropna(subset=[target]).copy()

    # Find area/year columns
    area_col = next((c for c in area_cols if c in df.columns), None)
    year_col = next((c for c in year_cols if c in df.columns), None)
    if area_col is None or year_col is None:
        print("[WARN] Area/Year columns not found; skipping Area×Year error heatmap.")
        return

    X = df.drop(columns=[target])
    y = df[target].astype(float).values
    Xp = pre.transform(X)

    # IMPORTANT: disable compile on load to avoid 'mse' lookup error
    model = keras.models.load_model(model_h5, compile=False)
    yhat = model.predict(_to_numpy(Xp), verbose=0).ravel()
    abs_err = np.abs(yhat - y)

    tmp = pd.DataFrame({
        "area": df[area_col].astype(str).values,
        "year": df[year_col].values,
        "mae": abs_err
    })
    pivot = tmp.groupby(["area","year"])["mae"].mean().reset_index()
    pivot_tbl = pivot.pivot(index="area", columns="year", values="mae").fillna(np.nan)

    plt.figure(figsize=(12, max(6, 0.25 * len(pivot_tbl))))
    if _HAS_SEABORN:
        sns.heatmap(pivot_tbl, cmap="magma", cbar=True)
    else:
        im = plt.imshow(pivot_tbl.values, cmap="magma", aspect="auto")
        plt.colorbar(im)
        plt.yticks(range(pivot_tbl.shape[0]), pivot_tbl.index)
        plt.xticks(range(pivot_tbl.shape[1]), pivot_tbl.columns, rotation=90)
    plt.title("Mean Absolute Error by Area × Year")
    plt.xlabel("Year"); plt.ylabel("Area")
    plt.tight_layout()
    _ensure_dir(out_path)
    plt.savefig(out_path, dpi=180)
    plt.close()
    print(f"[OK] Saved error pivot heatmap -> {out_path}")

# ----------------------------
# Run all plots
# ----------------------------
if __name__ == "__main__":
    # 1) training curves
    plot_training_curves(HISTCSV, os.path.join(BASE_DIR, "loss_mae.png"))

    # 2) pred vs actual hexbin
    plot_pred_vs_actual_hex(
        preprocess_pkl=PKL_PATH,
        model_h5=H5_PATH,
        data_csv=DATACSV,
        out_path=os.path.join(BASE_DIR, "pred_vs_actual_hex.png")
    )

    # 3) correlation heatmap
    plot_corr_heatmap(
        data_csv=DATACSV,
        out_path=os.path.join(BASE_DIR, "corr_heatmap.png"),
        top_n=30
    )

    # 4) error pivot heatmap (Area × Year)
    plot_error_pivot_heatmap(
        preprocess_pkl=PKL_PATH,
        model_h5=H5_PATH,
        data_csv=DATACSV,
        out_path=os.path.join(BASE_DIR, "error_heatmap.png")
    )

    print("[DONE] Plots saved in:", BASE_DIR)


[OK] Saved accuracy graph -> C:\Users\sagni\Downloads\Agri Mind\loss_mae.png
[OK] Saved prediction heatmap -> C:\Users\sagni\Downloads\Agri Mind\pred_vs_actual_hex.png
[OK] Saved correlation heatmap -> C:\Users\sagni\Downloads\Agri Mind\corr_heatmap.png
[OK] Saved error pivot heatmap -> C:\Users\sagni\Downloads\Agri Mind\error_heatmap.png
[DONE] Plots saved in: C:\Users\sagni\Downloads\Agri Mind
