In [1]:
import geopandas as gpd
import pandas as pd
import requests
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from shapely.geometry import LineString, MultiLineString, Point, MultiPoint
from shapely import set_precision
import contextily as ctx
from shapely.ops import unary_union, linemerge, snap
from shapely.validation import make_valid
from math import isfinite
from scipy.spatial import cKDTree
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import pickle


## Counties

In [2]:
counties = gpd.read_file('../raw data/county level data/tl_2025_us_county/tl_2025_us_county.shp').to_crs(5070)

In [3]:
counties = counties[~counties.STATEFP.isin(['02', '14', '15', '72', '07', '03', '43', '52', '78', '72', '69', '66', '60'])]
counties.shape

In [4]:
counties['centroid'] = counties.geometry.centroid

In [5]:
counties.COUNTYFP = counties.STATEFP + counties.COUNTYFP
counties.head()

In [6]:
counties.COUNTYFP.to_csv('../raw data/county level data/counties_list.csv', index=False)

## Highways

Weights are $W^H_{i,j} = \frac{M_i M_j}{C_{ij}^2}$, where $M_i = \sum_i \text{AADT} \times \text{lanes} \times \text{miles}$ and $C_{ij}$ is the shortest travel time between the centroids of states $i$ and $j$. Weights are then normalized by dividing by the maximum value.

In [None]:
highways = gpd.read_file('../raw data/state level data/NHS/National_Highway_System_(NHS).shp').to_crs(5070)

In [None]:
highways.head()

In [None]:
highways.shape

In [None]:
interstates = highways[highways.SIGNT1 == 'I']

In [None]:
interstates = interstates[['STFIPS', 'CTFIPS', 'ROUTEID', 'SIGNN1', 'LNAME', 'SPEED_LIMI','AADT','THROUGH_LA','MILES', 'geometry']]

In [None]:
interstates.STFIPS = interstates.STFIPS.astype(int).astype(str)
interstates.CTFIPS = interstates.CTFIPS.astype(int).astype(str).str.zfill(3)
interstates.head()

In [None]:
interstates.STFIPS = interstates.STFIPS.astype(int).astype(str).str.zfill(2)

In [None]:
interstates = interstates[~interstates.STFIPS.isin(['02', '14', '15', '72', '07', '03', '43', '52', '78', '72', '69', '66', '60'])]

In [None]:
interstates.SPEED_LIMI = interstates.SPEED_LIMI.apply(lambda x: 65 if x < 45 else x)

## Graph construction

In [None]:
interstates["geometry"] = interstates["geometry"].apply(make_valid)

In [None]:
# explode multilines so we work with LineString pieces
interstates = interstates.explode(index_parts=False).reset_index(drop=True)


In [None]:
interstates["geometry"] = interstates.geometry.apply(lambda g: set_precision(g, 200))

In [None]:
U = unary_union(interstates.geometry)  # noding happens here

def iter_lines(geom):
    if geom.is_empty: 
        return
    if isinstance(geom, LineString):
        yield geom
    elif isinstance(geom, MultiLineString):
        for l in geom.geoms:
            yield from iter_lines(l)
    else:
        m = linemerge(geom)
        if isinstance(m, (LineString, MultiLineString)):
            yield from iter_lines(m)

noded_lines = list(iter_lines(U))
noded = gpd.GeoDataFrame(geometry=noded_lines, crs=interstates.crs)
noded = noded.reset_index().rename(columns={"index":"nid"})

In [None]:
noded_attr = gpd.overlay(noded, interstates[['AADT', 'THROUGH_LA', 'SPEED_LIMI', "geometry"]],
                         how="identity", keep_geom_type=False)

In [None]:
noded_attr["miles_piece"] = noded_attr.length * 0.000621371

In [None]:
w = noded_attr["miles_piece"].clip(lower=1e-6)
agg = (noded_attr
       .assign(w=w)
       .groupby("nid")
       .apply(lambda df: pd.Series({
           "miles_piece": df["miles_piece"].sum(),
           "speed": np.average(df['SPEED_LIMI'], weights=df["w"]),
           "AADT":  np.average(df['AADT'],  weights=df["w"]),
           "lanes": np.average(df['THROUGH_LA'],  weights=df["w"]),
       }))
       .reset_index())

noded_final = noded.merge(agg, on="nid", how="left")
noded_final["travel_min"] = (noded_final["miles_piece"] / noded_final["speed"].clip(lower=1e-6)) * 60.0


