In [None]:
#!/usr/bin/env python3
# ============================================================
# Signed error maps (PDF):
#   4 maps = (North/South) x (MAE, GEH) with signed coloring
#   - Color: red = over (pred>gt), blue = under (pred<gt)
#   - Magnitude controls alpha (opacity) using percentile cap
#   - Basemap: same as before (GBOverview.tif), greyscaled
#   - No Train/Test split: merges train + test predictions
#   - Legends: 2 standalone PDFs (MAE and GEH)
# ============================================================

import json
from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import rasterio
from rasterio.windows import from_bounds
from rasterio.enums import Resampling
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize

# ========================= INPUTS =========================
BASEMAP_TIF   = "../data/basemap/GBOverview.tif"  # EPSG:27700
EDGES_GEOJSON = "../data/highway_network/uk_driving_edges_simplified.geojson"

TRAIN_JSON   = "error_analysis/pred_results_train.json"
TEST_JSON    = "error_analysis/pred_results_test.json"

# England-ish bbox in EPSG:27700
ENGLAND_BBOX_27700 = (0, 0, 700000, 700000)

OUT_DIR = Path("error_analysis/pdf_error_maps")
OUT_DIR.mkdir(parents=True, exist_ok=True)

OUT_N_MAE = OUT_DIR / "errors_north_mae.pdf"
OUT_S_MAE = OUT_DIR / "errors_south_mae.pdf"
OUT_N_GEH = OUT_DIR / "errors_north_geh.pdf"
OUT_S_GEH = OUT_DIR / "errors_south_geh.pdf"

LEGEND_MAE = OUT_DIR / "legend_mae.svg"
LEGEND_GEH = OUT_DIR / "legend_geh.svg"

# ========================= STYLING =========================
FIGSIZE = (10, 10)
LINEWIDTH = 3

BASEMAP_MAX_PIX = 2000
TILES_ALPHA = 0.5

# color: red over / blue under
COLOR_OVER  = "#d73027"
COLOR_UNDER = "#4575b4"

# alpha scaling
ALPHA_MIN, ALPHA_MAX = 0.15, 0.95
PCTL_CAP = 95  # use 95th percentile as cap

# CRS
EDGES_ASSUME_CRS_IF_MISSING = "EPSG:4326"

# ========================= HELPERS =========================
def read_basemap_crop_27700(tif_path: str, bbox_27700, max_pix: int):
    xmin, ymin, xmax, ymax = bbox_27700
    with rasterio.open(tif_path) as src:
        if src.crs is None or src.crs.to_epsg() != 27700:
            raise ValueError("Basemap must have CRS EPSG:27700.")
        win = from_bounds(xmin, ymin, xmax, ymax, transform=src.transform).round_offsets().round_lengths()
        win = win.intersection(rasterio.windows.Window(0, 0, src.width, src.height))
        scale = max(win.width / max_pix, win.height / max_pix, 1.0)
        out_w = int(max(1, win.width / scale))
        out_h = int(max(1, win.height / scale))
        data = src.read(window=win, out_shape=(src.count, out_h, out_w), resampling=Resampling.bilinear)
        img = np.transpose(data, (1, 2, 0))
        left, bottom, right, top = rasterio.windows.bounds(win, src.transform)
        extent = (left, right, bottom, top)
        return img, extent

