# SSI vs RF: Figures and Tables

- Figure: 3×3 multipanel SSI–RF alignment boxplots (by method × metric)
- Table 1: rf_ssi_population_summary.csv  (per country + overall)
- Table 2: rf_ssi_population_summary_GLOBAL_rule_threshold_table.csv
- Table 3: rf_ssi_population_summary_GLOBAL_rule_threshold_table_millions.csv

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import geopandas as gpd

import rasterio
from rasterio.mask import mask as rio_mask
from shapely.geometry import box

from joblib import Parallel, delayed

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# SSI vs RF country-level metrics from Notebook B
COUNTRY_METRICS = Path("...Pooled_Results/ssi_rf_country_metrics.csv")


# RF predictions (80% quality filtered, all attributes)
RF_DIR = Path("../2_modelling/02_application/Filtered_80pct_allattributes")

# SSI per-country outputs from Notebook A
SSI_PARENT = Path("./PerCountry_Outputs") # from Notebook A (02_ssi_clip_to_cities.ipynb)

# Per-country SSI TIF (clipped to blocks, SIGNAL policy)
SSI_FILENAME = "SSIpack100m_clipped_to_city_blocks_SIGNAL.tif"

# Population summary CSV (output)
POP_SUMMARY_CSV = Path("./Pooled_Results")


# Countries with SSI coverage
COUNTRIES = [
    "democratic_republic_of_the_congo","south_africa","sudan","ethiopia","nigeria","niger",
    "mozambique","angola","namibia","mauritania","somalia","egypt","tanzania","zambia",
    "central_african_republic","cameroon","kenya","republic_of_congo","burkina_faso",
    "cote_d_ivoire","ghana","guinea","uganda","senegal","malawi","benin","liberia","togo",
    "sierra_leone","rwanda","burundi","gambia","djibouti"
]

# SSI thresholds to evaluate
THRESHOLDS = [0.1, 0.2, 0.3]

# Confirm band indices are consistent in your stack
# [WaterDef=1, SanitationDef=2, HousingDef=3, SpaceDef=4, SSI=5]
BANDS = {"SpaceDef": 4, "SSI": 5}

# 1️⃣ MULTIPANEL SSI–RF BOX PLOTS (3 methods × 3 metrics)

In [None]:
FIG_DIR = COUNTRY_METRICS.parent / "Figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# --- Load and prepare country metrics ---
df = pd.read_csv(COUNTRY_METRICS)
df = df.rename(columns={"τ (threshold)": "tau"})

