In [None]:
from IPython.core.display import HTML
display(HTML("""
<style>
table.dataframe {
    margin-left: auto !important;
    margin-right: auto !important;
    /* optional: keep the table width no wider than its contents */
    width: auto;
}
</style>
"""))
import pandas as pd
pd.set_option("display.max_columns", 10)
pd.set_option("display.max_rows", 25)
import warnings
warnings.filterwarnings("ignore",message="Geometry is in a geographic CRS.*centroid.*")

import geopandas as gpd
import numpy as np
import math
import itertools
import pandas as pd
import networkx as nx
import h3
from h3 import H3FailedError

from collections import defaultdict
from typing import Dict, List, Tuple

In [None]:
R_KM = 6_371.0                                   # mean Earth radius

def _step_km(res):
    """centre‑to‑neighbour spacing at this resolution (km), API‑version safe"""
    try:                                         # h3‑py ≥ 4.0
        edge = h3.average_hexagon_edge_length(res, 'km')
    except AttributeError:                       # 3.x fallback
        edge = h3.edge_length(res, 'km')
    return 2 * edge / math.sqrt(3)

def _dist(a, b, res, step):
    """grid distance with geodesic fallback"""
    try:
        return h3.grid_distance(a, b)
    except H3FailedError:
        lat1, lon1 = h3.cell_to_latlng(a)
        lat2, lon2 = h3.cell_to_latlng(b)
        dlat = math.radians(lat2 - lat1)
        dlon = math.radians(lon2 - lon1)
        lat1 = math.radians(lat1); lat2 = math.radians(lat2)
        h = (math.sin(dlat/2)**2 +
             math.cos(lat1)*math.cos(lat2)*math.sin(dlon/2)**2)
        km = 2 * R_KM * math.atan2(math.sqrt(h), math.sqrt(1 - h))
        return int(round(km / step))

def origin_distances(origin_hex, cells, in_km=False):
    res = h3.get_resolution(origin_hex)
    step = _step_km(res)
    distances = {}
    for cell in cells:
        distances[cell] = _dist(origin_hex, cell, res, step)
    
    distances = pd.Series(distances)
    if in_km:
        return distances*step
    else:
        return distances    
    
def mst_series(cells):
    """symmetric Series of MST edge lengths for a single‑resolution cell set"""
    cells = list({c for c in cells if h3.is_valid_cell(c)})
    if len(cells) < 2:
        return pd.Series(dtype=int)
    res = h3.get_resolution(cells[0])
    if any(h3.get_resolution(c) != res for c in cells):
        raise ValueError('mixed resolutions')
    step = _step_km(res)
    edges = [(u, v, _dist(u, v, res, step))
             for u, v in itertools.combinations(cells, 2)]
    g = nx.Graph()
    g.add_weighted_edges_from(edges)
    t = nx.minimum_spanning_tree(g)
    d = {(u, v): w for u, v, w in t.edges(data='weight')}
    d.update({(v, u): w for (u, v), w in d.items()})
    return pd.Series(d)

def update_mst_series(T, mapping):
    """
    Given a symmetric MST Series T[(u,v)]→dist and a mapping old→new,
    return a new symmetric Series where the distance between any two
    clusters is the min over all original edges connecting them.
    """
    collapsed = {}
    for (u, v), dist in T.items():
        # only handle each undirected edge once
        if u > v:
            continue
        u2 = mapping.get(u, u)
        v2 = mapping.get(v, v)
        if u2 == v2:
            continue
        a, b = (u2, v2) if u2 < v2 else (v2, u2)
        collapsed[(a, b)] = min(collapsed.get((a, b), dist), dist)

    # mirror edges for symmetry
    sym = {**collapsed, **{(b, a): d for (a, b), d in collapsed.items()}}
    return pd.Series(sym)

def youngest_common_ancestor(h1, h2):
    if h1 == h2:
        return h1, h3.get_resolution(h1)

    if h3.get_base_cell_number(h1) != h3.get_base_cell_number(h2):
        return None, None

    r1, r2 = h3.get_resolution(h1), h3.get_resolution(h2)

    while r1 > r2:
        h1, r1 = h3.cell_to_parent(h1, r1 - 1), r1 - 1
    while r2 > r1:
        h2, r2 = h3.cell_to_parent(h2, r2 - 1), r2 - 1

    while h1 != h2 and r1 > 0:
        h1, h2 = h3.cell_to_parent(h1, r1 - 1), h3.cell_to_parent(h2, r2 - 1)
        r1 -= 1            # r1 == r2

    return (h1, r1) if h1 == h2 else (None, None)


def leaves_above_min_res(mst, min_res):
    adj = defaultdict(set)
    for u, v in mst.index:
        adj[u].add(v)
        adj[v].add(u)
    leaves = [
        cell for cell, nbrs in adj.items()
        if len(nbrs) == 1 and h3.get_resolution(cell) > min_res
    ]
    return leaves, adj

# ----------------------------------------------------------------------
# helper: edge length for a leaf → its single neighbour
# ----------------------------------------------------------------------
def edge_length(leaf, adj, mst):
    nbr = next(iter(adj[leaf]))
    return mst.get((leaf, nbr), mst.get((nbr, leaf)))

