In [None]:
# ==============================================================
# 03_WRI_PerCountry_Metrics
# --------------------------------------------------------------
# Purpose
#   For each country where WRI Urban Land Use rasters exist:
#     1) Sample informal probabilities and classes per CSD block.
#     2) Compare WRI-based "deprived" labels to RF-based rf_label
#        across thresholds τ ∈ {0.1, 0.2, 0.3}.
#     3) Save per-block metrics, country sweeps, and per-city counts.
#
# Inputs (NOT shipped in repo; user must prepare):
#   - GPKG_ROOT: CSD segment files, one per country:
#       {country}_rf_preds.gpkg   (contains rf_label, UC_NM_MN, etc.)
#   - WRI_ROOT: city WRI rasters organised by country, e.g.:
#       WRI_urban_landuse_v1/PerCountry_Files/<country>/*.tif
#
# Outputs (SHIPPED in repo/Zenodo):
#   WRI/Outputs/PerCountry_Outputs/<country>/:
#       - {country}_wri_informal_per_block.csv
#       - {country}_wri_overlap_audit.csv
#       - {country}_wri_vs_rf_threshold_sweep_country.csv
#       - {country}_wri_vs_rf_per_block_with_preds.csv
#       - {country}_overall_deprived_counts_by_tau.csv
#       - {country}_per_city_deprived_counts_by_tau.csv
#
#   And a global report:
#       WRI/Outputs/wri_country_processing_report.csv
#
# Notes
#   - WRI rasters are downloaded from GEE (01_WRI_DataDownload.js),
#     then manually arranged into PerCountry_Files/<country>/.
#   - These rasters are not included in the repository due to size.
# ==============================================================

# ==============================================================
# NOTE ON INFORMALITY METRICS
# --------------------------------------------------------------
# This script computes two WRI-based indicators:
#   1) p_informal — share of pixels labeled {2,3} in the categorical LULC band.
#   2) m_informal — average probability from the "atomistic" and
#                   "informal_subdivision" bands (0–1).
#
# For FULL reproducibility, both are calculated.
#
# HOWEVER: The for further downstream tasks uses ONLY p_informal for all WRI–CSD
# comparisons, figures, and quantitative results. m_informal is included
# only as a secondary diagnostic and is not used in any analysis.
# ==============================================================


In [None]:
from pathlib import Path
from typing import Dict, List
import os
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio as rio
from rasterio import mask as rio_mask
from shapely.geometry import box
from joblib import Parallel, delayed


In [None]:
# --------------------------------------------------------------
# 1️⃣ CONFIG (EDIT PATHS FOR YOUR ENVIRONMENT)
# --------------------------------------------------------------

# CSD segment GPKGs (with rf_label)
GPKG_ROOT = Path("../2_modelling/02_application/Filtered_80pct_allattributes")


# WRI city rasters organised by country (NOT in repo)
#   Example structure:
#   WRI_urban_landuse_v1/
#       PerCountry_Files/
#           india/
#               mumbai_y2020.tif
#               ...
#           kenya/
#               ...
WRI_ROOT = Path(".../WRI/PerCountry_Files")


# Output parent (this is what you expose in GitHub/Zenodo)
OUT_PARENT = Path(".../WRI/Outputs")

OUT_PARENT.mkdir(parents=True, exist_ok=True)

CITY_COL = "UC_NM_MN"
RF_COL   = "rf_label"

TAUS = [0.1, 0.2, 0.3]
MIN_VALID_PX = 1
PROB_DIVISOR = 100.0
POSSIBLE_ID_COLS = ["block_id", "ID_SEG", "ID_HDC_G0", "ID_HDC_G0_SEG", "ID"]

# Keep GDAL single-threaded inside each process (more stable).
os.environ.setdefault("GDAL_NUM_THREADS", "1")
os.environ.setdefault("RIO_MAX_WORKERS", "1")

In [None]:
# --------------------------------------------------------------
# 2️⃣ HELPERS
# --------------------------------------------------------------

def list_wri_countries(wri_root: Path) -> List[str]:
    """Return sorted list of country folder names under WRI_ROOT."""
    return sorted([p.name for p in wri_root.iterdir() if p.is_dir()])


def list_wri_tifs(country_dir: Path) -> List[Path]:
    """List all GeoTIFFs in a given country folder."""
    return sorted(list(country_dir.glob("*.tif")) + list(country_dir.glob("*.tiff")))