In [None]:
noded_final.rename({'speed':'SPEED_LIMI', 'lanes':'THROUGH_LA', 'miles_piece':'MILES'}, axis=1, inplace=True)

In [None]:
counties.geom_type.value_counts()

In [None]:
seg_in_county = gpd.overlay(
    interstates[['AADT', 'THROUGH_LA', 'MILES', "geometry"]],
    counties[['COUNTYFP', "geometry"]],
    how="intersection",
    keep_geom_type=True,
)


In [None]:
seg_in_county.to_file('../raw data/county level data/seg_in_county.shp')

In [None]:
# seg_in_county = gpd.read_file('../raw data/state level data/NHS/highway_overlay_shp/seg_in_state.shp')
# seg_in_county.head()

In [None]:
seg_in_county["part_miles"] = seg_in_county.length * 0.000621371

In [None]:
seg_in_county["cap"] = (
    seg_in_county['AADT'].astype(float).fillna(0.0)
    * seg_in_county['THROUGH_LA'].astype(float).fillna(1.0)
    * seg_in_county["part_miles"].clip(lower=1e-6)
)

M_by_county = (
    seg_in_county.groupby('COUNTYFP', as_index=False)["cap"].sum()
    .rename(columns={"cap": "M"})
)


In [None]:
seg_in_county.head()

In [None]:
M_by_county

In [None]:
counties = counties.merge(M_by_county, on='COUNTYFP', how="left")
counties["M"] = counties["M"].fillna(0.0)


In [None]:
def iter_lines(geom):
    """Yield LineString pieces from geometry, flattening MultiLineString."""
    if geom is None or geom.is_empty:
        return
    if isinstance(geom, LineString):
        yield geom
    elif isinstance(geom, MultiLineString):
        for ls in geom.geoms:
            if not ls.is_empty:
                yield ls

def coords2node(x, y, ndp=6):
    """Quantize coordinates for stable node keys (avoid floating-point duplicates)."""
    return round(float(x), ndp), round(float(y), ndp)


In [None]:
G = nx.Graph()

for _, r in noded_final.iterrows():
    speed = float(r['SPEED_LIMI'])
    miles = float(r['MILES'])
    # Guard against weird rows
    if not (isfinite(speed) and isfinite(miles)) or miles <= 0 or speed <= 0:
        continue
    travel_min = (miles / speed) * 60.0

    for ls in iter_lines(r.geometry):
        # Connect segment endpoints (you could also break at every vertex, but endpoints suffice for routing here)
        x0, y0 = ls.coords[0]
        x1, y1 = ls.coords[-1]
        u = coords2node(x0, y0)
        v = coords2node(x1, y1)
        # Combine parallel edges by keeping the minimum time (or sum—here min is reasonable)
        if G.has_edge(u, v):
            G[u][v]["travel_min"] = min(G[u][v]["travel_min"], travel_min)
            G[u][v]["miles"] = min(G[u][v]["miles"], miles)
        else:
            G.add_edge(u, v, travel_min=travel_min, miles=miles)


In [None]:
components = list(nx.connected_components(G))
len(components)

In [None]:
[len(c) for c in components]

In [None]:
plt.hist([len(c) for c in components])
plt.show()

In [None]:
largest_nodes = max(components, key=len)

In [None]:
G = G.subgraph(largest_nodes).copy()

In [None]:
def graph_edges_gdf(G, crs):
    rows = []
    for u, v, d in G.edges(data=True):
        rows.append({
            "u": u, "v": v,
            "travel_min": d.get("travel_min", np.nan),
            "miles": d.get("miles", np.nan),
            "geometry": LineString([u, v]),
        })
    return gpd.GeoDataFrame(rows, geometry="geometry", crs=crs)

edges_gdf = graph_edges_gdf(G, crs=counties.crs)  # G nodes are (x,y) in same CRS as `states`
edges_gdf.plot()

In [None]:
node_xy = np.array(list(G.nodes))

In [None]:
kdt = cKDTree(node_xy)
def nearest_node(pt):
    d, idx = kdt.query([pt.x, pt.y])
    return tuple(node_xy[idx])
counties["graph_node"] = counties["centroid"].apply(nearest_node)

In [None]:
for n in counties.graph_node:
    if n not in G:
        print('False')

In [None]:
n = len(counties)
T = np.full((n, n), np.inf, dtype=float)
county_ids = counties['COUNTYFP'].tolist()