# ----------------------------------------------------------------------
# main: one greedy aggregation step
# ----------------------------------------------------------------------
def lowest_leaf_and_aggregation(od, mst, min_res=4):
    leaves, adj = leaves_above_min_res(mst, min_res)
    if not leaves:
        raise ValueError("no eligible leaves with resolution > min_res")

    # find minimum flow among leaves
    w       = od.set_index("destination")["weight"].to_dict()
    min_flow = min(w.get(c, float("inf")) for c in leaves)

    # tie‑break: pick leaf on longest MST edge
    candidates = [c for c in leaves if w.get(c, float("inf")) == min_flow]
    leaf       = max(candidates, key=lambda c: edge_length(c, adj, mst))
    neighbor   = next(iter(adj[leaf]))

    # youngest common ancestor, capped at min_res
    ancestor, anc_res = youngest_common_ancestor(leaf, neighbor)
    if anc_res is None or anc_res < min_res:
        ancestor = h3.cell_to_parent(leaf, min_res)
        anc_res   = min_res

    # all destinations under that ancestor
    descendants = [
        c for c in od["destination"].unique()
        if h3.is_valid_cell(c)
        and h3.get_resolution(c) >= anc_res
        and h3.cell_to_parent(c, anc_res) == ancestor
    ]

    return {
        "leaf": leaf,
        "neighbor": neighbor,
        "ancestor": ancestor,
        "ancestor_resolution": anc_res,
        "descendants_in_od": descendants,
        "descendant_to_ancestor": {c: ancestor for c in descendants},
    }

## load zipcode and home data

In [None]:
california_zipcodes = gpd.read_file('zip_codes_ca.geojson')

In [None]:
home_table = pd.read_csv("home_table.csv")

### Altadena

In [None]:
# --- CONFIGURATION ---------------------------------------------------
# CHANGE THESE VALUES FOR THE NEXT RUN
# ---------------------------------------------------------------------
ORIGIN_ZIP = 'US.91001'            # e.g., 'US.90272' for Palisades
TIME_PERIOD_PREFIX = 'jan_feb'     # e.g., 'feb_apr'
K_SENSITIVITY = 5
DISTANCE_THRESHOLD_KM = 60
# ---------------------------------------------------------------------
zip_col = f"{TIME_PERIOD_PREFIX}_zip"
h3_col = f"h3_cell_{TIME_PERIOD_PREFIX}"
far_cells_label = f">{int(DISTANCE_THRESHOLD_KM)}km"

In [None]:
print(f"Processing OD matrix for {ORIGIN_ZIP} ({TIME_PERIOD_PREFIX})...")

od = home_table.query("baseline_zip == @ORIGIN_ZIP").copy()
# get destination cells
unique_cells = od.h3_cell_jan_feb.dropna().unique().tolist()

od["origin"] = od.baseline_zip
od["destination"] = od[zip_col].where(od[zip_col].notna(), other=od[h3_col]).fillna("missing")

od_agg = (
    od[["origin", "destination"]]
    .groupby(["origin", "destination"])
    .size()
    .reset_index(name="weight")
)

# 2. SPATIAL FILTERING: Relabel destinations that are too far away
# Get origin centroid to calculate distances from
centroid = california_zipcodes.query("geography_id == @ORIGIN_ZIP").geometry.centroid.iat[0]
origin_cell = h3.latlng_to_cell(lat=centroid.y, lng=centroid.x, res=8)

# Identify valid H3 cells and calculate their distance from the origin
unique_cells = [c for c in od_agg.destination.unique() if isinstance(c, str) and h3.is_valid_cell(c)]
dists = origin_distances(origin_cell, unique_cells, in_km=True)
far_cells = dists[dists > DISTANCE_THRESHOLD_KM].index

# Map the far cells to the new label (e.g., '>40km')
od_agg.loc[od_agg.destination.isin(far_cells), 'destination'] = far_cells_label
# Re-group after the relabeling
od_agg = od_agg.groupby(["origin", "destination"], as_index=False).sum()

# 3. SEPARATION
non_h3_labels = [far_cells_label, 'missing', ORIGIN_ZIP]
non_cell_dests = od_agg[od_agg.destination.isin(non_h3_labels)].copy()
od_cells = od_agg[~od_agg.destination.isin(non_h3_labels)].copy()


# 4. PRIVACY AGGREGATION
print(f"Aggregating cells to meet K={K_SENSITIVITY} threshold...")
compliant_flows = []
current_od = od_cells.copy()

current_res = 8
min_weight = od_jan_feb.weight.min()

while current_od.weight.min() < K_SENSITIVITY:
    compliant_at_this_level = current_od.query("weight >= @K_SENSITIVITY").copy()

    if not compliant_at_this_level.empty:
        # Add the resolution column, as requested.
        compliant_at_this_level['res'] = current_res
        compliant_flows.append(compliant_at_this_level)
    
    current_od['destination'] = current_od['destination'].apply(h3.cell_to_parent)
    current_od = current_od.groupby(["origin", "destination"], as_index=False).sum()
    current_res -= 1
    
final_od = pd.concat(compliant_flows + [non_cell_dests], ignore_index=True)
total_weight = final_od.weight.sum()

print("total displacements: ", total_weight) # 151
print(final_od)

final_od.loc[:, 'weight'] = (final_od.weight / total_weight).round(3)

In [None]:
final_od.to_csv(output_filename, index=False)

In [None]:
output_filename = f"od_matrices/od_{ORIGIN_ZIP.replace('.', '_')}_{TIME_PERIOD_PREFIX}_k={K_SENSITIVITY}.csv"
final_od.to_csv(output_filename, index=False)
print(f"Complete: {ORIGIN_ZIP} | Total Weight: {total_weight} | Saved to {output_filename}")