def read_blocks(gpkg_path: Path, city_col: str, rf_col: str) -> gpd.GeoDataFrame:
    """Read block GPKG, enforce presence of city and rf_label, clean geometry."""
    gdf = gpd.read_file(gpkg_path)
    if city_col not in gdf.columns:
        raise ValueError(f"Missing '{city_col}' in {gpkg_path}")
    if rf_col not in gdf.columns:
        raise ValueError(f"Missing '{rf_col}' in {gpkg_path}")

    gdf = gdf[~gdf.geometry.is_empty & gdf.geometry.notna()].copy()
    gdf["geometry"] = gdf.geometry.buffer(0)
    if gdf.crs is None:
        raise ValueError(f"{gpkg_path} has no CRS; please define it.")
    return gdf


def read_raster_meta(tif: Path):
    """Return basic metadata for a WRI raster."""
    with rio.open(tif) as ds:
        return dict(
            crs=ds.crs,
            bounds=box(*ds.bounds),
            descriptions=list(ds.descriptions) if ds.descriptions else [],
        )


def band_map_from_descriptions(desc: List[str]) -> Dict[str, int]:
    """
    Map band descriptions to indices:
      - lulc
      - atomistic (probability)
      - informal_subdivision (probability)
    """
    low = [d.lower() if d else "" for d in desc]
    find = lambda key: (low.index(key) + 1) if key in low else None
    bm = {
        "lulc": find("lulc"),
        "atomistic": find("atomistic"),
        "informal_subdivision": find("informal_subdivision"),
    }
    # fallback if lulc band not tagged
    if bm["lulc"] is None:
        bm["lulc"] = 1
    return bm


def block_stats_for_tif(block_row, tif_path: Path, bm: Dict[str, int]) -> Dict:
    """
    For one block geometry and one WRI raster:
      - valid_px: number of overlapping pixels with all bands
      - p_informal: share of lulc pixels in {2,3}
      - m_informal: average of atomistic + informal_subdivision
                    probabilities (scaled 0–1)
    """
    geom = block_row.geometry
    out = {"valid_px": 0, "p_informal": np.nan, "m_informal": np.nan}

    try:
        with rio.open(tif_path) as ds:
            r_bounds = box(*ds.bounds)
            # Quick reject by geometry bbox
            if not box(*geom.bounds).intersects(r_bounds):
                return out

            # LULC categorical band
            a_lulc, _ = rio_mask.mask(
                ds, [geom.__geo_interface__],
                crop=True, filled=False, indexes=bm["lulc"]
            )
            a_lulc = a_lulc[0] if a_lulc.ndim == 3 else a_lulc

            # Probability bands required for m_informal
            if (bm["atomistic"] is None) or (bm["informal_subdivision"] is None):
                return out

            a_atom, _ = rio_mask.mask(
                ds, [geom.__geo_interface__],
                crop=True, filled=False, indexes=bm["atomistic"]
            )
            a_info, _ = rio_mask.mask(
                ds, [geom.__geo_interface__],
                crop=True, filled=False, indexes=bm["informal_subdivision"]
            )
            a_atom = a_atom[0] if a_atom.ndim == 3 else a_atom
            a_info = a_info[0] if a_info.ndim == 3 else a_info

            # Valid pixels: where none of the bands are masked
            if hasattr(a_lulc, "mask"):
                valid = ~a_lulc.mask
                if hasattr(a_atom, "mask"):
                    valid &= ~a_atom.mask
                if hasattr(a_info, "mask"):
                    valid &= ~a_info.mask
            else:
                valid = np.ones_like(a_lulc, dtype=bool)

            valid_px = int(valid.sum())
            if valid_px < MIN_VALID_PX:
                return out

            # p_informal from lulc ∈ {2,3}
            lulc_vals = a_lulc.data[valid] if hasattr(a_lulc, "data") else a_lulc[valid]
            p_informal = float(
                np.count_nonzero((lulc_vals == 2) | (lulc_vals == 3))
            ) / valid_px

            # m_informal: mean of atomistic + informal_subdivision (0–100), scaled to 0–1
            atom = (a_atom.data[valid] if hasattr(a_atom, "data") else a_atom[valid]).astype(float)
            info = (a_info.data[valid] if hasattr(a_info, "data") else a_info[valid]).astype(float)
            m_informal = float(((atom + info) / PROB_DIVISOR / 2.0).mean())

            out.update(
                {"valid_px": valid_px, "p_informal": p_informal, "m_informal": m_informal}
            )
            return out
    except ValueError:
        # No overlap or geometry issues
        return out


