In [3]:
#!/usr/bin/env python3
"""
Downscale *Wales country-level* population projections (already cleaned/tidy)
to Wales LSOA21 by proportional split using 2022 LSOA baseline strata (population_lv3, 18 features).

Inputs:
  1) Clean Wales projections CSV (tidy):
       wales_population_projections_tidy.csv
     expected columns:
       sex,age_group,year,population
     where:
       sex in: total, male, female
       age_group in: all, 0_15, 16_24, 25_49, 50_64, 65+
       year includes: 2030, 2035, 2040, 2045

  2) EW lookup (to identify Wales LSOA21):
     LSOA_(2011)_to_LSOA_(2021)_to_Local_Authority_District_(2022)_Exact_Fit_Lookup_for_EW_(V3).csv
     uses: LSOA21CD, LAD22CD (Wales LAD22CD starts with 'W')

  3) 2022 LSOA baseline JSON:
     ../2022/population_2022_LSOA21.json
     pop[lsoa21cd]["population_lv3"] length 18 (order fixed)

Method:
  For each (sex, age_group, year):
    share_lsoa = baseline_2022_lsoa_stratum / sum_wales_baseline_2022_stratum
    lsoa_projection = wales_projection * share_lsoa

Output:
  lsoa21_population_projections_2030_2035_2040_2045_long_Wales.csv
  columns: lsoa21cd,country,sex,age_group,year,population
"""

import json
import numpy as np
import pandas as pd

# ----------------------------- PATHS -----------------------------
WALES_TIDY_CSV = "wales_population_projections_tidy.csv"
LOOKUP_CSV = "LSOA_(2011)_to_LSOA_(2021)_to_Local_Authority_District_(2022)_Exact_Fit_Lookup_for_EW_(V3).csv"
LSOA_BASELINE_JSON = "../2022/population_2022_LSOA21.json"

OUT_LONG_WALES = "lsoa21_population_projections_2030_2035_2040_2045_long_Wales.csv"

# ----------------------------- OPTIONS -----------------------------
YEARS_KEEP = {2030, 2035, 2040, 2045}

# If Wales baseline stratum total is 0, choose allocation:
#   "zero"  -> allocate 0 to all LSOAs
#   "equal" -> allocate equally across Wales LSOAs within (sex, age_group)
ZERO_BASELINE_POLICY = "zero"

# ----------------------------- STRATA MAP (18) -----------------------------
STRATA_INDEX = {
    ("total",  "all"): 0,
    ("male",   "all"): 1,
    ("female", "all"): 2,

    ("total",  "0_15"): 3,
    ("total",  "16_24"): 4,
    ("total",  "25_49"): 5,
    ("total",  "50_64"): 6,
    ("total",  "65+"): 7,

    ("male",   "0_15"): 8,
    ("male",   "16_24"): 9,
    ("male",   "25_49"): 10,
    ("male",   "50_64"): 11,
    ("male",   "65+"): 12,

    ("female", "0_15"): 13,
    ("female", "16_24"): 14,
    ("female", "25_49"): 15,
    ("female", "50_64"): 16,
    ("female", "65+"): 17,
}

REQUIRED_WALES_COLS = {"sex", "age_group", "year", "population"}


def load_baseline_json(path: str) -> dict:
    with open(path, "r") as f:
        return json.load(f)


def safe_get_lv3(pop_dict: dict, lsoa21cd: str):
    v = pop_dict.get(lsoa21cd, None)
    if v is None:
        return None
    lv3 = v.get("population_lv3", None)
    if lv3 is None or len(lv3) != 18:
        return None
    return lv3