def ensure_27700(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    if gdf.crs is None:
        gdf = gdf.set_crs(EDGES_ASSUME_CRS_IF_MISSING)
    if gdf.crs.to_epsg() != 27700:
        gdf = gdf.to_crs(27700)
    return gdf

def _geom_to_segments(geom):
    if geom is None or geom.is_empty:
        return []
    gt = geom.geom_type
    if gt == "LineString":
        return [np.asarray(geom.coords, dtype=float)]
    if gt == "MultiLineString":
        return [np.asarray(g.coords, dtype=float) for g in geom.geoms if (g is not None and not g.is_empty)]
    return []

def infer_direction(geom):
    """northbound if start is more south than end (y_start < y_end)."""
    segs = _geom_to_segments(geom)
    if not segs:
        return None
    s = segs[0]
    y0 = float(s[0, 1])
    y1 = float(s[-1, 1])
    return "north" if y0 < y1 else "south"

def parse_edge_id(eid: str):
    p = str(eid).split("_")
    if len(p) < 3:
        raise ValueError(f"Bad edge_id (expected u_v_key): {eid}")
    return int(p[0]), int(p[1]), int(p[2])

def load_preds(json_path):
    with open(json_path, "r") as f:
        rows = json.load(f)
    # If duplicates exist, average pred/gt within edge_id (conservative)
    dfp = pd.DataFrame([{
        "edge_id": str(r["edge_id"]),
        "gt": float(r["gt"]),
        "pred": float(r["pred"]),
    } for r in rows])
    dfp = dfp.groupby("edge_id", as_index=False).mean(numeric_only=True)
    return dfp

def compute_errors(df_pred: pd.DataFrame):
    gt = df_pred["gt"].to_numpy(dtype=float)
    pred = df_pred["pred"].to_numpy(dtype=float)
    err = pred - gt
    mae = np.abs(err)
    geh = np.sqrt(2.0 * (err ** 2) / np.maximum(pred + gt, 1e-9))
    out = df_pred.copy()
    out["signed_err"] = err
    out["mae"] = mae
    out["signed_geh"] = np.sign(err) * geh
    out["geh"] = geh
    return out

def alpha_from_magnitude(mag: np.ndarray, cap: float):
    cap = float(max(cap, 1e-12))
    a = np.clip(mag / cap, 0.0, 1.0)
    return ALPHA_MIN + (ALPHA_MAX - ALPHA_MIN) * a

def add_signed_lines(ax, gdf, signed_vals, mags, cap, linewidth):
    """
    Color depends on sign (over/under). Alpha depends on magnitude.
    """
    alphas = alpha_from_magnitude(np.asarray(mags, dtype=float), cap)

    segs = []
    cols = []
    for geom, sgn, a in zip(gdf.geometry.values, signed_vals, alphas):
        color = COLOR_OVER if sgn > 0 else COLOR_UNDER
        # convert hex to rgb
        h = color.lstrip("#")
        r = int(h[0:2], 16) / 255.0
        g = int(h[2:4], 16) / 255.0
        b = int(h[4:6], 16) / 255.0

        for s in _geom_to_segments(geom):
            if s.shape[0] >= 2:
                segs.append(s)
                cols.append((r, g, b, float(a)))

    if not segs:
        return
    lc = LineCollection(segs, colors=cols, linewidths=linewidth, capstyle="round", joinstyle="round")
    ax.add_collection(lc)

def save_standalone_alpha_legend(out_pdf, title, cap_value, over_color=COLOR_OVER, under_color=COLOR_UNDER):
    """
    Legend shows:
      - red/blue meaning (over/under)
      - alpha meaning in terms of magnitude from 0 to cap_value (95th percentile)
    """
    fig = plt.figure(figsize=(4.5, 1.5))
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis("off")

    ax.text(0.02, 0.92, title, fontsize=12, weight="bold", va="top")

    # color meaning
    ax.text(0.02, 0.72, "Sign", fontsize=10, weight="bold", va="top")
    ax.plot([0.05, 0.20], [0.62, 0.62], color=over_color, linewidth=6)
    ax.text(0.22, 0.62, "Overestimation (Pred > GT)", fontsize=10, va="center")

    ax.plot([0.05, 0.20], [0.50, 0.50], color=under_color, linewidth=6)
    ax.text(0.22, 0.50, "Underestimation (Pred < GT)", fontsize=10, va="center")

    # alpha meaning
    ax.text(0.02, 0.34, "Magnitude (opacity)", fontsize=10, weight="bold", va="top")
    # draw a small opacity ramp (use gray)
    xs = np.linspace(0.05, 0.55, 6)
    mags = np.linspace(0.0, cap_value, 6)
    alps = alpha_from_magnitude(mags, cap_value)

    for x, a in zip(xs, alps):
        ax.plot([x, x + 0.05], [0.18, 0.18], color="black", linewidth=8, alpha=float(a))

    ax.text(0.05, 0.08, f"0", fontsize=9, ha="left")
    ax.text(0.55, 0.08, f"{cap_value:.3g} (cap, p{PCTL_CAP})", fontsize=9, ha="right")
    ax.text(0.05, 0.25, "More opaque = larger error", fontsize=9)

    fig.savefig(out_pdf)
    plt.close(fig)
    print(f"Saved: {out_pdf}")

# ========================= LOAD + MERGE PREDS (Train+Test) =========================
df_train = load_preds(TRAIN_JSON)
df_test  = load_preds(TEST_JSON)

df_pred = pd.concat([df_train, df_test], axis=0, ignore_index=True)

# If the same edge appears in both, average (so we truly "do not distinguish")
df_pred = df_pred.groupby("edge_id", as_index=False).mean(numeric_only=True)

df_err = compute_errors(df_pred)

# Parse u/v/key for join
uvk = df_err["edge_id"].apply(parse_edge_id)
df_err["u"] = uvk.apply(lambda x: x[0])
df_err["v"] = uvk.apply(lambda x: x[1])
df_err["key"] = uvk.apply(lambda x: x[2])

# ========================= LOAD EDGES =========================
edges = gpd.read_file(EDGES_GEOJSON)
edges = ensure_27700(edges)

need = {"u", "v", "key"}
if not need.issubset(set(edges.columns)):
    cand = None
    for c in ["edge_id", "eid", "id", "osmid", "fid"]:
        if c in edges.columns:
            cand = c
            break
    if cand is None:
        raise ValueError("Edges GeoJSON must contain columns u,v,key OR an edge_id-like column.")
    parsed = edges[cand].astype(str).apply(parse_edge_id)
    edges["u"] = parsed.apply(lambda x: x[0])
    edges["v"] = parsed.apply(lambda x: x[1])
    edges["key"] = parsed.apply(lambda x: x[2])

g2 = edges.merge(df_err, on=["u", "v", "key"], how="inner")
if len(g2) == 0:
    raise RuntimeError("No edges matched between GeoJSON and pred_results_* (check id consistency).")

# Crop to bbox
xmin, ymin, xmax, ymax = ENGLAND_BBOX_27700
g2 = g2.cx[xmin:xmax, ymin:ymax].copy()

# Split direction
g2["dir"] = g2.geometry.apply(infer_direction)
north = g2[g2["dir"] == "north"].copy()
south = g2[g2["dir"] == "south"].copy()

# Basemap (greyscale)
basemap_img, basemap_extent = read_basemap_crop_27700(BASEMAP_TIF, ENGLAND_BBOX_27700, BASEMAP_MAX_PIX)
if basemap_img.ndim == 3 and basemap_img.shape[2] >= 3:
    basemap_img = np.dot(basemap_img[..., :3], [0.299, 0.587, 0.114])

# Caps (percentile) for consistent opacity scaling across all 4 maps
mae_cap = float(np.percentile(g2["mae"].to_numpy(dtype=float), PCTL_CAP))
geh_cap = float(np.percentile(g2["geh"].to_numpy(dtype=float), PCTL_CAP))

# ========================= PLOTTING =========================
def plot_map(gdf, which_metric, out_pdf, title):
    """
    which_metric: "mae" or "geh"
    """
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.imshow(basemap_img, extent=basemap_extent, cmap="gray", vmin=0, vmax=255, alpha=TILES_ALPHA)
    ax.set_xlim(basemap_extent[0], basemap_extent[1])
    ax.set_ylim(basemap_extent[2], basemap_extent[3])

    if which_metric == "mae":
        signed_vals = gdf["signed_err"].to_numpy(dtype=float)   # sign from pred-gt
        mags = gdf["mae"].to_numpy(dtype=float)
        cap = mae_cap
    elif which_metric == "geh":
        signed_vals = gdf["signed_geh"].to_numpy(dtype=float)   # signed GEH
        mags = gdf["geh"].to_numpy(dtype=float)
        cap = geh_cap
    else:
        raise ValueError("which_metric must be 'mae' or 'geh'")

    add_signed_lines(ax, gdf, signed_vals=signed_vals, mags=mags, cap=cap, linewidth=LINEWIDTH)

    ax.set_axis_off()
    ax.set_title(title)
    plt.tight_layout()
    plt.savefig(out_pdf)
    plt.close(fig)
    print(f"Saved: {out_pdf} (n={len(gdf)})")


Saved: error_analysis\pdf_error_maps\errors_north_mae.pdf (n=2531)
Saved: error_analysis\pdf_error_maps\errors_south_mae.pdf (n=2557)
Saved: error_analysis\pdf_error_maps\errors_north_geh.pdf (n=2531)
Saved: error_analysis\pdf_error_maps\errors_south_geh.pdf (n=2557)
Saved: error_analysis\pdf_error_maps\legend_mae.pdf
Saved: error_analysis\pdf_error_maps\legend_geh.pdf


In [10]:
# 4 maps
plot_map(north, "mae", OUT_N_MAE, "Signed MAE — Northbound edges (red=over, blue=under)")
plot_map(south, "mae", OUT_S_MAE, "Signed MAE — Southbound edges (red=over, blue=under)")
plot_map(north, "geh", OUT_N_GEH, "Signed GEH — Northbound edges (red=over, blue=under)")
plot_map(south, "geh", OUT_S_GEH, "Signed GEH — Southbound edges (red=over, blue=under)")

# 2 standalone legends (no colorbar; alpha legend + sign legend)
save_standalone_alpha_legend(LEGEND_MAE, "Legend: Signed MAE", mae_cap)
save_standalone_alpha_legend(LEGEND_GEH, "Legend: Signed GEH", geh_cap)

Saved: error_analysis\pdf_error_maps\errors_north_mae.pdf (n=2531)
Saved: error_analysis\pdf_error_maps\errors_south_mae.pdf (n=2557)
Saved: error_analysis\pdf_error_maps\errors_north_geh.pdf (n=2531)
Saved: error_analysis\pdf_error_maps\errors_south_geh.pdf (n=2557)
Saved: error_analysis\pdf_error_maps\legend_mae.svg
Saved: error_analysis\pdf_error_maps\legend_geh.svg