def metrics(y_true, y_pred):
    """Basic classification metrics including balanced accuracy."""
    y_true = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)
    TP = int(((y_true == 1) & (y_pred == 1)).sum())
    TN = int(((y_true == 0) & (y_pred == 0)).sum())
    FP = int(((y_true == 0) & (y_pred == 1)).sum())
    FN = int(((y_true == 1) & (y_pred == 0)).sum())
    prec = TP / (TP + FP) if (TP + FP) else 0.0
    rec  = TP / (TP + FN) if (TP + FN) else 0.0
    f1   = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
    bal  = 0.5 * (
        (TP / (TP + FN) if (TP + FN) else 0.0) +
        (TN / (TN + FP) if (TN + FP) else 0.0)
    )
    return dict(TP=TP, FP=FP, TN=TN, FN=FN,
                precision=prec, recall=rec, F1=f1, balanced_acc=bal)


def sweep(df, col, taus, rule_name):
    """Apply thresholds over a probability column and compute metrics."""
    rows = []
    for tau in taus:
        yhat = (df[col] >= tau).astype(int)
        m = metrics(df["rf_label"], yhat)
        m.update({
            "Rule / Comparison": rule_name,
            "τ (threshold)": tau,
            "n (blocks)": len(df),
        })
        rows.append(m)
    return pd.DataFrame(rows)


def pick_block_uid(df: gpd.GeoDataFrame) -> List[str]:
    """Return whichever ID columns exist for provenance."""
    return [c for c in POSSIBLE_ID_COLS if c in df.columns]


In [None]:
# --------------------------------------------------------------
# 3️⃣ PER-COUNTRY WORKER
# --------------------------------------------------------------

