In [None]:
%load_ext autoreload
%autoreload 2

import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from typing import Final
from pathlib import Path

CACHE_DIR: Final[Path] = Path('~/.cache/us_maps').resolve()
CACHE_DIR.mkdir(exist_ok=True, parents=True)


def plot_us_state_choropleth(
    df: pd.DataFrame,
    value_col: str,
    cmap: str = 'OrRd',
    set_col: str | None = None,
    cache_dir: Path = CACHE_DIR,
    out_path: Path | None = None,
):
    """
    Plots a U.S. state choropleth map with optional border color coding by set.

    Parameters:
    - df: Pandas DataFrame with state names (index) and a statistics column.
    - value_col: Name of the statistics column.
    - cmap: Matplotlib colormap to use (default: 'OrRd').
    - set_col: Optional column name for border coloring with values: 'backgnd', 'query', 'val'.
    - cache_dir: Directory to store or load the cached GeoJSON.
    - out_path: Optional output path (pdf, png, etc) for the figure
    """
    geojson_path = cache_dir / 'us_states.geojson'

    if not geojson_path.exists():
        url = 'https://raw.githubusercontent.com/jgoodall/us-maps/master/geojson/state.geo.json'
        print("Downloading U.S. states GeoJSON...")
        import requests
        response = requests.get(url)
        response.raise_for_status()
        with geojson_path.open('wb') as f:
            f.write(response.content)
        print("Done downloading")

    usa_states = gpd.read_file(geojson_path)

    # Normalize index casing
    df = df.copy()
    df.index = df.index.str.title()

    # Rename and merge
    usa_states = usa_states.rename(columns={'NAME10': 'state'})
    merged = usa_states.set_index('state').join(df)

    fig, ax = plt.subplots(1, 1, figsize=(15, 10))

    # Base plot
    # Plot without legend
    merged.plot(column=value_col, cmap=cmap, linewidth=0.8, ax=ax, edgecolor='0.8', legend=False)

    # Create scalar mappable for colorbar
    norm = mpl.colors.Normalize(vmin=merged[value_col].min(), vmax=merged[value_col].max())
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)

    # Add smaller colorbar
    fig.colorbar(sm, ax=ax, fraction=0.025, pad=0.04)
    linewidth: Final[float] = 2.5

    if set_col and set_col in merged.columns:
        # Color mappings for 'backgnd', 'trial', 'val'
        set_colors = {'backgnd': 'blue', 'trial': 'green', 'val': 'orange'}

        for set_value, color in set_colors.items():
            subset = merged[merged[set_col] == set_value]
            if not subset.empty:
                subset.boundary.plot(ax=ax, edgecolor=color, linewidth=linewidth)

        # Add legend boxes as text annotations with background color
        ax.text(0.1, 0.24, 
                'background', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["backgnd"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

        ax.text(0.1, 0.17, 
                'trial', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["trial"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth)
                )

        ax.text(0.1, 0.1, 
                'validate',
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["val"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

    ax.set_title(f'U.S. States Colored by {value_col}', fontsize=16)
    ax.axis('off')
    if out_path:
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
state_borders = {
        'Alabama': ['Florida', 'Georgia', 'Mississippi', 'Tennessee'],
        'Alaska': [],
        'Arizona': ['California', 'Colorado', 'Nevada', 'New Mexico', 'Utah'],
        'Arkansas': ['Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'Tennessee', 'Texas'],
        'California': ['Arizona', 'Nevada', 'Oregon'],
        'Colorado': ['Arizona', 'Kansas', 'Nebraska', 'New Mexico', 'Oklahoma', 'Utah', 'Wyoming'],
        'Connecticut': ['Massachusetts', 'New York', 'Rhode Island'],
        'Delaware': ['Maryland', 'New Jersey', 'Pennsylvania'],
        'Florida': ['Alabama', 'Georgia'],
        'Georgia': ['Alabama', 'Florida', 'North Carolina', 'South Carolina', 'Tennessee'],
        'Hawaii': [],
        'Idaho': ['Montana', 'Nevada', 'Oregon', 'Utah', 'Washington', 'Wyoming'],
        'Illinois': ['Indiana', 'Iowa', 'Michigan', 'Kentucky', 'Missouri', 'Wisconsin'],
        'Indiana': ['Illinois', 'Kentucky', 'Michigan', 'Ohio'],
        'Iowa': ['Illinois', 'Minnesota', 'Missouri', 'Nebraska', 'South Dakota', 'Wisconsin'],
        'Kansas': ['Colorado', 'Missouri', 'Nebraska', 'Oklahoma'],
        'Kentucky': ['Illinois', 'Indiana', 'Missouri', 'Ohio', 'Tennessee', 'Virginia', 'West Virginia'],
        'Louisiana': ['Arkansas', 'Mississippi', 'Texas'],
        'Maine': ['New Hampshire'],
        'Maryland': ['Delaware', 'Pennsylvania', 'Virginia', 'West Virginia'],
        'Massachusetts': ['Connecticut', 'New Hampshire', 'New York', 'Rhode Island', 'Vermont'],
        'Michigan': ['Illinois', 'Indiana', 'Minnesota', 'Ohio', 'Wisconsin'],
        'Minnesota': ['Iowa', 'Michigan', 'North Dakota', 'South Dakota', 'Wisconsin'],
        'Mississippi': ['Alabama', 'Arkansas', 'Louisiana', 'Tennessee'],
        'Missouri': ['Arkansas', 'Illinois', 'Iowa', 'Kansas', 'Kentucky', 'Nebraska', 'Oklahoma', 'Tennessee'],
        'Montana': ['Idaho', 'North Dakota', 'South Dakota', 'Wyoming'],
        'Nebraska': ['Colorado', 'Iowa', 'Kansas', 'Missouri', 'South Dakota', 'Wyoming'],
        'Nevada': ['Arizona', 'California', 'Idaho', 'Oregon', 'Utah'],
        'New Hampshire': ['Maine', 'Massachusetts', 'Vermont'],
        'New Jersey': ['Delaware', 'New York', 'Pennsylvania'],
        'New Mexico': ['Arizona', 'Colorado', 'Oklahoma', 'Texas', 'Utah'],
        'New York': ['Connecticut', 'Massachusetts', 'New Jersey', 'Pennsylvania', 'Vermont'],
        'North Carolina': ['Georgia', 'South Carolina', 'Tennessee', 'Virginia'],
        'North Dakota': ['Minnesota', 'Montana', 'South Dakota'],
        'Ohio': ['Indiana', 'Kentucky', 'Michigan', 'Pennsylvania', 'West Virginia'],
        'Oklahoma': ['Arkansas', 'Colorado', 'Kansas', 'Missouri', 'New Mexico', 'Texas'],
        'Oregon': ['California', 'Idaho', 'Nevada', 'Washington'],
        'Pennsylvania': ['Delaware', 'Maryland', 'New Jersey', 'New York', 'Ohio', 'West Virginia'],
        'Rhode Island': ['Connecticut', 'Massachusetts'],
        'South Carolina': ['Georgia', 'North Carolina'],
        'South Dakota': ['Iowa', 'Minnesota', 'Montana', 'Nebraska', 'North Dakota', 'Wyoming'],
        'Tennessee': ['Alabama', 'Arkansas', 'Georgia', 'Kentucky', 'Mississippi', 'Missouri', 'North Carolina', 'Virginia'],
        'Texas': ['Arkansas', 'Louisiana', 'New Mexico', 'Oklahoma'],
        'Utah': ['Arizona', 'Colorado', 'Idaho', 'Nevada', 'New Mexico', 'Wyoming'],
        'Vermont': ['Massachusetts', 'New Hampshire', 'New York'],
        'Virginia': ['Kentucky', 'Maryland', 'North Carolina', 'Tennessee', 'West Virginia'],
        'Washington': ['Idaho', 'Oregon'],
        'West Virginia': ['Kentucky', 'Maryland', 'Ohio', 'Pennsylvania', 'Virginia'],
        'Wisconsin': ['Illinois', 'Iowa', 'Michigan', 'Minnesota'],
        'Wyoming': ['Colorado', 'Idaho', 'Montana', 'Nebraska', 'South Dakota', 'Utah']
}

def share_common_border_us_states(s1: str, s2: str) -> bool:
    # Normalize inputs
    s1 = s1.strip().title()
    s2 = s2.strip().title()

    if s1 not in state_borders or s2 not in state_borders:
        raise ValueError(f"State '{s1}' or '{s2}' is not a valid U.S. state")

    return s2 in state_borders[s1]

In [None]:
import kagglehub

# path = kagglehub.dataset_download("justin2028/unemployment-in-america-per-us-state")
path = Path("/Users/user/.cache/kagglehub/datasets/justin2028/unemployment-in-america-per-us-state/versions/3")
print("Path to dataset files:", path)

In [None]:
data = pd.read_csv(list(Path(path).glob("*.csv"))[0])
example_year: Final[int] = 1976
example_month: Final[int] = 1

filtered_data = data[
    data["State/Area"].isin(state_borders.keys()) & 
    ~data["State/Area"].isin(["Alaska", "Hawaii"])
]
filtered_data = filtered_data.set_index(filtered_data["State/Area"])

x_cols = ["Total Civilian Labor Force in State/Area",
          "Total Civilian Non-Institutional Population in State/Area",
          "Percent (%) of State/Area\'s Population"
          ]
y_col = "Percent (%) of Labor Force Employed in State/Area"
filtered_data[x_cols[:2]] = filtered_data[x_cols[:2]].apply(lambda col: pd.to_numeric(col.str.replace(",", ""), errors="coerce"))


def subset_month(year: int, month: int) -> pd.DataFrame:
    data_year = filtered_data[filtered_data["Year"] == year]
    return data_year[data_year["Month"] == month]



data_month = subset_month(year=example_year, month=example_month)
data_month.head()

In [None]:
plot_us_state_choropleth(data_month, value_col=y_col)

In [None]:
import seaborn as sns 

x_cols_short = ["Labor Force", "Population", "Percent Eligible"]
y_col_short = "Percent Employed"
cols_short = x_cols_short + [y_col_short]
renamed_data = data_month.rename(
    columns={k:v for k, v in zip(x_cols + [y_col], cols_short)}
)
# sns.pairplot(renamed_data, vars=cols_short)
# plt.savefig("employment_pairplot.pdf")

First two features are almost perfectly correlated, so getting rid of the first feature

In [None]:
x_cols_reduced = x_cols[1:]
x_cols_short_reduced = x_cols_short[1:]
# sns.scatterplot(data_month, x=x_cols_reduced[0], y=x_cols_reduced[1], hue=y_col)

In [None]:
from itertools import product
from collections import defaultdict
from typing import Final
import numpy as np
import pandas as pd

# # Northwestern US states
# query_states = [
#     "Washington", "Oregon", "Idaho", "Montana", "Wyoming",
#     "North Dakota", "South Dakota", "Nebraska", "Minnesota", "Iowa",
#     "Colorado", "Utah", "Nevada", "Kansas", "Missouri"
# ]

# # Eastern US states
# val_states = [
#     "Maine", "New Hampshire", "Vermont", "Massachusetts", "Connecticut",
#     "New York", "New Jersey", "Pennsylvania", "Delaware", "Maryland",
#     "Rhode Island", "Virginia", "North Carolina", "South Carolina", "Georgia"
# ]

np.random.seed(42)
states_shuffled = list(set(filtered_data.index))
np.random.shuffle(states_shuffled)
n_query = n_val = len(states_shuffled) // 3

query_states = states_shuffled[:n_query]
val_states = states_shuffled[n_query : n_query + n_val]


set_col: Final[str] = "set"
data_month[set_col] = "backgnd"
data_month.loc[query_states, set_col] = "trial"
data_month.loc[val_states, set_col] = "val"

# plot_us_state_choropleth(data_month, value_col=y_col, set_col=set_col, out_path=Path(f"emloyment_{example_year}_{example_month}.png"))

In [None]:
from tqdm import tqdm
from tabrel.benchmark.nw_regr import MlpConfig, run_training
from tabrel.train import train_relnet



def compute_relation_matrix(states: list[str]) -> np.ndarray:
    n = len(states)
    r = np.zeros((n, n))
    for i, j in product(range(n), range(n)):
        if share_common_border_us_states(states[i], states[j]):
            r[i, j] = 1
    return r

r_month = compute_relation_matrix(list(data_month.index))

## pairwise Y distances

In [None]:
from tabrel.utils.plot import calc_diffs

y_month = data_month[y_col].to_numpy()
calc_diffs(y_month, r_month, .5, True)

## Training

In [None]:

LR: Final[float] = 0.006
N_EPOCHS: Final[int] = 5000

results = defaultdict(list)

for year in tqdm(filtered_data["Year"].unique()):
    for month in range(1, 13):
        data_month: pd.DataFrame = subset_month(year, month)
        if data_month.empty:
            continue

        # try:
        query_indices = data_month.index.get_indexer(query_states)
        val_indices = data_month.index.get_indexer(val_states)
        if np.any(query_indices == -1) or np.any(val_indices == -1):
            continue

        n_samples = len(data_month)
        all_indices = np.arange(n_samples)
        nonbackgnd_indices = np.array([*query_indices, *val_indices])
        backgnd_indices = np.setdiff1d(all_indices, nonbackgnd_indices)

        x = data_month[x_cols_reduced].to_numpy()
        y = data_month[y_col].to_numpy()

        r = compute_relation_matrix(list(data_month.index))

        month_results = run_training(
            x=x,
            y=y,
            r=r,
            backgnd_indices=backgnd_indices,
            query_indices=query_indices,
            val_indices=val_indices,
            lr=LR,
            n_epochs=N_EPOCHS,
            rel_as_feats=r,
            mlp_config=MlpConfig(
                        in_dim=x.shape[1],
                        hidden_dim=10,
                        out_dim=20,
                        dropout=.19,
                    ),
        )
        for k, v in month_results.items():
            results[k].append(v)
        
        def train_relnet_shorthand(n_layers: int) -> tuple[float, float]:
            return train_relnet(
                    x=x,
                    y=y,
                    r=r,
                    backgnd_indices=query_indices,
                    query_indices=backgnd_indices,
                    val_indices=val_indices,
                    lr=0.003,
                    n_epochs=5000,
                    n_layers=n_layers,
                    progress_bar=False,
                    print_loss=False,
                )[:2]
            # relnet_results = train_relnet_shorthand(2)
            # if relnet_results[0] > 100:
            #     relnet_results = train_relnet_shorthand(1)
            # results["relnet"].append([*relnet_results, None])
            

        # except Exception as e:
        #     print(f"Skipping year={year}, month={month}: {e}")
        #     continue
        # break
    # break  # to make execution faster; remove to reproduce results


In [None]:
from tabrel.utils.misc import to_df


to_df(results, decimal_places=3) #.to_csv("usUnemployment_results_firstMonth_ave.csv")