In [1]:
# eco_forest_predict_nb.py (or paste this whole cell into Jupyter)

import os
import json
import logging
from typing import Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# ---------------------------
# Default paths (match your training)
# ---------------------------
DATA_CSV_DEFAULT = r"C:\Users\sagni\Downloads\Eco Detect\archive\goal15.forest_shares.csv"
OUT_DIR          = r"C:\Users\sagni\Downloads\Eco Detect"

PKL_PATH         = os.path.join(OUT_DIR, "eco_forest_rf.pkl")   # trained sklearn pipeline
# Plots & output
PRED_CSV_DEFAULT = os.path.join(OUT_DIR, "predictions.csv")
PRED_HIST_PNG    = os.path.join(OUT_DIR, "pred_hist.png")
SCATTER_PNG      = os.path.join(OUT_DIR, "pred_vs_actual.png")
HEAT_PNG         = os.path.join(OUT_DIR, "residual_heatmap.png")

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# ---------------------------
# Helpers (keep in sync with training heuristics)
# ---------------------------
def pick_target_column(df: pd.DataFrame) -> Optional[str]:
    """Heuristic: prefer common names; else no target."""
    lower = {c.lower(): c for c in df.columns}
    for name in ["forest_share", "forest_share_percent", "forest_area_pct",
                 "forest_area_percent", "share", "value", "trend"]:
        if name in lower:
            return lower[name]
    return None

def best_country_col(df: pd.DataFrame) -> Optional[str]:
    for cand in ["country","country_name","Country","Country Name","Entity","entity"]:
        if cand in df.columns:
            return cand
    objs = [c for c in df.columns if df[c].dtype == "object"]
    return objs[0] if objs else None

def best_year_col(df: pd.DataFrame) -> Optional[str]:
    for cand in ["year","Year","Time","time"]:
        if cand in df.columns: return cand
    for cand in ["date","Date"]:
        if cand in df.columns: return cand
    return None

def coerce_numeric_like(df: pd.DataFrame) -> pd.DataFrame:
    """Convert number-like strings to numeric, keep true text as object."""
    df2 = df.copy()
    for c in df2.columns:
        if df2[c].dtype == "object":
            df2[c] = pd.to_numeric(df2[c], errors="ignore")
    return df2

# ---------------------------
# Plotters
# ---------------------------
def _plot_pred_hist(y_pred: np.ndarray, out_png: str, show: bool):
    plt.figure(figsize=(8,5))
    plt.hist(y_pred, bins=30, alpha=0.85)
    plt.title("Prediction Distribution")
    plt.xlabel("Predicted value"); plt.ylabel("Count")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout(); plt.savefig(out_png, dpi=160)
    if show:
        try: plt.show()
        except Exception: pass
    plt.close()

def _plot_pred_vs_actual(y_true: np.ndarray, y_pred: np.ndarray, out_png: str, show: bool):
    lim_min = float(min(y_true.min(), y_pred.min()))
    lim_max = float(max(y_true.max(), y_pred.max()))
    plt.figure(figsize=(7,7))
    plt.scatter(y_true, y_pred, s=18, alpha=0.7)
    plt.plot([lim_min, lim_max], [lim_min, lim_max], "r--", linewidth=1.5)
    plt.xlabel("True"); plt.ylabel("Predicted")
    plt.title("Predicted vs. True")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout(); plt.savefig(out_png, dpi=160)
    if show:
        try: plt.show()
        except Exception: pass
    plt.close()

