In [None]:
# -*- coding: utf-8 -*-
"""
3.4.3 Spatial Visualization over a world basemap
Priority order:
  - Cartopy (Robinson projection + land/sea/borders basemap)
  - If unavailable, fall back to GeoPandas (Natural Earth world basemap)
  - If both unavailable, fallback to plain longitude-latitude scatter (still produces plots).

Input:
  ./eval_3_4_2_outputs/all_site_metrics_all_models.csv
  Required columns: ['scale','target','site','lat','lon','R2','RMSE','MAE','PearsonR',...]

Output:
  ./eval_3_4_3_maps/*.png
"""

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ===== Backend detection =================================================
HAVE_CARTOPY = True
try:
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
except Exception:
    HAVE_CARTOPY = False

HAVE_GPD = True
try:
    import geopandas as gpd
except Exception:
    HAVE_GPD = False

USE = "cartopy" if HAVE_CARTOPY else ("geopandas" if HAVE_GPD else "plain")
print(f"[MAP BACKEND] {USE} (cartopy={HAVE_CARTOPY}, geopandas={HAVE_GPD})")

# ===== Config ============================================================
IN_CSV = Path("./eval_3_4_2_outputs/all_site_metrics_all_models.csv")
OUT_DIR = Path("./eval_3_4_3_maps"); OUT_DIR.mkdir(parents=True, exist_ok=True)

TARGETS = ["GPP", "NEE"]          # targets to plot
METRIC  = "R2"                    # options: "R2" | "RMSE" | "MAE" | "PearsonR" | "MAPE" | "RMSLE"
SINGLE_SCALES = ["170sites"]      # single-scale maps
COMPARE_PAIR  = ("5sites", "170sites")    # side-by-side comparison
DELTA_PAIR    = ("5sites", "170sites")    # delta maps: second - first

PROJECTION = "Robinson"           # only effective for Cartopy: "Robinson" | "PlateCarree"
DOT_SIZE  = 48
EDGE_KW   = dict(edgecolor="white", linewidth=0.4)

CMAP_R2    = "viridis"
CMAP_ERR   = "magma_r"            # error metrics: reversed for intuitive interpretation (low=bright)
CMAP_DELTA = "coolwarm"           # delta: diverging colormap

# ===== Utility functions =================================================
def robust_limits(vals: np.ndarray, lq=2, uq=98):
    """Robust vmin/vmax based on lower/upper percentiles."""
    vals = vals[np.isfinite(vals)]
    if len(vals) == 0:
        return None, None
    vmin, vmax = np.percentile(vals, lq), np.percentile(vals, uq)
    if np.isclose(vmin, vmax):
        eps = 1e-6; vmin, vmax = vmin - eps, vmax + eps
    return float(vmin), float(vmax)

def load_df(csv_path: Path) -> pd.DataFrame:
    if not csv_path.exists():
        raise FileNotFoundError(csv_path)
    df = pd.read_csv(csv_path)
    needed = {"scale","target","site","lat","lon", METRIC}
    miss = needed - set(df.columns)
    if miss:
        raise ValueError(f"CSV missing columns: {miss}")
    # normalize longitude into [-180, 180]
    if (df["lon"] > 180).any():
        df.loc[df["lon"] > 180, "lon"] -= 360.0
    return df

# ===== Backend: Cartopy ==================================================
def make_ax_cartopy(title: str):
    proj = ccrs.Robinson() if PROJECTION.lower()=="robinson" else ccrs.PlateCarree()
    fig = plt.figure(figsize=(10.5, 5.4), constrained_layout=True)
    ax = plt.axes(projection=proj)
    try:
        ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="#DCE7F3")
        ax.add_feature(cfeature.LAND,  zorder=0, facecolor="#F3F1ED")
        ax.add_feature(cfeature.LAKES, zorder=1, edgecolor="none", facecolor="#DCE7F3")
        ax.add_feature(cfeature.RIVERS, zorder=1, linewidth=0.2)
    except Exception:
        ax.stock_img()
    ax.add_feature(cfeature.BORDERS, linewidth=0.2, zorder=2)
    ax.add_feature(cfeature.COASTLINE, linewidth=0.4, zorder=2)
    ax.set_global()
    ax.set_title(title)
    return fig, ax

def scatter_cartopy(ax, df, color_col, vmin, vmax, cmap, add_colorbar=True):
    sc = ax.scatter(df["lon"], df["lat"], c=df[color_col], s=DOT_SIZE, cmap=cmap,
                    vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(),
                    zorder=5, **EDGE_KW)
    if add_colorbar:
        cb = plt.colorbar(sc, ax=ax, fraction=0.03, pad=0.02)
        cb.set_label(color_col)
    return sc

# ===== Backend: GeoPandas ================================================
WORLD = None
def get_world():
    """Load Natural Earth low-res world polygons."""
    global WORLD
    if WORLD is None:
        WORLD = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
    return WORLD

