Prediction Maps for FRH04 and Belle-Ile Region subset of FRH04 for Julia Results

In [None]:
import os
import time
import torch
import pandas as pd
import sklearn.metrics
import breizhcrops
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim import Adam
from tqdm import tqdm 
from breizhcrops import BreizhCrops
from breizhcrops.models import TempCNN

DATA_PATH = "/breizh_data"                  
LEVEL = "L1C"
PRELOAD_RAM = False

belle_ile = BreizhCrops(region="belle-ile", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
frh04 = BreizhCrops(region="frh04", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
field_parcels_geodataframe = frh04.geodataframe() #frh04 or belle_ile

In [None]:
import os, math
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
from pyproj import CRS

# CONFIG 
CSV_PATH     = "julia_sample_results.csv"   # Julia results CSV
SAVE_DIR     = "maps"
PRED_SRC     = "onnx"   # "julia" or "onnx"
CRS_EPSG     = 3857      
TITLE_REGION = "FRH04"
FIXED_SCALE_KM = None 

os.makedirs(SAVE_DIR, exist_ok=True)


def _safe_series(x):
    return pd.Series(np.asarray(x, dtype=object))

def fmt_pct(p):
    if p <= 0:  return "0%"
    if p < 0.1: return "<0.1%"
    return f"{p:.1f}%"

def add_north_arrow(ax, xy=(0.95, 0.92), length=0.10):
    ax.annotate('', xy=xy, xytext=(xy[0], xy[1]-length),
                xycoords='axes fraction', textcoords='axes fraction',
                arrowprops=dict(arrowstyle='-|>', lw=1.8, color='k'))
    ax.text(xy[0], xy[1]+0.02, 'N', transform=ax.transAxes,
            ha='center', va='bottom', fontsize=12, weight='bold',color='k')

def _nice_number_m(x_m):
    if x_m <= 0: return 1.0
    exp  = int(np.floor(np.log10(x_m)))
    frac = x_m / (10**exp)
    nice = 1 if frac < 1.5 else (2 if frac < 3 else (5 if frac < 7 else 10))
    return nice * (10**exp)

def _meters_per_deg_lon_at_lat(lat_deg):
    return 111320.0 * math.cos(math.radians(lat_deg))

def add_qgis_scalebar_auto(ax, gdf_plotted, total_km=None, segments=4,
                           loc='lower right', pad_frac=0.05, edge_lw=0.8, textsize=10):
    if gdf_plotted.crs is None:
        raise ValueError("GeoDataFrame has no CRS. Set/reproject before drawing scalebar.")
    crs = CRS.from_user_input(gdf_plotted.crs)

  
    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    width  = maxx - minx
    height = maxy - miny

    if crs.is_projected:
        width_m_est = width
        meters_total = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_native = meters_total
        seg_native     = bar_len_native / float(segments)
        bar_thick      = 0.014 * height
        label_mid_km   = meters_total / 2000.0
    else:
        mid_lat = 0.5*(miny+maxy)
        m_per_deg_lon = max(_meters_per_deg_lon_at_lat(mid_lat), 1e-6)
        width_m_est   = width * m_per_deg_lon
        meters_total  = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_deg   = meters_total / m_per_deg_lon
        seg_native    = bar_len_deg / float(segments)
        bar_len_native = bar_len_deg
        bar_thick     = 0.014 * height
        label_mid_km  = meters_total / 2000.0

    # anchor
    if loc == 'lower right':
        x0 = maxx - pad_frac*width - bar_len_native; y0 = miny + pad_frac*height; label_above = True
    elif loc == 'lower left':
        x0 = minx + pad_frac*width; y0 = miny + pad_frac*height; label_above = True
    elif loc == 'upper right':
        x0 = maxx - pad_frac*width - bar_len_native; y0 = maxy - pad_frac*height - 2.5*bar_thick; label_above = False
    else:
        x0 = minx + pad_frac*width; y0 = maxy - pad_frac*height - 2.5*bar_thick; label_above = False

    # blocks
    for i in range(segments):
        xi = x0 + i*seg_native
        ax.add_patch(Rectangle((xi, y0), seg_native, bar_thick,
                               facecolor=('k' if i%2==0 else 'white'),
                               edgecolor='k', lw=edge_lw))
    # ticks
    ax.plot([x0, x0], [y0, y0+bar_thick], color='k', lw=edge_lw)
    ax.plot([x0+bar_len_native/2, x0+bar_len_native/2], [y0, y0+bar_thick], color='k', lw=edge_lw)
    ax.plot([x0+bar_len_native,   x0+bar_len_native],   [y0, y0+bar_thick], color='k', lw=edge_lw)

    # labels
    ty = (y0 + 1.9*bar_thick) if label_above else (y0 - 0.9*bar_thick)
    va = 'bottom' if label_above else 'top'
    ax.text(x0, ty, "0", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    ax.text(x0+bar_len_native/2, ty, f"{label_mid_km:g}", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    end_label = f"{(meters_total/1000.0):g} km" if meters_total>=1000 else f"{int(round(meters_total))} m"
    ax.text(x0+bar_len_native, ty, end_label, ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))

def add_bg_if_3857(ax, gdf_proj, **kwargs):
    try:
        if CRS.from_user_input(gdf_proj.crs).to_epsg() != 3857:
            return None
        import contextily as ctx
        providers_try = [
            "Esri.WorldImagery","CartoDB.Positron", "CartoDB.Voyager",
            "OpenStreetMap.Mapnik","Esri.WorldTopoMap", "OpenTopoMap",
        ]
        def _resolve(path):
            prov = ctx.providers
            for part in path.split("."):
                prov = getattr(prov, part, None)
                if prov is None: return None
            return prov
        for name in providers_try:
            prov = _resolve(name)
            if prov is not None:
                ctx.add_basemap(ax, crs=gdf_proj.crs, source=prov, **kwargs)
                return name
        ctx.add_basemap(ax, crs=gdf_proj.crs, **kwargs)
        return "default(OSM)"
    except Exception:
        return None

def col(df, name_like):
    m = {c.lower(): c for c in df.columns}
    key = name_like.lower()
    if key not in m:
        raise KeyError(f"Required column like '{name_like}' not found in: {list(df.columns)}")
    return m[key]

def majority_label(s: pd.Series):
    s = s.dropna().astype(str)
    if s.empty: return None
    m = s.mode()
    return str(m.iloc[0]) if not m.empty else str(s.iloc[0])

# ---------- read Julia/ONNX results and aggregate per field ----------
raw = pd.read_csv(CSV_PATH)
fid_col   = col(raw, "field_id")
gt_name   = col(raw, "ground_truth_name")
julia_col = col(raw, "julia_pred_name")  if PRED_SRC.lower() == "julia" else None
onnx_col  = col(raw, "onnx_pred_name")   if PRED_SRC.lower() == "onnx"  else None
pred_name = julia_col or onnx_col
title_src = "Flux" if julia_col else "ONNX — Julia"

raw[fid_col] = raw[fid_col].astype(int)

agg = (raw.groupby(fid_col, as_index=False)
          .agg({gt_name: majority_label, pred_name: majority_label})
          .rename(columns={fid_col: "id", gt_name: "gt_name_csv", pred_name: "pred_name_csv"}))

# ---------- merge with parcels ----------
gdf = field_parcels_geodataframe.copy()
gdf["id"] = gdf["id"].astype(int)
gdf = gdf.merge(agg, on="id", how="left")

if "classname" in gdf.columns and not _safe_series(gdf["classname"]).isna().all():
    gdf["gt_class_name"] = gdf["classname"].astype(str)
else:
    gdf["gt_class_name"] = _safe_series(gdf["gt_name_csv"]).astype(str)

gdf["pred_class_name"] = _safe_series(gdf["pred_name_csv"]).astype(str)
gdf_proj = gdf.to_crs(epsg=CRS_EPSG)

# ---------- stable classes + palette ----------
classes = sorted(pd.unique(pd.concat([
    _safe_series(gdf_proj["gt_class_name"]),
    _safe_series(gdf_proj["pred_class_name"])
], ignore_index=True).dropna().astype(str)))
cmap = plt.cm.get_cmap("tab20", len(classes))
class2color = {cls: cmap(i) for i, cls in enumerate(classes)}

gdf_proj["gt_class_name"]   = pd.Categorical(_safe_series(gdf_proj["gt_class_name"]).astype(str),   categories=classes, ordered=True)
gdf_proj["pred_class_name"] = pd.Categorical(_safe_series(gdf_proj["pred_class_name"]).astype(str), categories=classes, ordered=True)

# =========================
# 1) Ground Truth
# =========================
gt_counts = (_safe_series(gdf_proj["gt_class_name"]).value_counts(dropna=True).reindex(classes, fill_value=0))
gt_total  = int(gt_counts.sum())
gt_pct    = gt_counts / max(gt_total, 1) * 100.0
gdf_proj["gt_color"] = _safe_series(gdf_proj["gt_class_name"]).map(class2color)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["gt_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
minx, miny, maxx, maxy = gdf_proj.total_bounds
pad_x = 0.02 * (maxx - minx); pad_y = 0.02 * (maxy - miny)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy + pad_y)
gdf_proj.plot(color=gdf_proj["gt_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)
ax.set_title(f"{TITLE_REGION} — Ground Truth (BreizhCrops)", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower right')

handles_gt = [mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                             label=f"{cls} — {fmt_pct(float(gt_pct.loc[cls]))} ({int(gt_counts.loc[cls]):,})")
              for cls in classes]
ax.legend(handles=handles_gt, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Class")
fig.tight_layout(); fig.subplots_adjust(right=0.82)
fig.savefig(os.path.join(SAVE_DIR, "gr_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

# ==============================
# 2) Predicted
# ==============================
pred_counts = (_safe_series(gdf_proj["pred_class_name"]).value_counts(dropna=True).reindex(classes, fill_value=0))
pred_total  = int(pred_counts.sum())
pred_pct    = pred_counts / max(pred_total, 1) * 100.0
gdf_proj["pred_color"] = _safe_series(gdf_proj["pred_class_name"]).map(class2color)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["pred_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy - (-pad_y))  
gdf_proj.plot(color=gdf_proj["pred_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title(f"{TITLE_REGION} — Predicted (TempCNN, {title_src})", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower right')

handles_pr = [mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                             label=f"{cls} — {fmt_pct(float(pred_pct.loc[cls]))} ({int(pred_counts.loc[cls]):,})")
              for cls in classes]
ax.legend(handles=handles_pr, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Predicted Class")
fig.tight_layout(); fig.subplots_adjust(right=0.82)
fig.savefig(os.path.join(SAVE_DIR, "pr_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

# ==========================================
# 3) Correct vs Incorrect
# ==========================================
gdf_proj["correct"]  = (_safe_series(gdf_proj["gt_class_name"]) == _safe_series(gdf_proj["pred_class_name"]))
palette_ci = {True: "#1a9850", False: "#d73027"}
gdf_proj["ci_color"] = gdf_proj["correct"].map(palette_ci)

s_pred      = _safe_series(gdf_proj["pred_class_name"])
n_total     = int(s_pred.notna().sum())
n_correct   = int(_safe_series(gdf_proj["correct"]).sum())
n_incorrect = n_total - n_correct
p_correct   = 100.0 * n_correct / max(n_total, 1)
p_incorrect = 100.0 * n_incorrect / max(n_total, 1)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["ci_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy + pad_y)
gdf_proj.plot(color=gdf_proj["ci_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title(f"{TITLE_REGION} — Correct vs Incorrect — {title_src}", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower right')

handles_ci = [
    mpatches.Patch(color=palette_ci[True],  label=f"Correct — {p_correct:.1f}% ({n_correct:,})"),
    mpatches.Patch(color=palette_ci[False], label=f"Incorrect — {p_incorrect:.1f}% ({n_incorrect:,})"),
]
ax.legend(handles=handles_ci, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Prediction")

fig.tight_layout(); fig.subplots_adjust(right=0.72)
fig.savefig(os.path.join(SAVE_DIR, "co_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

print("Saved:",
      os.path.join(SAVE_DIR, "gr_julia_pct.png"),
      os.path.join(SAVE_DIR, "pr_julia_pct.png"),
      os.path.join(SAVE_DIR, "co_julia_pct.png"))

#Belle_Ile

In [None]:
field_parcels_geodataframe = belle_ile.geodataframe()

In [None]:
import os, math
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
from pyproj import CRS


# CONFIG 
CSV_PATH     = "maps/julia_sample_results.csv"   # Julia results CSV
SAVE_DIR     = "maps"
PRED_SRC     = "julia"   # "julia" or "onnx"
CRS_EPSG     = 3857      
TITLE_REGION = "Belle-ile"
FIXED_SCALE_KM = None   
os.makedirs(SAVE_DIR, exist_ok=True)

def _safe_series(x):
    return pd.Series(np.asarray(x, dtype=object))

def fmt_pct(p):
    if p <= 0:  return "0%"
    if p < 0.1: return "<0.1%"
    return f"{p:.1f}%"

def add_north_arrow(ax, xy=(0.95, 0.92), length=0.10):
    ax.annotate('', xy=xy, xytext=(xy[0], xy[1]-length),
                xycoords='axes fraction', textcoords='axes fraction',
                arrowprops=dict(arrowstyle='-|>', lw=1.8, color='w'))
    ax.text(xy[0], xy[1]+0.02, 'N', transform=ax.transAxes,
            ha='center', va='bottom', fontsize=12, weight='bold',color='w')

def _nice_number_m(x_m):
    if x_m <= 0: return 1.0
    exp  = int(np.floor(np.log10(x_m)))
    frac = x_m / (10**exp)
    nice = 1 if frac < 1.5 else (2 if frac < 3 else (5 if frac < 7 else 10))
    return nice * (10**exp)

def _meters_per_deg_lon_at_lat(lat_deg):
    return 111320.0 * math.cos(math.radians(lat_deg))

def add_qgis_scalebar_auto(ax, gdf_plotted, total_km=None, segments=4,
                           loc='lower left', pad_frac=0.05, edge_lw=0.8, textsize=10):
    if gdf_plotted.crs is None:
        raise ValueError("GeoDataFrame has no CRS. Set/reproject before drawing scalebar.")
    crs = CRS.from_user_input(gdf_plotted.crs)

    # current visible extent
    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    width  = maxx - minx
    height = maxy - miny

    if crs.is_projected:
        width_m_est = width
        meters_total = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_native = meters_total
        seg_native     = bar_len_native / float(segments)
        bar_thick      = 0.014 * height
        label_mid_km   = meters_total / 2000.0
    else:
        mid_lat = 0.5*(miny+maxy)
        m_per_deg_lon = max(_meters_per_deg_lon_at_lat(mid_lat), 1e-6)
        width_m_est   = width * m_per_deg_lon
        meters_total  = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_deg   = meters_total / m_per_deg_lon
        seg_native    = bar_len_deg / float(segments)
        bar_len_native = bar_len_deg
        bar_thick     = 0.014 * height
        label_mid_km  = meters_total / 2000.0

    # anchor
    if loc == 'lower right':
        x0 = maxx - pad_frac*width - bar_len_native; y0 = miny + pad_frac*height; label_above = True
    elif loc == 'lower left':
        x0 = minx + pad_frac*width; y0 = miny + pad_frac*height; label_above = True
    elif loc == 'upper right':
        x0 = maxx - pad_frac*width - bar_len_native; y0 = maxy - pad_frac*height - 2.5*bar_thick; label_above = False
    else:
        x0 = minx + pad_frac*width; y0 = maxy - pad_frac*height - 2.5*bar_thick; label_above = False

    # blocks
    for i in range(segments):
        xi = x0 + i*seg_native
        ax.add_patch(Rectangle((xi, y0), seg_native, bar_thick,
                               facecolor=('grey' if i%2==0 else 'white'),
                               edgecolor='grey', lw=edge_lw))
    # ticks
    ax.plot([x0, x0], [y0, y0+bar_thick], color='k', lw=edge_lw)
    ax.plot([x0+bar_len_native/2, x0+bar_len_native/2], [y0, y0+bar_thick], color='k', lw=edge_lw)
    ax.plot([x0+bar_len_native,   x0+bar_len_native],   [y0, y0+bar_thick], color='k', lw=edge_lw)

    # labels
    ty = (y0 + 1.9*bar_thick) if label_above else (y0 - 0.9*bar_thick)
    va = 'bottom' if label_above else 'top'
    ax.text(x0, ty, "0", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    ax.text(x0+bar_len_native/2, ty, f"{label_mid_km:g}", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    end_label = f"{(meters_total/1000.0):g} km" if meters_total>=1000 else f"{int(round(meters_total))} m"
    ax.text(x0+bar_len_native, ty, end_label, ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))

def add_bg_if_3857(ax, gdf_proj, **kwargs):
    try:
        if CRS.from_user_input(gdf_proj.crs).to_epsg() != 3857:
            return None
        import contextily as ctx
        providers_try = [
            "Esri.WorldImagery","CartoDB.Positron", "CartoDB.Voyager",
            "OpenStreetMap.Mapnik","Esri.WorldTopoMap", "OpenTopoMap",
        ]
        def _resolve(path):
            prov = ctx.providers
            for part in path.split("."):
                prov = getattr(prov, part, None)
                if prov is None: return None
            return prov
        for name in providers_try:
            prov = _resolve(name)
            if prov is not None:
                ctx.add_basemap(ax, crs=gdf_proj.crs, source=prov, **kwargs)
                return name
        ctx.add_basemap(ax, crs=gdf_proj.crs, **kwargs)
        return "default(OSM)"
    except Exception:
        return None

def col(df, name_like):
    m = {c.lower(): c for c in df.columns}
    key = name_like.lower()
    if key not in m:
        raise KeyError(f"Required column like '{name_like}' not found in: {list(df.columns)}")
    return m[key]

def majority_label(s: pd.Series):
    s = s.dropna().astype(str)
    if s.empty: return None
    m = s.mode()
    return str(m.iloc[0]) if not m.empty else str(s.iloc[0])

# ---------- read Julia/ONNX results and aggregate per field ----------
raw = pd.read_csv(CSV_PATH)
fid_col   = col(raw, "field_id")
gt_name   = col(raw, "ground_truth_name")
julia_col = col(raw, "julia_pred_name")  if PRED_SRC.lower() == "julia" else None
onnx_col  = col(raw, "onnx_pred_name")   if PRED_SRC.lower() == "onnx"  else None
pred_name = julia_col or onnx_col
title_src = "Flux" if julia_col else "ONNX — Julia"

raw[fid_col] = raw[fid_col].astype(int)

agg = (raw.groupby(fid_col, as_index=False)
          .agg({gt_name: majority_label, pred_name: majority_label})
          .rename(columns={fid_col: "id", gt_name: "gt_name_csv", pred_name: "pred_name_csv"}))

# ---------- merge with parcels ----------
gdf = field_parcels_geodataframe.copy()
gdf["id"] = gdf["id"].astype(int)
gdf = gdf.merge(agg, on="id", how="left")

if "classname" in gdf.columns and not _safe_series(gdf["classname"]).isna().all():
    gdf["gt_class_name"] = gdf["classname"].astype(str)
else:
    gdf["gt_class_name"] = _safe_series(gdf["gt_name_csv"]).astype(str)

gdf["pred_class_name"] = _safe_series(gdf["pred_name_csv"]).astype(str)
gdf_proj = gdf.to_crs(epsg=CRS_EPSG)

# ---------- stable classes + palette ----------
classes = sorted(pd.unique(pd.concat([
    _safe_series(gdf_proj["gt_class_name"]),
    _safe_series(gdf_proj["pred_class_name"])
], ignore_index=True).dropna().astype(str)))
cmap = plt.cm.get_cmap("tab20", len(classes))
class2color = {cls: cmap(i) for i, cls in enumerate(classes)}

gdf_proj["gt_class_name"]   = pd.Categorical(_safe_series(gdf_proj["gt_class_name"]).astype(str),   categories=classes, ordered=True)
gdf_proj["pred_class_name"] = pd.Categorical(_safe_series(gdf_proj["pred_class_name"]).astype(str), categories=classes, ordered=True)

# =========================
# 1) Ground Truth
# =========================
gt_counts = (_safe_series(gdf_proj["gt_class_name"]).value_counts(dropna=True).reindex(classes, fill_value=0))
gt_total  = int(gt_counts.sum())
gt_pct    = gt_counts / max(gt_total, 1) * 100.0
gdf_proj["gt_color"] = _safe_series(gdf_proj["gt_class_name"]).map(class2color)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["gt_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)

minx, miny, maxx, maxy = gdf_proj.total_bounds
pad_x = 0.02 * (maxx - minx); pad_y = 0.02 * (maxy - miny)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy + pad_y)
_ = add_bg_if_3857(ax, gdf_proj, zoom="auto", attribution=True, alpha=1.0)
gdf_proj.plot(color=gdf_proj["gt_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title(f"{TITLE_REGION} — Ground Truth (BreizhCrops)", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower left')

handles_gt = [mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                             label=f"{cls} — {fmt_pct(float(gt_pct.loc[cls]))} ({int(gt_counts.loc[cls]):,})")
              for cls in classes]
ax.legend(handles=handles_gt, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Class")
fig.tight_layout(); fig.subplots_adjust(right=0.82)
fig.savefig(os.path.join(SAVE_DIR, "gr_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

# ==============================
# 2) Predicted
# ==============================
pred_counts = (_safe_series(gdf_proj["pred_class_name"]).value_counts(dropna=True).reindex(classes, fill_value=0))
pred_total  = int(pred_counts.sum())
pred_pct    = pred_counts / max(pred_total, 1) * 100.0
gdf_proj["pred_color"] = _safe_series(gdf_proj["pred_class_name"]).map(class2color)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["pred_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy - (-pad_y))  # reuse extent
_ = add_bg_if_3857(ax, gdf_proj, zoom="auto", attribution=True, alpha=1.0)
gdf_proj.plot(color=gdf_proj["pred_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title(f"{TITLE_REGION} — Predicted (TempCNN, {title_src})", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower left')

handles_pr = [mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                             label=f"{cls} — {fmt_pct(float(pred_pct.loc[cls]))} ({int(pred_counts.loc[cls]):,})")
              for cls in classes]
ax.legend(handles=handles_pr, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Predicted Class")
fig.tight_layout(); fig.subplots_adjust(right=0.82)
fig.savefig(os.path.join(SAVE_DIR, "pr_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

# ==========================================
# 3) Correct vs Incorrect
# ==========================================
gdf_proj["correct"]  = (_safe_series(gdf_proj["gt_class_name"]) == _safe_series(gdf_proj["pred_class_name"]))
palette_ci = {True: "#1a9850", False: "#d73027"}
gdf_proj["ci_color"] = gdf_proj["correct"].map(palette_ci)

s_pred      = _safe_series(gdf_proj["pred_class_name"])
n_total     = int(s_pred.notna().sum())
n_correct   = int(_safe_series(gdf_proj["correct"]).sum())
n_incorrect = n_total - n_correct
p_correct   = 100.0 * n_correct / max(n_total, 1)
p_incorrect = 100.0 * n_incorrect / max(n_total, 1)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
gdf_proj.plot(color=gdf_proj["ci_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
ax.set_xlim(minx - pad_x, maxx + pad_x); ax.set_ylim(miny - pad_y, maxy + pad_y)
_ = add_bg_if_3857(ax, gdf_proj, zoom="auto", attribution=True, alpha=1.0)
gdf_proj.plot(color=gdf_proj["ci_color"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title(f"{TITLE_REGION} — Correct vs Incorrect — {title_src}", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax)
add_qgis_scalebar_auto(ax, gdf_proj, total_km=FIXED_SCALE_KM, segments=4, loc='lower left')

handles_ci = [
    mpatches.Patch(color=palette_ci[True],  label=f"Correct — {p_correct:.1f}% ({n_correct:,})"),
    mpatches.Patch(color=palette_ci[False], label=f"Incorrect — {p_incorrect:.1f}% ({n_incorrect:,})"),
]
ax.legend(handles=handles_ci, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Prediction")

fig.tight_layout(); fig.subplots_adjust(right=0.72)
fig.savefig(os.path.join(SAVE_DIR, "co_julia_pct.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show(); plt.close(fig)

print("Saved:",
      os.path.join(SAVE_DIR, "gr_julia_pct.png"),
      os.path.join(SAVE_DIR, "pr_julia_pct.png"),
      os.path.join(SAVE_DIR, "co_julia_pct.png"))