def _plot_residual_heatmap(df_pred: pd.DataFrame, out_png: str, show: bool):
    """
    df_pred must have columns: residual (+ country/year if available).
    """
    ccol = best_country_col(df_pred)
    ycol = best_year_col(df_pred)

    if ccol is None:
        logging.warning("[HEATMAP] No country-like column found; skipping.")
        return

    # If year is parseable to numeric year, use it
    if ycol is not None and not np.issubdtype(df_pred[ycol].dtype, np.number):
        try:
            yr = pd.to_datetime(df_pred[ycol], errors="coerce").dt.year
            if yr.notna().any():
                df_pred[ycol] = yr
        except Exception:
            pass

    if ycol is not None and ycol in df_pred.columns and np.issubdtype(df_pred[ycol].dtype, np.number):
        pivot = df_pred.pivot_table(index=ccol, columns=ycol, values="residual", aggfunc="mean")
    else:
        pivot = df_pred.groupby(ccol)["residual"].mean().to_frame("mean_residual")

    # Keep top-35 countries by sample count to keep the image readable
    counts = df_pred.groupby(ccol).size().sort_values(ascending=False)
    keep = counts.head(35).index
    pivot = pivot.loc[pivot.index.intersection(keep)]

    if pivot.empty:
        logging.warning("[HEATMAP] Nothing to plot after filtering; skipping.")
        return

    plt.figure(figsize=(12,9))
    im = plt.imshow(pivot.values, aspect="auto", interpolation="nearest")
    plt.colorbar(im, fraction=0.046, pad=0.04, label="Residual (pred − true)")
    plt.yticks(range(len(pivot.index)), pivot.index)
    try:
        plt.xticks(range(len(pivot.columns)), pivot.columns, rotation=90)
    except Exception:
        pass
    plt.title("Residual Heatmap (mean error by Country × Year)")
    plt.tight_layout(); plt.savefig(out_png, dpi=160)
    if show:
        try: plt.show()
        except Exception: pass
    plt.close()

# ---------------------------
# Public API (Notebook-friendly)
# ---------------------------
def predict_csv(
    input_csv: str = DATA_CSV_DEFAULT,
    out_csv: str = PRED_CSV_DEFAULT,
    show_plots: bool = True
) -> str:
    """
    Predict using the saved sklearn pipeline (eco_forest_rf.pkl).
    - Saves predictions CSV to `out_csv`
    - Always saves prediction histogram PNG
    - If the CSV contains a target column (e.g., 'forest_share' or 'trend'):
        * prints R2 / RMSE / MAE
        * saves scatter (pred_vs_actual.png)
        * saves residual heatmap (residual_heatmap.png) using Country×Year if available
    Returns the path to the predictions CSV.
    """
    if not os.path.exists(PKL_PATH):
        raise FileNotFoundError(f"Missing sklearn pipeline: {PKL_PATH}. Train first.")

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    # Load model
    bundle = joblib.load(PKL_PATH)
    sk_pipe = bundle if not isinstance(bundle, dict) else bundle.get("pipeline", bundle)

    # Read & lightly coerce
    df = pd.read_csv(input_csv)
    df = coerce_numeric_like(df)

    # Detect target
    target = pick_target_column(df)
    has_target = target in df.columns if target else False
    if has_target:
        logging.info(f"[PRED] Detected target column in input: '{target}'")
        X = df.drop(columns=[target])
        y_true = pd.to_numeric(df[target], errors="coerce")
    else:
        logging.info("[PRED] No target column found — predictions only.")
        X = df
        y_true = None

    # Predict
    y_pred = sk_pipe.predict(X)

    # Save predictions CSV
    pred_df = df.copy()
    pred_df["y_pred"] = y_pred
    pred_df.to_csv(out_csv, index=False, encoding="utf-8")
    logging.info(f"[SAVE] Predictions CSV → {out_csv}")

    # Plots
    _plot_pred_hist(y_pred, PRED_HIST_PNG, show=show_plots)

    if has_target and y_true is not None and y_true.notna().any():
        mask = y_true.notna()
        y_true_v = y_true[mask].values
        y_pred_v = y_pred[mask]

        if len(y_true_v) > 0:
            r2  = r2_score(y_true_v, y_pred_v)
            rmse = np.sqrt(mean_squared_error(y_true_v, y_pred_v))
            mae  = mean_absolute_error(y_true_v, y_pred_v)
            logging.info(f"[METRICS] R2={r2:.4f} | RMSE={rmse:.4f} | MAE={mae:.4f}")

            _plot_pred_vs_actual(y_true_v, y_pred_v, SCATTER_PNG, show=show_plots)

            df_resid = df.loc[mask].copy()
            df_resid["y_true"] = y_true_v
            df_resid["y_pred"] = y_pred_v
            df_resid["residual"] = df_resid["y_pred"] - df_resid["y_true"]
            _plot_residual_heatmap(df_resid, HEAT_PNG, show=show_plots)
        else:
            logging.warning("[METRICS] After dropping NaNs, no rows left to score/plot.")
    else:
        logging.info("[PRED] No ground truth present → skipping scatter & residual heatmap.")

    return out_csv
