In [None]:
# ==============================================================
# Figure 5 ‚Äî Combined 3√ó3 comparison (SSI, MN, WRI vs CSD, œÑ = 0.1)
#
# Inputs (relative to repo root):
#   - 3_comparitive_analysis/SSI/Pooled_Results/
#       ‚îî‚îÄ ssi_rf_country_metrics.csv
#   - 3_comparitive_analysis/MN/Outputs/MN_Comparison_Files/
#       ‚îî‚îÄ {country}_segments_mnlabels_k3_maj10.gpkg
#   - 3_comparitive_analysis/WRI/Outputs/PerCountry_Outputs/
#       ‚îî‚îÄ {country}_wri_vs_rf_threshold_sweep_country.csv
#
# The "summary" block for the right-hand metric bars is HARD-CODED
# but derived from the following global summary tables:
#   - 3_comparitive_analysis/SSI/Pooled_Results/
#       rf_ssi_population_summary_GLOBAL_rule_threshold_table_millions.csv
#   - 3_comparitive_analysis/MN/Outputs/
#       mn_rf_summary_segments_population_GLOBAL_k_tau_table_millions.csv
#   - 3_comparitive_analysis/WRI/Outputs/
#       wri_rf_population_summary_GLOBAL_rule_threshold_table_millions.csv
#
# Output:
#   - 4_Figures_Tables/Figures/Figure5_ThreeComparisons.png
# ==============================================================

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import geopandas as gpd
from sklearn.metrics import precision_score, recall_score, f1_score
from matplotlib.gridspec import GridSpec

In [None]:
# --------------------------------------------------------------
# 1Ô∏è‚É£ Paths (repo-relative)
# --------------------------------------------------------------
REPO_ROOT = Path.cwd()

SSI_CSV  = REPO_ROOT / "3_comparitive_analysis" / "SSI" / "Pooled_Results" / "ssi_rf_country_metrics.csv"
MN_ROOT  = REPO_ROOT / "3_comparitive_analysis" / "MN" / "Outputs" / "MN_Comparison_Files"
WRI_ROOT = REPO_ROOT / "3_comparitive_analysis" / "WRI" / "Outputs" / "PerCountry_Outputs"

FIG_OUT_DIR = REPO_ROOT / "4_Figures_Tables" / "Figures"
FIG_OUT_DIR.mkdir(parents=True, exist_ok=True)

print("SSI metrics CSV :", SSI_CSV)
print("MN comparison   :", MN_ROOT)
print("WRI outputs     :", WRI_ROOT)
print("Figure outdir   :", FIG_OUT_DIR)

In [None]:
# --------------------------------------------------------------
# 2Ô∏è‚É£ Load and prepare SSI (SpaceDef, œÑ = 0.1)
# --------------------------------------------------------------
ssi_df = pd.read_csv(SSI_CSV)

# keep only œÑ = 0.1 and SpaceDef-only rule
ssi_df = ssi_df[ssi_df["œÑ (threshold)"] == 0.1].copy()
ssi_df = ssi_df[ssi_df["Rule / Comparison"] == "SSI (SpaceDef only)"].copy()
ssi_df["Rule"] = "SpaceDef"

ssi_long = ssi_df.melt(
    id_vars=["country", "Rule"],
    value_vars=["precision", "recall", "F1"],
    var_name="Metric",
    value_name="Score",
)
ssi_long["Metric"] = ssi_long["Metric"].str.capitalize()
ssi_long["Dataset"] = "SSI"

In [None]:
# --------------------------------------------------------------
# 3Ô∏è‚É£ Load and prepare MN (k > 3, œÑ = 0.1 only ‚Üí maj10)
# --------------------------------------------------------------
mn_rows = []
for country_dir in sorted(MN_ROOT.glob("*")):
    if not country_dir.is_dir():
        continue
    country = country_dir.name
    # œÑ = 0.1 corresponds to maj10 in file naming
    fpath = country_dir / f"{country}_segments_mnlabels_k3_maj10.gpkg"
    if not fpath.exists():
        continue
    try:
        gdf = gpd.read_file(fpath)
    except Exception:
        continue

    if "rf_label" not in gdf.columns or "label_final" not in gdf.columns:
        continue

    sub = gdf[["rf_label", "label_final"]].dropna()
    if sub.empty:
        continue

    mn_rows.append({
        "country": country,
        "tau": 0.1,
        "precision": precision_score(sub["rf_label"].astype(int),
                                     sub["label_final"].astype(int),
                                     zero_division=0),
        "recall":    recall_score(sub["rf_label"].astype(int),
                                  sub["label_final"].astype(int),
                                  zero_division=0),
        "F1":        f1_score(sub["rf_label"].astype(int),
                              sub["label_final"].astype(int),
                              zero_division=0),
    })