In [None]:
# Precompute single-source Dijkstra from each anchor
for i, src in enumerate(counties["graph_node"]):
    # print(src)
    dist = nx.single_source_dijkstra_path_length(G, src, weight="travel_min")
    # print(len(dist))
    # Map to destination anchors
    for j, dst in enumerate(counties["graph_node"]):
        T[i, j] = dist[dst]

In [None]:
# Clean up any zeros/diagonal
for i in range(n):
    T[i, i] = np.inf  # set to inf so weight becomes 0 on diagonal


In [None]:
T[:5, :5]

In [None]:
alpha = 1 # state mass weight
beta = 2 # travel time weight

In [None]:
M = counties["M"].to_numpy()  # shape (n,)
# Impedance matrix C = T (minutes). Avoid divide-by-zero/infs later.
C = np.where(np.isfinite(T) & (T > 0), T, np.nan)

# w_ij = (M_i^alpha * M_j^alpha) / (C_ij^beta)
W = (M[:, None]**alpha) * (M[None, :]**alpha) / (C**beta)
W[~np.isfinite(W)] = 0.0
np.fill_diagonal(W, 0.0)

# row_sums = W.sum(axis=1, keepdims=True)
# # Avoid division by zero for isolated states (no interstate mass or disconnected)
# W = np.divide(W, np.where(row_sums == 0, 1.0, row_sums))


In [None]:
masses_df = pd.DataFrame(M, index=county_ids, columns=['mass'])

In [None]:
travel_times_df = pd.DataFrame(T, index=county_ids, columns=county_ids)

In [None]:
weights = pd.DataFrame(W, index=county_ids, columns=county_ids)

In [None]:
row_states = {s[:2] for s in weights.index}
col_states = {s[:2] for s in weights.columns}
states = sorted(row_states & col_states)

In [None]:
blocks = {}
for st in states:
    rows = weights.index[weights.index.str.startswith(st)]
    cols = weights.columns[weights.columns.str.startswith(st)]
    # For a symmetric matrix that’s county×county, rows and cols should match—but intersect to be safe:
    idx = rows.intersection(cols)
    if len(idx):
        blocks[st] = weights.loc[idx, idx]