def main():
    # ---- Wales LSOAs from EW lookup: LAD22CD startswith 'W' ----
    lookup = pd.read_csv(LOOKUP_CSV, dtype=str, usecols=["LSOA21CD", "LAD22CD"])
    lookup = lookup.dropna(subset=["LSOA21CD", "LAD22CD"]).drop_duplicates()

    wales_lsoa = lookup.loc[lookup["LAD22CD"].str.startswith("W", na=False), ["LSOA21CD"]].copy()
    wales_lsoa.columns = ["lsoa21cd"]
    if wales_lsoa.empty:
        raise RuntimeError("No Wales LSOA21 found via LAD22CD starting with 'W'. Check LOOKUP_CSV.")

    wales_lsoa_set = set(wales_lsoa["lsoa21cd"].tolist())
    print(f"[INFO] Wales LSOA21 count from lookup: {len(wales_lsoa_set):,}")

    # ---- Load baseline JSON and assemble Wales baseline table ----
    pop = load_baseline_json(LSOA_BASELINE_JSON)

    baseline_rows = []
    missing = 0
    for lsoa21cd in wales_lsoa["lsoa21cd"].values:
        lv3 = safe_get_lv3(pop, lsoa21cd)
        if lv3 is None:
            missing += 1
            continue
        for (sex, age_group), idx in STRATA_INDEX.items():
            baseline_rows.append((lsoa21cd, sex, age_group, float(lv3[idx])))

    baseline = pd.DataFrame(
        baseline_rows, columns=["lsoa21cd", "sex", "age_group", "baseline_2022"]
    )

    if baseline.empty:
        raise RuntimeError("Baseline table is empty for Wales. Check JSON keys vs LSOA21 codes.")

    if missing > 0:
        print(f"[WARN] Missing/invalid baseline for {missing} Wales LSOA21 codes (skipped).")

    # Wales baseline totals by stratum
    wales_base = (
        baseline.groupby(["sex", "age_group"], as_index=False)["baseline_2022"]
        .sum()
        .rename(columns={"baseline_2022": "wales_baseline_2022"})
    )

    # ---- Load tidy Wales projections ----
    wproj = pd.read_csv(WALES_TIDY_CSV, dtype={"sex": str, "age_group": str, "year": int})
    missing_cols = REQUIRED_WALES_COLS - set(wproj.columns)
    if missing_cols:
        raise RuntimeError(f"WALES_TIDY_CSV missing required columns: {sorted(missing_cols)}")

    wproj["sex"] = wproj["sex"].str.strip().str.lower()
    wproj["age_group"] = wproj["age_group"].str.strip().str.lower()
    wproj = wproj[wproj["year"].isin(YEARS_KEEP)].copy()

    # Validate strata exist in map
    bad = wproj[~wproj.apply(lambda r: (r["sex"], r["age_group"]) in STRATA_INDEX, axis=1)]
    if not bad.empty:
        ex = bad[["sex", "age_group"]].drop_duplicates().head(50)
        raise RuntimeError(f"Found strata not in STRATA_INDEX (showing up to 50):\n{ex}")

    # ---- Join Wales projections with Wales baseline totals ----
    wproj2 = wproj.merge(wales_base, on=["sex", "age_group"], how="left")

    # ---- Expand to LSOA by joining baseline shares ----
    # This is *many-to-many* by design: (sex,age) has multiple years in projections and many LSOAs in baseline.
    expanded = wproj2.merge(
        baseline,
        on=["sex", "age_group"],
        how="left",
        validate="many_to_many",
    )

    expanded["wales_baseline_2022"] = expanded["wales_baseline_2022"].fillna(0.0)
    expanded["baseline_2022"] = expanded["baseline_2022"].fillna(0.0)

    # Shares
    if ZERO_BASELINE_POLICY == "equal":
        counts = baseline.groupby(["sex", "age_group"], as_index=False)["lsoa21cd"].nunique()
        counts = counts.rename(columns={"lsoa21cd": "n_lsoa"})
        expanded = expanded.merge(counts, on=["sex", "age_group"], how="left")
        expanded["n_lsoa"] = expanded["n_lsoa"].fillna(0).astype(int)

        def share_row(row):
            if row["wales_baseline_2022"] > 0:
                return row["baseline_2022"] / row["wales_baseline_2022"]
            if row["n_lsoa"] > 0:
                return 1.0 / row["n_lsoa"]
            return 0.0

        expanded["share"] = expanded.apply(share_row, axis=1)
    elif ZERO_BASELINE_POLICY == "zero":
        expanded["share"] = np.where(
            expanded["wales_baseline_2022"] > 0,
            expanded["baseline_2022"] / expanded["wales_baseline_2022"],
            0.0
        )
    else:
        raise ValueError("ZERO_BASELINE_POLICY must be 'zero' or 'equal'.")

    expanded["population"] = expanded["population"].astype(float) * expanded["share"]

    out = expanded[["lsoa21cd", "sex", "age_group", "year", "population"]].copy()
    out.insert(1, "country", "Wales")

    out.to_csv(OUT_LONG_WALES, index=False)
    print(f"[OK] Wrote: {OUT_LONG_WALES}")
    print(f"Rows: {len(out):,}")
    print(f"Unique Wales LSOAs: {out['lsoa21cd'].nunique():,}")
    print(f"Years: {sorted(out['year'].unique().tolist())}")
    print("Done.")


if __name__ == "__main__":
    main()

[INFO] Wales LSOA21 count from lookup: 1,917
[OK] Wrote: lsoa21_population_projections_2030_2035_2040_2045_long_Wales.csv
Rows: 138,024
Unique Wales LSOAs: 1,917
Years: [2030, 2035, 2040, 2045]
Done.