mn_df = pd.DataFrame(mn_rows)

mn_long = mn_df.melt(
    id_vars=["country", "tau"],
    value_vars=["precision", "recall", "F1"],
    var_name="Metric",
    value_name="Score",
)
mn_long["Metric"] = mn_long["Metric"].str.capitalize()
mn_long["Dataset"] = "MN"


In [None]:
# --------------------------------------------------------------
# 4Ô∏è‚É£ Load and prepare WRI (p_informal, œÑ = 0.1 only)
# --------------------------------------------------------------
wri_frames = []
for c_dir in sorted(WRI_ROOT.glob("*")):
    if not c_dir.is_dir():
        continue
    country = c_dir.name
    f = c_dir / f"{country}_wri_vs_rf_threshold_sweep_country.csv"
    if not f.exists():
        continue

    try:
        df = pd.read_csv(f)
        df.insert(0, "country", country)
        wri_frames.append(df)
    except Exception:
        # skip problematic country, keep going
        continue

if not wri_frames:
    raise RuntimeError("No WRI sweep files found under PerCountry_Outputs.")

wri_df = pd.concat(wri_frames, ignore_index=True)

# keep only p_informal rule
def _rule_key(x: str):
    if isinstance(x, str) and "p_informal" in x.lower():
        return "p_informal"
    return None

wri_df["Rule_key"] = wri_df["Rule / Comparison"].map(_rule_key)
wri_df = wri_df[wri_df["Rule_key"] == "p_informal"].copy()

wri_df.rename(columns={"œÑ (threshold)": "tau"}, inplace=True)
wri_df = wri_df[wri_df["tau"] == 0.1].copy()

wri_long = wri_df.melt(
    id_vars=["country", "tau"],
    value_vars=["precision", "recall", "F1"],
    var_name="Metric",
    value_name="Score",
)
wri_long["Metric"] = wri_long["Metric"].str.capitalize()
wri_long["Dataset"] = "WRI"

In [None]:
# --------------------------------------------------------------
# 5Ô∏è‚É£ Combine datasets for unified plotting
# --------------------------------------------------------------
combined = pd.concat([ssi_long, mn_long, wri_long], ignore_index=True)

metrics_order = ["Precision", "Recall", "F1"]
datasets = ["SSI", "MN", "WRI"]

# --------------------------------------------------------------
# 6Ô∏è‚É£ Palette and hard-coded global summaries
# --------------------------------------------------------------
gray_color = "#bdbdbd"

# NOTE: The summary dictionary below is HARD-CODED but derived from:
#   - rf_ssi_population_summary_GLOBAL_rule_threshold_table_millions.csv
#   - mn_rf_summary_segments_population_GLOBAL_k_tau_table_millions.csv
#   - wri_rf_population_summary_GLOBAL_rule_threshold_table_millions.csv
summary = {
    "SSI": {
        "Segments":   {"CSD": 35.4, "SpaceDef": 43.2},
        "Population": {"CSD": 30.6, "SpaceDef": 42.6},
    },
    "MN": {
        "Segments":   {"CSD": 35.6, "œÑ=0.1": 29.8},
        "Population": {"CSD": 31.6, "œÑ=0.1": 40.9},
    },
    "WRI": {
        "Segments":   {"CSD": 29.3, "œÑ=0.1": 48.6},
        "Population": {"CSD": 27.9, "œÑ=0.1": 46.2},
    },
}

row_titles = [
    "Country-level alignment of SSI (SpaceDef) (œÑ = 0.1) with CSD (n = 33)",
    "Country-level alignment of MN (k > 3, œÑ = 0.1) with CSD (n = 45)",
    "Country-level alignment of WRI (p_informal, œÑ = 0.1) with CSD (n = 48)",
]

In [None]:
# --------------------------------------------------------------
# 7Ô∏è‚É£ Figure layout
# --------------------------------------------------------------
sns.set_theme(
    context="paper",
    style="white",
    rc={
        "axes.edgecolor": "0.4",
        "axes.linewidth": 0.8,
        "axes.labelsize": 11,
        "font.size": 10,
        "xtick.labelsize": 9,
        "ytick.labelsize": 9,
    },
)
plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.family"] = "DejaVu Sans"

fig = plt.figure(figsize=(13, 10))
gs = GridSpec(
    3,
    4,
    figure=fig,
    width_ratios=[1, 1, 1, 1.3],
    wspace=0.35,
    hspace=0.45,
)

