# SSI vs RF Comparison- Notebook B (Per-country + Pooled)
This notebook:
- Loads each country's rf_label GPKG (80% filtered)
- Loads the clipped SSI raster from Notebook A
- Extracts SSI metrics for every block:
         p_any   = proportion of SSI > 0
         m_ssi   = mean SSI
         p_space = proportion of SpaceDef > 0
 - Computes precision/recall/F1 for thresholds (0.1, 0.2, 0.3)

Writes:

a) Per-country SSI vs RF metrics (per_country/*.csv)

b) Pooled metrics across all SSI countries

c) Stacked per-country metrics

d) Country processing audit

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio as rio
from rasterio import mask as rio_mask
from shapely.geometry import box

# 1️⃣ Paths (edit relative paths in the final repo)

In [None]:
# RF predictions (80% quality filtered, all attributes)
GPKG_DIR    = Path("../2_modelling/02_application/Filtered_80pct_allattributes")

# SSI per-country outputs from Notebook A
SSI_PARENT  = Path("./PerCountry_Outputs")   # from Notebook A (02_ssi_clip_to_cities.ipynb)

# per-country SSI tiff filename inside each country folder
SSI_FILENAME = "SSIpack100m_clipped_to_city_blocks_SIGNAL.tif"

# Output root for all SSI vs RF metrics
OUT_ROOT    = Path("./Pooled_Results")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

# Subfolder for per-country CSVs
OUT_COUNTRY = OUT_ROOT / "per_country"
OUT_COUNTRY.mkdir(parents=True, exist_ok=True)




# 2️⃣ Countries (SSI-covered set)

In [None]:
COUNTRIES = [
    "democratic_republic_of_the_congo","south_africa","sudan","ethiopia","nigeria","niger",
    "mozambique","angola","namibia","mauritania","somalia","egypt","tanzania","zambia",
    "central_african_republic","cameroon","kenya","republic_of_congo","burkina_faso",
    "cote_d_ivoire","ghana","guinea","uganda","senegal","malawi","benin","liberia","togo",
    "sierra_leone","rwanda","burundi","gambia","djibouti"
]

# Thresholds for turning SSI-derived scores into binary "deprived/not" flags
THRESHOLDS = [0.1, 0.2, 0.3]

# Fallback band order if descriptions missing
EXPECTED_BANDS = ["WaterDef", "SanitationDef", "HousingDef", "SpaceDef", "SSI"]

# 3️⃣ Metric helpers

In [None]:
def metrics(y_true, y_pred):
    """Compute TP/TN/FP/FN, precision, recall, F1, balanced accuracy."""
    y_true = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)

    TP = int(((y_true == 1) & (y_pred == 1)).sum())
    TN = int(((y_true == 0) & (y_pred == 0)).sum())
    FP = int(((y_true == 0) & (y_pred == 1)).sum())
    FN = int(((y_true == 1) & (y_pred == 0)).sum())

    prec = TP / (TP + FP) if (TP + FP) else 0.0
    rec  = TP / (TP + FN) if (TP + FN) else 0.0
    f1   = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
    bal  = 0.5 * (
        (TP / (TP + FN) if (TP + FN) else 0.0) +
        (TN / (TN + FP) if (TN + FP) else 0.0)
    )

    return dict(
        TP=TP, FP=FP, TN=TN, FN=FN,
        precision=prec, recall=rec, F1=f1, balanced_acc=bal
    )


def sweep_metrics(df, prob_col, thresholds, rule_name, country):
    """Apply a threshold sweep for a given SSI-derived probability column."""
    rows = []
    for tau in thresholds:
        yhat = (df[prob_col] >= tau).astype(int)
        m = metrics(df["rf_label"], yhat)
        m.update({
            "country": country,
            "Rule / Comparison": rule_name,
            "τ (threshold)": tau,
            "n (blocks)": int(len(df))
        })
        rows.append(m)
    return pd.DataFrame(rows)

# 4️⃣ Raster band utilities

In [None]:
def detect_bands(ds):
    """
    Return dict with band indices (1-based) for 'SSI' and 'SpaceDef'.
    Use descriptions if present, otherwise fallback to EXPECTED_BANDS.
    """
    desc = ds.descriptions or ()
    if desc and "SSI" in desc and "SpaceDef" in desc:
        return {"SSI": desc.index("SSI") + 1,
                "SpaceDef": desc.index("SpaceDef") + 1}
    # Fallback to expected order: [Water, Sanitation, Housing, SpaceDef, SSI]
    return {"SSI": 5, "SpaceDef": 4}


def sample_block_summaries(ds, geom, band_idx):
    """
    Sample a single band over a polygon.

    Returns:
        (prop_pos, mean_val, valid_px)
        - prop_pos: proportion of valid pixels with value > 0
        - mean_val: mean of all valid pixels (including zeros)
        - valid_px: number of valid pixels
    """
    try:
        arr, _ = rio_mask.mask(
            ds, [geom.__geo_interface__],
            crop=True, filled=False, indexes=band_idx
        )
    except ValueError:
        # Polygon does not overlap raster
        return (np.nan, np.nan, 0)

    a = arr[0] if arr.ndim == 3 else arr

    if a.size == 0:
        return (np.nan, np.nan, 0)

    valid = ~a.mask if hasattr(a, "mask") else np.ones_like(a, dtype=bool)
    vpx = int(valid.sum())

    if vpx == 0:
        return (np.nan, np.nan, 0)

    vals = a.compressed() if hasattr(a, "compressed") else a[valid]
    prop_pos = float(np.count_nonzero(vals > 0) / vpx)
    mean_val = float(vals.mean())

    return (prop_pos, mean_val, vpx)

# 5️⃣ Per-country processing

In [None]:
def process_country(country: str):
    """
    For a given country:
      - load RF GPKG and SSI TIFF
      - align CRS
      - compute per-block:
          p_any, m_ssi, p_space, valid_px, space_valid_px
      - build per-country metric tables (any, mean, space)
      - write per-country CSVs
      - return (cmp_df, summary_df, audit_dict)
    """
    print(f"\n=== {country} ===")

    gpkg_path = GPKG_DIR / f"{country}_rf_preds_filtered80.gpkg"
    ssi_path  = SSI_PARENT / country / SSI_FILENAME

    if not gpkg_path.exists():
        print(f"  ! Missing GPKG: {gpkg_path}")
        return None, None, dict(country=country, status="missing_gpkg")

    if not ssi_path.exists():
        # (In practice this should exist because of Notebook A)
        print(f"  ! Missing SSI TIFF: {ssi_path}")
        return None, None, dict(country=country, status="missing_ssi")

    # --- Read RF GPKG ---
    blocks = gpd.read_file(gpkg_path)

    if "rf_label" not in blocks.columns:
        print("  ! rf_label column missing; skipping.")
        return None, None, dict(country=country, status="no_rf_label")

    keep_cols = ["rf_label", "UC_NM_MN"]
    keep_cols = [c for c in keep_cols if c in blocks.columns]
    blocks = blocks[keep_cols + ["geometry"]].copy()

    # --- Read SSI raster once for metadata, then again for sampling ---
    with rio.open(ssi_path) as ds:
        r_crs = ds.crs
        r_bounds = box(*ds.bounds)
        bands = detect_bands(ds)

    # Align CRS
    if blocks.crs != r_crs:
        blocks = blocks.to_crs(r_crs)

    # Basic clean
    blocks = blocks[blocks.geometry.notna() & ~blocks.geometry.is_empty].copy()
    blocks["geometry"] = blocks.geometry.buffer(0)
    blocks["block_id"] = np.arange(len(blocks))

    # --- Compute per-block SSI summaries ---
    rows = []
    with rio.open(ssi_path) as ds:
        for _, r in blocks.iterrows():
            geom = r.geometry

            # quick bbox rejection
            if not box(*geom.bounds).intersects(r_bounds):
                rows.append({
                    "block_id": r["block_id"],
                    "rf_label": np.nan if pd.isna(r["rf_label"]) else int(r["rf_label"]),
                    "UC_NM_MN": r.get("UC_NM_MN", None),
                    "p_any": np.nan, "m_ssi": np.nan, "valid_px": 0,
                    "p_space": np.nan, "space_valid_px": 0
                })
                continue

            # SSI band
            p_any, m_ssi, vpx = sample_block_summaries(ds, geom, bands["SSI"])
            # SpaceDef band
            p_space, m_space, vpx_space = sample_block_summaries(ds, geom, bands["SpaceDef"])

            rows.append({
                "block_id": r["block_id"],
                "rf_label": np.nan if pd.isna(r["rf_label"]) else int(r["rf_label"]),
                "UC_NM_MN": r.get("UC_NM_MN", None),
                "p_any": p_any if not np.isnan(p_any) else np.nan,
                "m_ssi": m_ssi if not np.isnan(m_ssi) else np.nan,
                "valid_px": int(vpx),
                "p_space": p_space if not np.isnan(p_space) else np.nan,
                "space_valid_px": int(vpx_space)
            })

    blk = pd.DataFrame(rows)

    # --- Comparison sets ---
    cmp_df = blk.dropna(subset=["rf_label"]).copy()
    cmp_df = cmp_df[cmp_df["valid_px"] > 0].copy()
    cmp_df["rf_label"] = cmp_df["rf_label"].astype(int)

    # Require SpaceDef coverage for the SpaceDef-only rule
    cmp_df_space = cmp_df[cmp_df["space_valid_px"] > 0].copy()

    # --- Build tables (per-country) ---
    any_tbl   = sweep_metrics(cmp_df,       "p_any",  THRESHOLDS, "SSI (any-pixel)",      country)
    mean_tbl  = sweep_metrics(cmp_df,       "m_ssi",  THRESHOLDS, "SSI (mean-SSI)",       country)
    space_tbl = sweep_metrics(cmp_df_space, "p_space",THRESHOLDS, "SSI (SpaceDef only)",  country)

    summary = pd.concat([any_tbl, mean_tbl, space_tbl], ignore_index=True)

    # Light version for main use
    light = summary[[
        "country","Rule / Comparison","τ (threshold)",
        "precision","recall","F1","balanced_acc","n (blocks)"
    ]].copy()
    light[["precision","recall","F1","balanced_acc"]] = (
        light[["precision","recall","F1","balanced_acc"]].round(2)
    )

    # Save per-country CSVs in OUT_COUNTRY
    light.to_csv(OUT_COUNTRY / f"{country}_ssi_rf_summary.csv", index=False)
    summary.to_csv(OUT_COUNTRY / f"{country}_ssi_rf_summary_with_counts.csv", index=False)

    # Audit entry
    audit = dict(
        country=country,
        blocks_total=int(len(blocks)),
        blocks_with_rf=int(blk["rf_label"].notna().sum()),
        blocks_with_ssi=int((blk["valid_px"] > 0).sum()),
        blocks_used=int(len(cmp_df)),
        blocks_used_space=int(len(cmp_df_space)),
        status="ok"
    )
    print(f"  ok: used {audit['blocks_used']} blocks (space rule: {audit['blocks_used_space']})")

    return cmp_df, summary, audit

# 6️⃣ Loop over all countries

In [None]:
all_cmp_rows = []
all_country_tables = []
audits = []

for c in COUNTRIES:
    cmp_df, summary, audit = process_country(c)
    audits.append(audit)
    if (cmp_df is not None) and (summary is not None):
        tmp = cmp_df.copy()
        tmp["country"] = c
        keep = ["country","block_id","rf_label","p_any","m_ssi","valid_px",
                "p_space","space_valid_px"]
        all_cmp_rows.append(tmp[keep])
        all_country_tables.append(summary)

# Save country-level audit
audit_df = pd.DataFrame(audits)
audit_df.to_csv(OUT_ROOT / "country_processing_audit.csv", index=False)
audit_df

# 7️⃣ Pooled SSI vs RF metrics across all countries

In [None]:
if len(all_cmp_rows) == 0:
    raise RuntimeError("No countries processed successfully; check paths and inputs.")

pooled_df = pd.concat(all_cmp_rows, ignore_index=True)

pooled_any   = pooled_df[pooled_df["valid_px"] > 0].copy()
pooled_space = pooled_df[pooled_df["space_valid_px"] > 0].copy()

rows = []
for tau in THRESHOLDS:
    # any-pixel rule
    yhat_any = (pooled_any["p_any"] >= tau).astype(int)
    m_any = metrics(pooled_any["rf_label"], yhat_any)
    m_any.update({
        "Rule / Comparison": "SSI (any-pixel)",
        "τ (threshold)": tau,
        "n (blocks)": int(len(pooled_any))
    })
    rows.append(m_any)

    # mean-SSI rule
    yhat_mean = (pooled_any["m_ssi"] >= tau).astype(int)
    m_mean = metrics(pooled_any["rf_label"], yhat_mean)
    m_mean.update({
        "Rule / Comparison": "SSI (mean-SSI)",
        "τ (threshold)": tau,
        "n (blocks)": int(len(pooled_any))
    })
    rows.append(m_mean)

    # SpaceDef-only rule
    yhat_space = (pooled_space["p_space"] >= tau).astype(int)
    m_space = metrics(pooled_space["rf_label"], yhat_space)
    m_space.update({
        "Rule / Comparison": "SSI (SpaceDef only)",
        "τ (threshold)": tau,
        "n (blocks)": int(len(pooled_space))
    })
    rows.append(m_space)

pooled_tbl = pd.DataFrame(rows)

pooled_tbl_light = pooled_tbl[[
    "Rule / Comparison","τ (threshold)","precision","recall","F1","balanced_acc","n (blocks)"
]].copy()
pooled_tbl_light[["precision","recall","F1","balanced_acc"]] = (
    pooled_tbl_light[["precision","recall","F1","balanced_acc"]].round(3)
)

# Save pooled metrics
pooled_tbl.to_csv(OUT_ROOT / "pooled_ssi_rf_metrics_with_counts.csv", index=False)
pooled_tbl_light.to_csv(OUT_ROOT / "pooled_ssi_rf_metrics.csv", index=False)

print("Saved pooled metrics to:", OUT_ROOT)
pooled_tbl_light

# 8️⃣ Stacked per-country metrics table

In [None]:
if len(all_country_tables):
    stacked_country = pd.concat(all_country_tables, ignore_index=True)
    stacked_light = stacked_country[[
        "country","Rule / Comparison","τ (threshold)",
        "precision","recall","F1","balanced_acc","n (blocks)"
    ]].copy()

    stacked_light[["precision","recall","F1","balanced_acc"]] = (
        stacked_light[["precision","recall","F1","balanced_acc"]].round(3)
    )

    stacked_country.to_csv(OUT_ROOT / "ssi_rf_country_metrics_with_counts.csv", index=False)
    stacked_light.to_csv(OUT_ROOT / "ssi_rf_country_metrics.csv", index=False)

    stacked_light.head(12)
else:
    print("No per-country tables collected.")