for c in ["precision", "recall", "F1"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

df["tau"] = df["tau"].astype(str)

rule_map = {
    "SSI (any-pixel)": "p_any",
    "SSI (mean-SSI)": "m_ssi",
    "SSI (SpaceDef only)": "p_space",
}
df["Method"] = df["Rule / Comparison"].map(rule_map)

# Order: rows = Mean-SSI, SpaceDef, Any-pixel
keep_methods = ["m_ssi", "p_space", "p_any"]
df = df[df["Method"].isin(keep_methods)].copy()

metrics = ["precision", "recall", "F1"]
tau_order = ["0.1", "0.2", "0.3"]

# --- Styling ---
sns.set_theme(context="paper", style="white", rc={
    "axes.edgecolor": "0.4",
    "axes.linewidth": 0.8,
    "axes.labelsize": 10,
    "font.size": 9.5,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
})
plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.family"] = "DejaVu Sans"

palette = {"0.1": "#f0f0f0", "0.2": "#bdbdbd", "0.3": "#636363"}

# --- 3×3 layout: rows = methods, cols = metrics ---
fig, axes = plt.subplots(
    nrows=3, ncols=3, figsize=(8.8, 6.8),
    sharey=True, sharex=False,
    gridspec_kw={"hspace": 0.35, "wspace": 0.25},
)

for r, method in enumerate(keep_methods):
    sub = df[df["Method"] == method]
    for c_idx, metric in enumerate(metrics):
        ax = axes[r, c_idx]

        # 0.5 reference line
        ax.axhline(0.5, lw=0.6, ls="--", color="#bdbdbd", zorder=0)

        # Boxplot
        sns.boxplot(
            data=sub,
            x="tau", y=metric,
            order=tau_order,
            palette=palette, width=0.5,
            whis=(5, 95), showfliers=False,
            boxprops=dict(linewidth=0.9, edgecolor="0.4"),
            whiskerprops=dict(linewidth=0.8, color="0.4"),
            medianprops=dict(linewidth=2.0, color="black"),
            capprops=dict(linewidth=0.8, color="0.4"),
            ax=ax,
        )

        # Scatter overlay + median labels
        for i, tau in enumerate(tau_order):
            vals = sub.loc[sub["tau"] == tau, metric].dropna()
            x_vals = np.random.normal(i, 0.05, size=len(vals))
            ax.scatter(
                x_vals, vals,
                s=14, color="black", alpha=0.6,
                edgecolor="white", linewidth=0.3, zorder=3,
            )
            if not vals.empty:
                m = vals.median()
                ax.text(
                    i + 0.29, m + 0.01, f"{m:.2f}",
                    va="center", ha="left", fontsize=9,
                    color="#222222", fontweight="bold",
                )

        ax.set_xticklabels(["τ=0.1", "τ=0.2", "τ=0.3"])
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_ylim(0, 1.02)
        sns.despine(ax=ax)

        # Row labels (methods)
        if c_idx == 0:
            row_label = {
                "m_ssi": "Mean-SSI",
                "p_space": "SpaceDef",
                "p_any": "Any-pixel",
            }[method]
            ax.text(
                -0.30, 0.5, row_label,
                fontsize=10.5, fontweight="bold",
                rotation=90, va="center", ha="center",
                transform=ax.transAxes,
            )

        # Column titles (metrics)
        if r == 0:
            ax.set_title(metric.capitalize(), fontsize=10, pad=4)

fig.suptitle(
    "Country-level alignment of SSI with CSD — Threshold effects across methods",
    fontsize=12.5, fontweight="bold", y=0.95,
)

plt.tight_layout(rect=[0, 0, 1, 0.96])

outfile_fig = FIG_DIR / "multipanel_ssi_rf_boxplots_gray_compact.png"
plt.savefig(outfile_fig, bbox_inches="tight")
plt.show()

print(f"✅ Figure saved to: {outfile_fig}")

# 2️⃣ SSI vs RF Population Summary — Per Country + Overall
#     (m_ssi, p_space, p_any)

In [None]:
def _extract_band(ds, geom, band_idx):
    """Return masked array for a band over polygon; None if no overlap."""
    try:
        arr, _ = rio_mask(
            ds, [geom.__geo_interface__],
            crop=True, filled=False, indexes=band_idx,
        )
    except Exception:
        return None
    a = arr[0] if arr.ndim == 3 else arr
    return a

def sample_mean(ds, geom, band_idx):
    """Mean of band within polygon."""
    a = _extract_band(ds, geom, band_idx)
    if a is None:
        return np.nan
    valid = ~a.mask if hasattr(a, "mask") else np.ones_like(a, dtype=bool)
    if not valid.any():
        return np.nan
    vals = a[valid]
    return float(vals.mean())

def sample_prop_pos(ds, geom, band_idx):
    """Proportion of pixels > 0 within polygon for band."""
    a = _extract_band(ds, geom, band_idx)
    if a is None:
        return np.nan
    valid = ~a.mask if hasattr(a, "mask") else np.ones_like(a, dtype=bool)
    vpx = int(valid.sum())
    if vpx == 0:
        return np.nan
    vals = a[valid]
    return float((vals > 0).sum() / vpx)

def process_country_pop(country):
    """Compute RF vs SSI population summaries for one country."""
    gpkg_path = RF_DIR / f"{country}_rf_preds_filtered80.gpkg"
    ssi_path = SSI_PARENT / country / SSI_FILENAME

    if not gpkg_path.exists():
        return {"Country": country, "Status": "Missing GPKG"}
    if not ssi_path.exists():
        return {"Country": country, "Status": "Missing SSI TIFF"}

    # --- Read RF file ---
    try:
        gdf = gpd.read_file(gpkg_path)[["rf_label", "POP_SEG", "geometry"]].copy()
    except Exception as e:
        return {"Country": country, "Status": f"Error reading GPKG: {e}"}

    # --- Read SSI raster + align CRS ---
    try:
        with rasterio.open(ssi_path) as ds:
            r_crs = ds.crs
            bounds_poly = box(*ds.bounds)
            if gdf.crs != r_crs:
                gdf = gdf.to_crs(r_crs)

            m_ssi_vals, p_space_vals, p_any_vals = [], [], []
            for geom in gdf.geometry:
                if (geom is None) or geom.is_empty or (not geom.is_valid) or (not geom.intersects(bounds_poly)):
                    m_ssi_vals.append(np.nan)
                    p_space_vals.append(np.nan)
                    p_any_vals.append(np.nan)
                    continue

                m_ssi_vals.append(sample_mean(ds, geom, BANDS["SSI"]))
                p_space_vals.append(sample_mean(ds, geom, BANDS["SpaceDef"]))
                p_any_vals.append(sample_prop_pos(ds, geom, BANDS["SSI"]))
    except Exception as e:
        return {"Country": country, "Status": f"Error reading SSI TIFF: {e}"}

    gdf["m_ssi"] = m_ssi_vals
    gdf["p_space"] = p_space_vals
    gdf["p_any"] = p_any_vals

    # --- Drop rows missing POP_SEG or geometry ---
    gdf = gdf.dropna(subset=["POP_SEG", "geometry"]).copy()
    if gdf.empty:
        return {"Country": country, "Status": "No valid data"}

    total_pop = gdf["POP_SEG"].sum()
    total_segments = len(gdf)

    rows = []
    for tau in THRESHOLDS:
        # segment counts
        rf_seg    = int((gdf["rf_label"] == 1).sum())
        ssi_seg   = int((gdf["m_ssi"] >= tau).sum())
        space_seg = int((gdf["p_space"] >= tau).sum())
        any_seg   = int((gdf["p_any"] >= tau).sum())

        # population sums
        rf_pop    = float(gdf.loc[gdf["rf_label"] == 1,       "POP_SEG"].sum())
        ssi_pop   = float(gdf.loc[gdf["m_ssi"] >= tau,        "POP_SEG"].sum())
        space_pop = float(gdf.loc[gdf["p_space"] >= tau,      "POP_SEG"].sum())
        any_pop   = float(gdf.loc[gdf["p_any"] >= tau,        "POP_SEG"].sum())

        rows.append({
            "Country": country,
            "Threshold": tau,
            "TotalSegments": total_segments,
            "RF_Deprived_Seg": rf_seg,
            "SSI_Deprived_Seg": ssi_seg,
            "SpaceDef_Deprived_Seg": space_seg,
            "AnyPixel_Deprived_Seg": any_seg,
            "RF_Deprived_Pop": rf_pop,
            "SSI_Deprived_Pop": ssi_pop,
            "SpaceDef_Deprived_Pop": space_pop,
            "AnyPixel_Deprived_Pop": any_pop,
            "Total_Pop": float(total_pop),
            "Status": "OK",
        })

    print(f"✅ Processed {country}: {total_segments} segments")
    return rows

print(f"\nProcessing population summary for {len(COUNTRIES)} countries...")

results = Parallel(n_jobs=8, verbose=10)(
    delayed(process_country_pop)(country) for country in COUNTRIES
)

# Flatten list-of-lists into a single DataFrame
flat_rows = [r for sub in results if isinstance(sub, list) for r in sub]
summary_df = pd.DataFrame(flat_rows)

# --- Add overall row (summing across countries) ---
overall_rows = []
for tau in THRESHOLDS:
    sub = summary_df[summary_df["Threshold"] == tau]
    overall_rows.append({
        "Country": "Overall",
        "Threshold": tau,
        "TotalSegments": sub["TotalSegments"].sum(),
        "RF_Deprived_Seg": sub["RF_Deprived_Seg"].sum(),
        "SSI_Deprived_Seg": sub["SSI_Deprived_Seg"].sum(),
        "SpaceDef_Deprived_Seg": sub["SpaceDef_Deprived_Seg"].sum(),
        "AnyPixel_Deprived_Seg": sub["AnyPixel_Deprived_Seg"].sum(),
        "RF_Deprived_Pop": sub["RF_Deprived_Pop"].sum(),
        "SSI_Deprived_Pop": sub["SSI_Deprived_Pop"].sum(),
        "SpaceDef_Deprived_Pop": sub["SpaceDef_Deprived_Pop"].sum(),
        "AnyPixel_Deprived_Pop": sub["AnyPixel_Deprived_Pop"].sum(),
        "Total_Pop": sub["Total_Pop"].sum(),
        "Status": "OK",
    })

summary_df = pd.concat([summary_df, pd.DataFrame(overall_rows)], ignore_index=True)

# --- Save population summary ---
POP_SUMMARY_CSV.parent.mkdir(parents=True, exist_ok=True)
summary_df.to_csv(POP_SUMMARY_CSV, index=False)
print(f"\n✅ Saved population summary to:\n{POP_SUMMARY_CSV}")

try:
    from IPython.display import display
    display(summary_df.head(10))
except Exception:
    pass

# 3️⃣ GLOBAL RULE × THRESHOLD TABLE (COUNTS)

In [None]:
IN_CSV  = POP_SUMMARY_CSV
OUT_CSV_COUNTS = IN_CSV / "rf_ssi_population_summary_GLOBAL_rule_threshold_table.csv"

df_pop = pd.read_csv(IN_CSV)
overall = df_pop[df_pop["Country"] == "Overall"].copy()

rules = [
    ("m_ssi",  "SSI_Deprived_Seg",      "SSI_Deprived_Pop",      "Mean-SSI"),
    ("p_space","SpaceDef_Deprived_Seg", "SpaceDef_Deprived_Pop", "SpaceDef"),
    ("p_any",  "AnyPixel_Deprived_Seg", "AnyPixel_Deprived_Pop", "Any-pixel"),
]

rows = []
for _, row in overall.iterrows():
    tau = row["Threshold"]
    for _, seg_col, pop_col, label in rules:
        rows.append({
            "Rule": label,
            "Threshold": tau,
            "TotalSegments": int(row["TotalSegments"]),
            "RF_Deprived_Seg": int(row["RF_Deprived_Seg"]),
            "Rule_Deprived_Seg": int(row[seg_col]),
            "Total_Pop": float(row["Total_Pop"]),
            "RF_Deprived_Pop": float(row["RF_Deprived_Pop"]),
            "Rule_Deprived_Pop": float(row[pop_col]),
        })

out_counts = pd.DataFrame(rows).sort_values(["Rule", "Threshold"]).reset_index(drop=True)

with pd.option_context("display.max_rows", None, "display.float_format", "{:,.0f}".format):
    print("\nGlobal Rule × Threshold table (counts):")
    print(out_counts)

out_counts.to_csv(OUT_CSV_COUNTS, index=False)
print(f"\n✅ Saved global Rule×Threshold table (counts) to:\n{OUT_CSV_COUNTS}")

# 4️⃣ GLOBAL RULE × THRESHOLD TABLE (POPULATIONS IN MILLIONS)

In [None]:
OUT_CSV_MILL = IN_CSV / "rf_ssi_population_summary_GLOBAL_rule_threshold_table_millions.csv"

rows = []
for _, row in overall.iterrows():
    tau = row["Threshold"]
    for _, seg_col, pop_col, label in rules:
        rows.append({
            "Rule": label,
            "Threshold": tau,
            "TotalSegments": int(row["TotalSegments"]),
            "RF_Deprived_Seg": int(row["RF_Deprived_Seg"]),
            "Rule_Deprived_Seg": int(row[seg_col]),
            # populations in millions
            "Total_Pop_M": round(float(row["Total_Pop"]) / 1e6, 2),
            "RF_Deprived_Pop_M": round(float(row["RF_Deprived_Pop"]) / 1e6, 2),
            "Rule_Deprived_Pop_M": round(float(row[pop_col]) / 1e6, 2),
        })

out_mill = pd.DataFrame(rows).sort_values(["Rule", "Threshold"]).reset_index(drop=True)

# Add % shares
out_mill["RF_Deprived_Pop_%"] = (out_mill["RF_Deprived_Pop_M"] / out_mill["Total_Pop_M"] * 100).round(2)
out_mill["Rule_Deprived_Pop_%"] = (out_mill["Rule_Deprived_Pop_M"] / out_mill["Total_Pop_M"] * 100).round(2)

with pd.option_context("display.max_rows", None):
    print("\nGlobal Rule × Threshold table (millions + shares):")
    print(out_mill)

out_mill.to_csv(OUT_CSV_MILL, index=False)
print(f"\n✅ Saved global Rule×Threshold table (millions) to:\n{OUT_CSV_MILL}")