In [None]:
with open('../processed data/county level/county_highway_weights_by_state.pkl', "wb") as f:
    pickle.dump(blocks, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
blocks_travel_times = {}
for st in states:
    rows = travel_times_df.index[travel_times_df.index.str.startswith(st)]
    cols = travel_times_df.columns[travel_times_df.columns.str.startswith(st)]
    # For a symmetric matrix that’s county×county, rows and cols should match—but intersect to be safe:
    idx = rows.intersection(cols)
    if len(idx):
        blocks_travel_times[st] = travel_times_df.loc[idx, idx]


In [None]:
with open('../processed data/county level/county_highway_travel_times_by_state.pkl', "wb") as f:
    pickle.dump(blocks_travel_times, f, protocol=pickle.HIGHEST_PROTOCOL)

## Airports

Weights are $W^A_i = \sum_{j} \frac{E_i E_j}{\max(E_{i,j})}$, where $E_i$ is the total enplanement of all airports within a 20 mile radius of state $i$ and $\max(E_{i,j})$ is the maximum enplanement value. Weights are then normalized by dividing by the maximum value.

In [None]:
airports = pd.read_excel('../raw data/state level data/all-airport-data.xlsx')
airports.head()

In [None]:
airports['NPIAS Hub'].value_counts()

In [None]:
relevant_airports = airports[airports['NPIAS Hub'].isin(['Large', 'Medium', 'Small'])]

In [None]:
locs = []
for _, row in relevant_airports.iterrows():
    y, x = row['ARP Latitude DD'], row['ARP Longitude DD']  # longitude, latitude
    url = f"https://geocoding.geo.census.gov/geocoder/geographies/coordinates?x={x}&y={y}&benchmark=Public_AR_Current&vintage=Current_Current&format=json"
    response = requests.get(url).json()
    
    fips = response['result']['geographies']['Counties'][0]['GEOID']
    county = response['result']['geographies']['Counties'][0]['NAME']
    locs.append([fips, county])

In [None]:
relevant_airports['COUNTYFP'] = locs

In [None]:
relevant_airports = relevant_airports[['Loc Id', 'COUNTYFP', 'NPIAS Hub', 'ARP Latitude DD', 'ARP Longitude DD']]

In [None]:
relevant_airports['COUNTYFP'] = relevant_airports['COUNTYFP'].apply(lambda x: x[0])
relevant_airports

In [None]:
relevant_airports['geometry'] = relevant_airports.apply(lambda row: Point(row['ARP Longitude DD'], row['ARP Latitude DD']), axis=1)

In [None]:
geo_airports = gpd.GeoDataFrame(relevant_airports, geometry='geometry', crs=4326).to_crs(counties.crs)

In [None]:
geo_airports['STATEFP'] = geo_airports.COUNTYFP.apply(lambda x: x[:2])

In [None]:
geo_airports = geo_airports[~geo_airports.STATEFP.isin(['02', '14', '15', '72', '07', '03', '43', '52', '78', '72', '69', '66', '60'])]

In [None]:
large = geo_airports[geo_airports['NPIAS Hub'] == 'Large']
medium = geo_airports[geo_airports['NPIAS Hub'] == 'Medium']
small = geo_airports[geo_airports['NPIAS Hub'] == 'Small']

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
counties.plot(ax=ax, facecolor='grey', edgecolor='lightgray', linewidth=0.5, zorder=1)
large.plot(ax=ax, facecolor='white', edgecolor='red', linewidth=0.5, zorder=3, marker='H', markersize=20)
medium.plot(ax=ax, facecolor='white', edgecolor='black', linewidth=0.5, zorder=2, marker='s', markersize=10)
small.plot(ax=ax, facecolor='white', edgecolor='green', linewidth=0.5, zorder=2, marker='^', markersize=5)
handles = [
    mlines.Line2D([], [], color='red', marker='H', markersize=15,
                  markerfacecolor='white', label='Large Airports'),
    mlines.Line2D([], [], color='black', marker='s', markersize=10,
              markerfacecolor='white', label='Medium Airports'),
    mlines.Line2D([], [], color='green', marker='^', markersize=5,
                  markerfacecolor='white', label='Small Airports'),

]
plt.legend(handles=handles, loc='lower left')
plt.show()

In [None]:
traffic = pd.read_excel('../raw data/state level data/ARP-cy2024-all-enplanements.xlsx')
traffic.head()

In [None]:
cand_traffic = geo_airports.merge(traffic[['Locid', 'CY 24 Enplanements']], how='left', left_on='Loc Id', right_on='Locid')

In [None]:
cand_traffic = cand_traffic.drop(['ARP Latitude DD', 'ARP Longitude DD', 'Locid'], axis=1).rename({'CY 24 Enplanements':'enplanements'}, axis=1)

In [None]:
cand_traffic.head()

In [None]:
mass = cand_traffic.pivot_table(values='enplanements', index='COUNTYFP', aggfunc='sum')

In [None]:
mass.to_csv('../processed data/county level/county_level_airport_masses.csv', index=True)

In [None]:
airweights = pd.DataFrame(np.array(mass) * np.array(mass.T))

In [None]:
np.allclose(airweights, airweights.T, atol=1e-9)

In [None]:
airweights.index = airweights.columns = mass.index
np.fill_diagonal(airweights.values, 0)

In [None]:
row_states = {s[:2] for s in airweights.index}
col_states = {s[:2] for s in airweights.columns}
states = sorted(row_states & col_states)

In [None]:
blocks_airweights = {}
for st in states:
    rows = airweights.index[airweights.index.str.startswith(st)]
    cols = airweights.columns[airweights.columns.str.startswith(st)]
    # For a symmetric matrix that’s county×county, rows and cols should match—but intersect to be safe:
    idx = rows.intersection(cols)
    if len(idx):
        blocks_airweights[st] = airweights.loc[idx, idx]

In [None]:
with open('../processed data/county level/county_airport_weights_by_state.pkl', "wb") as f:
    pickle.dump(blocks_airweights, f, protocol=pickle.HIGHEST_PROTOCOL)

## Adjacency Weights

In [None]:
county_adj = pd.read_csv('../raw data/county_adjacency2025.txt', delimiter='|', dtype=str)

In [None]:
county_adj['state_fips'] = county_adj['County GEOID'].apply(lambda x: x[:2])
county_adj = county_adj[~county_adj.state_fips.isin(['02', '14', '15', '72', '07', '03', '43', '52', '78', '72', '69', '66', '60'])]
county_adj.Length = county_adj.Length.astype(float)
county_adj.head()

In [None]:
border_adj_matrix = county_adj.pivot_table(index='County GEOID', columns='Neighbor GEOID', aggfunc='sum', values='Length', fill_value=0)
np.fill_diagonal(border_adj_matrix.values, 0)
border_adj_matrix.iloc[:5,:5]

In [None]:
border_adj_matrix.shape

In [None]:
row_states = {s[:2] for s in border_adj_matrix.index}
col_states = {s[:2] for s in border_adj_matrix.columns}
states = sorted(row_states & col_states)

In [None]:
blocks_adj = {}
for st in states:
    rows = border_adj_matrix.index[border_adj_matrix.index.str.startswith(st)]
    cols = border_adj_matrix.columns[border_adj_matrix.columns.str.startswith(st)]
    # For a symmetric matrix that’s county×county, rows and cols should match—but intersect to be safe:
    idx = rows.intersection(cols)
    if len(idx):
        blocks_adj[st] = border_adj_matrix.loc[idx, idx]

In [None]:
len(blocks_adj)

In [None]:
with open('../processed data/county level/county_adj_by_state.pkl', "wb") as f:
    pickle.dump(blocks_adj, f, protocol=pickle.HIGHEST_PROTOCOL)

## County-level Variant Data

In [31]:
prev_interp_comb = pd.read_csv('../processed data/state_level/state_level_prevalences.csv', dtype={'location':str})

In [None]:
prev_interp_comb.date = pd.to_datetime(prev_interp_comb.date)
prev_interp_comb['Other'] = 1 - prev_interp_comb[['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Omicron']].sum(axis=1)
prev_interp_comb.head()

In [108]:
prev_interp_comb.shape

In [144]:
cases_2021 = pd.read_csv('../raw data/county level data/us-counties-2021.csv', dtype={'fips':str})
cases_2022 = pd.read_csv('../raw data/county level data/us-counties-2022.csv', dtype={'fips':str})

In [145]:
cases = pd.concat([cases_2021, cases_2022], axis=0)

In [146]:
cases.columns

In [147]:
cases.loc[((cases.state == 'New York') & (cases.county == 'New York City')), 'geoid'] = 'USA-36061' # Correct for NYC

In [148]:
cases.date = pd.to_datetime(cases.date)
cases.set_index('date', inplace=True)

In [None]:
cases['fips'] = cases.geoid.apply(lambda x: x[4:])
cases.head()

In [159]:
daily_cases = (
    cases
    .groupby('fips')
    .apply(lambda g: g.reindex(pd.date_range('2021-01-01', '2022-12-31')))   # force same daily index
)
daily_cases.index = daily_cases.index.set_names(['fips', 'date'])
daily_cases = daily_cases.fillna(0).drop(['fips', 'geoid', 'county', 'state', 'cases_avg', 'cases_avg_per_100k', 'deaths', 'deaths_avg', 'deaths_avg_per_100k'], axis=1)

daily_cases = daily_cases.reset_index()


In [160]:
daily_cases['state'] = daily_cases.fips.apply(lambda x: x[:2])

In [161]:
daily_cases = daily_cases[~daily_cases.state.isin(['02', '14', '15', '72', '07', '03', '43', '52', '78', '72', '69', '66', '60'])]

In [162]:
daily_cases.head()

In [163]:
cases_prev = daily_cases.merge(prev_interp_comb, left_on=['state', 'date'], right_on=['location', 'date'], how='inner')
cases_prev.head()

In [203]:
comb = cases_prev.copy()

In [204]:
comb[['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Omicron', 'Other']] = cases_prev[['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Omicron', 'Other']].mul(cases_prev.cases, axis=0)

In [205]:
comb.drop(['cases', 'location'], axis=1, inplace=True)

In [206]:
comb.to_parquet('../processed data/county level/daily_cases_by_county.parquet')

In [207]:
states = comb.state.unique()

In [208]:
comb.head()

In [209]:
comb = comb.sort_values(['fips', 'date'])
rolled = comb.copy()
rolled = rolled.groupby('fips')[comb.columns.to_list()[3:]].transform(lambda s: s.rolling(7, min_periods=1).mean())
comb[comb.columns.to_list()[3:]] = rolled
comb.head()

In [210]:
blocks_cases = {}
for st in states:
    rows = comb[comb.fips.str.startswith(st)]
    blocks_cases[st] = rows

In [213]:
with open('../processed data/county level/rolled_county_cases.pkl', "wb") as f:
    pickle.dump(blocks_cases, f, protocol=pickle.HIGHEST_PROTOCOL)