def make_ax_gpd(title: str):
    fig, ax = plt.subplots(figsize=(10.5, 5.4), constrained_layout=True)
    world = get_world()
    world.plot(ax=ax, color="#F3F1ED", edgecolor="#999", linewidth=0.2, zorder=0)
    world.boundary.plot(ax=ax, linewidth=0.3, color="#444", zorder=1)
    ax.set_xlim(-180,180); ax.set_ylim(-90,90)
    ax.set_xticks(np.arange(-180,181,60)); ax.set_yticks(np.arange(-90,91,30))
    ax.grid(True, ls="--", lw=0.3, alpha=0.35)
    ax.set_xlabel("Longitude"); ax.set_ylabel("Latitude")
    ax.set_title(title)
    return fig, ax

def scatter_gpd(ax, df, color_col, vmin, vmax, cmap, add_colorbar=True):
    sc = ax.scatter(df["lon"], df["lat"], c=df[color_col], s=DOT_SIZE,
                    cmap=cmap, vmin=vmin, vmax=vmax, zorder=2, **EDGE_KW)
    if add_colorbar:
        cb = plt.colorbar(sc, ax=ax, fraction=0.032, pad=0.02)
        cb.set_label(color_col)
    return sc

# ===== Backend: Plain ====================================================
def make_ax_plain(title: str):
    fig, ax = plt.subplots(figsize=(10.5, 5.4), constrained_layout=True)
    ax.set_xlim(-180,180); ax.set_ylim(-90,90)
    ax.set_xticks(np.arange(-180,181,60)); ax.set_yticks(np.arange(-90,91,30))
    ax.grid(True, ls="--", lw=0.3, alpha=0.35)
    ax.set_xlabel("Longitude"); ax.set_ylabel("Latitude")
    ax.set_title(title)
    return fig, ax

def scatter_plain(ax, df, color_col, vmin, vmax, cmap, add_colorbar=True):
    sc = ax.scatter(df["lon"], df["lat"], c=df[color_col], s=DOT_SIZE,
                    cmap=cmap, vmin=vmin, vmax=vmax, **EDGE_KW)
    if add_colorbar:
        cb = plt.colorbar(sc, ax=ax, fraction=0.032, pad=0.02)
        cb.set_label(color_col)
    return sc

# ===== Router ============================================================
def make_ax(title: str):
    if USE == "cartopy":  return make_ax_cartopy(title)
    if USE == "geopandas": return make_ax_gpd(title)
    return make_ax_plain(title)

def do_scatter(ax, df, color_col, vmin, vmax, cmap, add_colorbar=True):
    if USE == "cartopy":   return scatter_cartopy(ax, df, color_col, vmin, vmax, cmap, add_colorbar)
    if USE == "geopandas": return scatter_gpd(ax, df, color_col, vmin, vmax, cmap, add_colorbar)
    return scatter_plain(ax, df, color_col, vmin, vmax, cmap, add_colorbar)

# ===== Plotting functions ================================================
def plot_single(df: pd.DataFrame, scale: str, target: str, metric: str):
    """Single-scale visualization for one target."""
    d = df[(df["scale"] == scale) & (df["target"] == target)].copy()
    d = d[np.isfinite(d[metric])]
    if d.empty:
        print(f"[WARN] No data: {scale}, {target}")
        return
    if metric == "R2":
        vmin, vmax, cmap = 0.0, 1.0, CMAP_R2
    else:
        vmin, vmax = robust_limits(d[metric].values, 2, 98); cmap = CMAP_ERR
    fig, ax = make_ax(f"{metric} by site — {scale} ({target})")
    do_scatter(ax, d, metric, vmin, vmax, cmap, add_colorbar=True)
    out = OUT_DIR / f"map_{metric}_{scale}_{target}.png"
    fig.savefig(out, dpi=240); plt.close(fig); print(f"[Saved] {out}")