def process_country(country: str) -> dict:
    """
    Process one country:
      - Sample WRI metrics per block across all rasters.
      - Write per-block CSVs, sweeps, overall/per-city counts.
      - Return a small status dict.
    """
    wri_dir = WRI_ROOT / country
    gpkg    = GPKG_ROOT / f"{country}_rf_preds.gpkg"
    out_dir = OUT_PARENT / country
    out_dir.mkdir(parents=True, exist_ok=True)

    if not gpkg.exists():
        return {"country": country, "status": "no_gpkg"}

    tifs = list_wri_tifs(wri_dir)
    if not tifs:
        return {"country": country, "status": "no_wri_tifs"}

    # Load blocks once (native CRS)
    try:
        blocks = read_blocks(gpkg, CITY_COL, RF_COL)
    except Exception as e:
        return {"country": country, "status": f"gpkg_error: {e}"}

    id_cols = pick_block_uid(blocks)
    rows, audit = [], []

    for tif in tifs:
        meta = read_raster_meta(tif)
        bm = band_map_from_descriptions(meta["descriptions"])

        # BBox in blocks CRS for quick filtering
        if blocks.crs != meta["crs"]:
            rb_blocks_crs = (
                gpd.GeoDataFrame(geometry=[meta["bounds"]], crs=meta["crs"])
                .to_crs(blocks.crs)
                .geometry.iloc[0]
            )
        else:
            rb_blocks_crs = meta["bounds"]

        # Spatial index if available
        try:
            subset_idx = blocks.sindex.query(rb_blocks_crs, predicate="intersects")
            subset = blocks.iloc[subset_idx].copy()
        except Exception:
            subset = blocks[blocks.intersects(rb_blocks_crs)].copy()

        subset = subset[subset.intersects(rb_blocks_crs)]
        if subset.empty:
            audit.append({"tif": tif.name, "overlapping_blocks": 0})
            continue

        # Match CRS of raster
        if subset.crs != meta["crs"]:
            subset = subset.to_crs(meta["crs"])

        # Per-block metrics on this TIFF
        for _, r in subset.iterrows():
            st = block_stats_for_tif(r, tif, bm)
            rec = {
                "tif_name": tif.name,
                "valid_px": st["valid_px"],
                "p_informal": st["p_informal"],
                "m_informal": st["m_informal"],
                CITY_COL: r[CITY_COL],
                "rf_label": int(r[RF_COL]) if pd.notna(r[RF_COL]) else np.nan,
            }
            for c in id_cols:
                rec[c] = r[c]
            rows.append(rec)

        audit.append({"tif": tif.name, "overlapping_blocks": int(len(subset))})

    # Save per-block + audit
    stats_df = pd.DataFrame(rows)
    audit_df = pd.DataFrame(audit).sort_values("tif")

    stats_csv = out_dir / f"{country}_wri_informal_per_block.csv"
    audit_csv = out_dir / f"{country}_wri_overlap_audit.csv"
    stats_df.to_csv(stats_csv, index=False)
    audit_df.to_csv(audit_csv, index=False)

    # Prepare comparison table (only blocks with RF labels + valid pixels)
    cmp = stats_df.dropna(subset=["rf_label"]).copy()
    cmp = cmp[cmp["valid_px"].fillna(0) >= MIN_VALID_PX].copy()
    if cmp.empty:
        return {"country": country, "status": "no_overlap"}

    cmp["rf_label"] = cmp["rf_label"].astype(int)

    # Threshold sweeps: p_informal and m_informal
    country_tbl = pd.concat(
        [
            sweep(cmp, "p_informal", TAUS, "WRI (p_informal)"),
            sweep(cmp, "m_informal", TAUS, "WRI (m_informal)"),
        ],
        ignore_index=True,
    )
    sweep_csv = out_dir / f"{country}_wri_vs_rf_threshold_sweep_country.csv"
    country_tbl.to_csv(sweep_csv, index=False)

    # Per-block predictions for each τ
    for tau in TAUS:
        cmp[f"pred_p_{tau:.2f}"] = (cmp["p_informal"] >= tau).astype(int)
        cmp[f"pred_m_{tau:.2f}"] = (cmp["m_informal"] >= tau).astype(int)

    per_block_preds_csv = out_dir / f"{country}_wri_vs_rf_per_block_with_preds.csv"
    cmp.to_csv(per_block_preds_csv, index=False)

    # Overall counts (country-level)
    n_blocks = len(cmp)
    n_rf_dep = int((cmp["rf_label"] == 1).sum())
    n_rf_non = n_blocks - n_rf_dep

    overall_rows = []
    for tau in TAUS:
        overall_rows.append({
            "country": country,
            "τ (threshold)": tau,
            "n_blocks": n_blocks,
            "rf_deprived_n": n_rf_dep,
            "rf_non_deprived_n": n_rf_non,
            "WRI(p)_deprived_n": int((cmp[f"pred_p_{tau:.2f}"] == 1).sum()),
            "WRI(m)_deprived_n": int((cmp[f"pred_m_{tau:.2f}"] == 1).sum()),
        })
    overall_counts = pd.DataFrame(overall_rows)
    overall_counts_csv = out_dir / f"{country}_overall_deprived_counts_by_tau.csv"
    overall_counts.to_csv(overall_counts_csv, index=False)

    # Per-city counts
    per_city_tables = []
    for tau in TAUS:
        grp = (
            cmp.groupby(CITY_COL)
               .agg(
                   n_blocks=("rf_label", "size"),
                   rf_deprived_n=("rf_label", lambda s: int((s == 1).sum())),
                   WRI_p_deprived_n=(f"pred_p_{tau:.2f}", "sum"),
                   WRI_m_deprived_n=(f"pred_m_{tau:.2f}", "sum"),
               )
               .reset_index()
        )
        grp.insert(0, "country", country)
        grp.insert(1, "τ (threshold)", tau)
        per_city_tables.append(grp)

    per_city_counts = pd.concat(per_city_tables, ignore_index=True)
    per_city_counts_csv = out_dir / f"{country}_per_city_deprived_counts_by_tau.csv"
    per_city_counts.to_csv(per_city_counts_csv, index=False)

    return {"country": country, "status": "ok", "n_blocks": n_blocks, "rf_dep": n_rf_dep}



In [None]:
# --------------------------------------------------------------
# 4️⃣ RUN IN PARALLEL ACROSS COUNTRIES + SAVE REPORT
# --------------------------------------------------------------

countries = list_wri_countries(WRI_ROOT)
print(
    f"Processing {len(countries)} country folder(s): "
    f"{countries[:6]}{' ...' if len(countries) > 6 else ''}"
)

N_JOBS = -1  # use all cores; set to e.g. 6 to limit
results = Parallel(
    n_jobs=N_JOBS, backend="loky", prefer="processes"
)(
    delayed(process_country)(c) for c in countries
)

report_df = pd.DataFrame(results)
print("\nStatus summary:")
print(report_df["status"].value_counts())

# Save high-level report alongside OUT_PARENT
report_csv = OUT_PARENT.parent / "wri_country_processing_report.csv"
report_df.to_csv(report_csv, index=False)
print(f"\n✅ Country processing report saved to: {report_csv}")