axes = np.empty((3, 4), dtype=object)
for r in range(3):
    for c in range(4):
        axes[r, c] = fig.add_subplot(gs[r, c])

# --------------------------------------------------------------
# 8Ô∏è‚É£ Boxplots (left 3 columns)
# --------------------------------------------------------------
for r, dataset in enumerate(datasets):
    for c, metric in enumerate(metrics_order):
        ax = axes[r, c]

        sub = combined[
            (combined["Dataset"] == dataset)
            & (combined["Metric"] == metric)
        ].copy()

        ax.axhline(0.5, lw=0.7, ls="--", color="#bdbdbd", zorder=0)

        sns.boxplot(
            data=sub,
            y="Score",
            color=gray_color,
            width=0.50,
            whis=(5, 95),
            showfliers=False,
            boxprops=dict(linewidth=0.9, edgecolor="0.4"),
            whiskerprops=dict(linewidth=0.8, color="0.4"),
            medianprops=dict(linewidth=2.0, color="black"),
            capprops=dict(linewidth=0.8, color="0.4"),
            ax=ax,
        )

        # Scatter overlay + median label
        vals = sub["Score"].dropna()
        if len(vals) > 0:
            x_vals = np.random.normal(0, 0.05, size=len(vals))
            ax.scatter(
                x_vals,
                vals,
                s=14,
                color="black",
                alpha=0.6,
                edgecolor="white",
                linewidth=0.3,
                zorder=3,
            )
            m = vals.median()
            ax.text(
                0.26,
                m + 0.01,
                f"{m:.2f}",
                va="center",
                ha="left",
                fontsize=9.5,
                color="#222222",
                fontweight="bold",
            )

        ax.set_ylim(0, 1.02)
        ax.set_xlabel("")
        ax.set_ylabel(metric)
        ax.set_xticks([0])

        if dataset == "SSI":
            xtlab = "SpaceDef"
        else:
            xtlab = "œÑ = 0.1"

        ax.set_xticklabels([xtlab])
        sns.despine(ax=ax)

    # Row-level title centred over the three boxplots
    mid_ax = axes[r, 1]
    mid_ax.set_title(row_titles[r], fontsize=11.5, fontweight="bold", pad=15)

# --------------------------------------------------------------
# 9Ô∏è‚É£ Right column ‚Äî metrics summary bars
# --------------------------------------------------------------
for r, dataset in enumerate(datasets):
    ax = axes[r, 3]
    ax.set_xlim(0, 60)
    ax.set_xticks([0, 20, 40, 60])
    ax.set_yticks([])
    ax.set_xlabel("%", fontsize=9)
    sns.despine(ax=ax, left=True, bottom=True)

    if r == 0:
        fig.text(
            0.805,
            0.978,
            "Metrics summary",
            ha="center",
            va="bottom",
            fontsize=11.5,
            fontweight="bold",
        )

    groups = list(summary[dataset].keys())  # ["Segments", "Population"]
    y_base = np.linspace(1.45, 0.55, len(groups))
    bar_height = 0.22

    for y0, group in zip(y_base, groups):
        metrics_dict = summary[dataset][group]

        for j, (label, val) in enumerate(metrics_dict.items()):
            y_pos = y0 - j * bar_height

            ax.barh(
                y_pos,
                val,
                height=bar_height * 0.75,
                color=gray_color,
                edgecolor="0.4",
                linewidth=0.6,
                zorder=2,
            )

            ax.text(
                val + 1.2,
                y_pos,
                f"{val:.1f}%",
                va="center",
                ha="left",
                fontsize=8.5,
                color="#222222",
            )

            ax.text(
                2,
                y_pos,
                label,
                va="center",
                ha="left",
                fontsize=7.5,
                color="#111111",
                fontweight="bold",
                zorder=3,
            )

        label_text = (
            "Deprived\npopulation" if "Pop" in group else "Deprived\nsegments"
        )
        ax.text(
            -5,
            y0 - (bar_height * 0.5),
            label_text,
            fontsize=7.5,
            rotation=90,
            fontweight="bold",
            ha="center",
            va="center",
        )

# --------------------------------------------------------------
# üîü Final layout and save
# --------------------------------------------------------------
plt.subplots_adjust(left=0.08, right=0.94, top=0.96, bottom=0.06)

outfile = FIG_OUT_DIR / "Figure5_ThreeComparisons.png"
plt.savefig(outfile, bbox_inches="tight", dpi=600)
print(f"‚úÖ Combined comparison figure saved to: {outfile}")

plt.show()