def plot_compare(df: pd.DataFrame, scale_a: str, scale_b: str, target: str, metric: str):
    """Side-by-side comparison for two scales."""
    dA = df[(df["scale"] == scale_a) & (df["target"] == target)].copy()
    dB = df[(df["scale"] == scale_b) & (df["target"] == target)].copy()
    dA, dB = dA[np.isfinite(dA[metric])], dB[np.isfinite(dB[metric])]
    if dA.empty or dB.empty:
        print(f"[WARN] Comparison missing data: {scale_a} vs {scale_b} ({target})")
        return

    vals = np.r_[dA[metric].values, dB[metric].values]
    if metric == "R2":
        vmin, vmax, cmap = 0.0, 1.0, CMAP_R2
    else:
        vmin, vmax = robust_limits(vals, 2, 98); cmap = CMAP_ERR

    # Two columns, shared colorbar
    if USE == "cartopy":
        proj = ccrs.Robinson() if PROJECTION.lower()=="robinson" else ccrs.PlateCarree()
        fig, axs = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True,
                                subplot_kw=dict(projection=proj))
        for ax in axs:
            try:
                ax.add_feature(cfeature.OCEAN, facecolor="#DCE7F3", zorder=0)
                ax.add_feature(cfeature.LAND,  facecolor="#F3F1ED", zorder=0)
            except Exception:
                ax.stock_img()
            ax.add_feature(cfeature.BORDERS, linewidth=0.2, zorder=2)
            ax.add_feature(cfeature.COASTLINE, linewidth=0.4, zorder=2)
            ax.set_global()
    elif USE == "geopandas":
        fig, axs = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        world = get_world()
        for ax in axs:
            world.plot(ax=ax, color="#F3F1ED", edgecolor="#999", linewidth=0.2, zorder=0)
            world.boundary.plot(ax=ax, linewidth=0.3, color="#444", zorder=1)
            ax.set_xlim(-180,180); ax.set_ylim(-90,90)
            ax.grid(True, ls="--", lw=0.3, alpha=0.35)
    else:
        fig, axs = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        for ax in axs:
            ax.set_xlim(-180,180); ax.set_ylim(-90,90)
            ax.grid(True, ls="--", lw=0.3, alpha=0.35)

    axs[0].set_title(f"{metric} — {scale_a} ({target})")
    axs[1].set_title(f"{metric} — {scale_b} ({target})")

    do_scatter(axs[0], dA, metric, vmin, vmax, cmap, add_colorbar=False)
    sc = do_scatter(axs[1], dB, metric, vmin, vmax, cmap, add_colorbar=False)

    # Shared colorbar
    cbar = fig.colorbar(sc, ax=axs.ravel().tolist(), fraction=0.03, pad=0.03)
    cbar.set_label(metric)

    out = OUT_DIR / f"compare_{metric}_{scale_a}_vs_{scale_b}_{target}.png"
    fig.savefig(out, dpi=240); plt.close(fig); print(f"[Saved] {out}")

def plot_delta(df: pd.DataFrame, scale_a: str, scale_b: str, target: str, metric: str):
    """
    Plot delta map: Δmetric = metric(scale_b) - metric(scale_a) at matched sites.
    For R2: delta > 0 means improvement.
    For RMSE/MAE: delta < 0 means error reduction (also improvement).
    """
    A = df[(df["scale"] == scale_a) & (df["target"] == target)][["site","lat","lon",metric]].copy()
    B = df[(df["scale"] == scale_b) & (df["target"] == target)][["site","lat","lon",metric]].copy()
    M = pd.merge(B, A, on="site", suffixes=("_B","_A"))
    if M.empty:
        print(f"[WARN] No overlapping sites: {scale_b} vs {scale_a} ({target})"); return

    # use coordinates from scale_b
    M["lat"] = M["lat_B"]; M["lon"] = M["lon_B"]
    M["delta"] = M[f"{metric}_B"] - M[f"{metric}_A"]
    M = M[np.isfinite(M["delta"])]

    vmax = np.nanpercentile(np.abs(M["delta"]), 98)
    if not np.isfinite(vmax) or vmax == 0:
        vmax = 1e-6
    vmin = -vmax

    fig, ax = make_ax(f"Δ{metric} = {scale_b} − {scale_a} ({target})")
    do_scatter(ax, M.rename(columns={"delta": "delta"}), "delta",
               vmin=vmin, vmax=vmax, cmap=CMAP_DELTA, add_colorbar=True)
    out = OUT_DIR / f"delta_{metric}_{scale_b}minus{scale_a}_{target}.png"
    fig.savefig(out, dpi=240); plt.close(fig); print(f"[Saved] {out}")

# ===== Main ==============================================================
def main():
    df = load_df(IN_CSV)

    # Single-scale maps
    for t in TARGETS:
        for sc in SINGLE_SCALES:
            plot_single(df, sc, t, METRIC)

    # Side-by-side comparison
    for t in TARGETS:
        plot_compare(df, COMPARE_PAIR[0], COMPARE_PAIR[1], t, METRIC)

    # Delta maps
    for t in TARGETS:
        plot_delta(df, DELTA_PAIR[0], DELTA_PAIR[1], t, METRIC)

if __name__ == "__main__":
    main()


[MAP BACKEND] cartopy (cartopy=True, geopandas=False)




[Saved] eval_3_4_3_maps/map_R2_170sites_GPP.png
[Saved] eval_3_4_3_maps/map_R2_170sites_NEE.png
[Saved] eval_3_4_3_maps/compare_R2_5sites_vs_170sites_GPP.png
[Saved] eval_3_4_3_maps/compare_R2_5sites_vs_170sites_NEE.png
[Saved] eval_3_4_3_maps/delta_R2_170sitesminus5sites_GPP.png
[Saved] eval_3_4_3_maps/delta_R2_170sitesminus5sites_NEE.png
