In [None]:
import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from typing import Final
from pathlib import Path

CACHE_DIR: Final[Path] = Path('~/.cache/us_maps').resolve()
CACHE_DIR.mkdir(exist_ok=True, parents=True)


def plot_us_state_choropleth(
    df: pd.DataFrame,
    value_col: str,
    cmap: str = 'OrRd',
    set_col: str | None = None,
    cache_dir: Path = CACHE_DIR,
    out_path: Path | None = None,
):
    """
    Plots a U.S. state choropleth map with optional border color coding by set.

    Parameters:
    - df: Pandas DataFrame with state names (index) and a statistics column.
    - value_col: Name of the statistics column.
    - cmap: Matplotlib colormap to use (default: 'OrRd').
    - set_col: Optional column name for border coloring with values: 'train', 'test', 'val'.
    - cache_dir: Directory to store or load the cached GeoJSON.
    - out_path: Optional output path (pdf, png, etc) for the figure
    """
    geojson_path = cache_dir / 'us_states.geojson'

    if not geojson_path.exists():
        url = 'https://raw.githubusercontent.com/jgoodall/us-maps/master/geojson/state.geo.json'
        print("Downloading U.S. states GeoJSON...")
        import requests
        response = requests.get(url)
        response.raise_for_status()
        with geojson_path.open('wb') as f:
            f.write(response.content)
        print("Done downloading")

    usa_states = gpd.read_file(geojson_path)

    # Normalize index casing
    df = df.copy()
    df.index = df.index.str.title()

    # Rename and merge
    usa_states = usa_states.rename(columns={'NAME10': 'state'})
    merged = usa_states.set_index('state').join(df)

    fig, ax = plt.subplots(1, 1, figsize=(15, 10))

    # Base plot
    # Plot without legend
    merged.plot(column=value_col, cmap=cmap, linewidth=0.8, ax=ax, edgecolor='0.8', legend=False)

    # Create scalar mappable for colorbar
    norm = mpl.colors.Normalize(vmin=merged[value_col].min(), vmax=merged[value_col].max())
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)

    # Add smaller colorbar
    fig.colorbar(sm, ax=ax, fraction=0.025, pad=0.04)
    linewidth: Final[float] = 2.5

    if set_col and set_col in merged.columns:
        # Color mappings for 'train', 'test', 'val'
        set_colors = {'train': 'blue', 'test': 'green', 'val': 'orange'}

        for set_value, color in set_colors.items():
            subset = merged[merged[set_col] == set_value]
            if not subset.empty:
                subset.boundary.plot(ax=ax, edgecolor=color, linewidth=linewidth)

        # Add legend boxes as text annotations with background color
        ax.text(0.1, 0.24, 
                'train', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["train"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

        ax.text(0.1, 0.17, 
                'test', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["test"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth)
                )

        ax.text(0.1, 0.1, 
                'validate',
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["val"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

    ax.set_title(f'U.S. States Colored by {value_col}', fontsize=16)
    ax.axis('off')
    if out_path:
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
state_borders = {
        'Alabama': ['Florida', 'Georgia', 'Mississippi', 'Tennessee'],
        'Alaska': [],
        'Arizona': ['California', 'Colorado', 'Nevada', 'New Mexico', 'Utah'],
        'Arkansas': ['Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'Tennessee', 'Texas'],
        'California': ['Arizona', 'Nevada', 'Oregon'],
        'Colorado': ['Arizona', 'Kansas', 'Nebraska', 'New Mexico', 'Oklahoma', 'Utah', 'Wyoming'],
        'Connecticut': ['Massachusetts', 'New York', 'Rhode Island'],
        'Delaware': ['Maryland', 'New Jersey', 'Pennsylvania'],
        'Florida': ['Alabama', 'Georgia'],
        'Georgia': ['Alabama', 'Florida', 'North Carolina', 'South Carolina', 'Tennessee'],
        'Hawaii': [],
        'Idaho': ['Montana', 'Nevada', 'Oregon', 'Utah', 'Washington', 'Wyoming'],
        'Illinois': ['Indiana', 'Iowa', 'Michigan', 'Kentucky', 'Missouri', 'Wisconsin'],
        'Indiana': ['Illinois', 'Kentucky', 'Michigan', 'Ohio'],
        'Iowa': ['Illinois', 'Minnesota', 'Missouri', 'Nebraska', 'South Dakota', 'Wisconsin'],
        'Kansas': ['Colorado', 'Missouri', 'Nebraska', 'Oklahoma'],
        'Kentucky': ['Illinois', 'Indiana', 'Missouri', 'Ohio', 'Tennessee', 'Virginia', 'West Virginia'],
        'Louisiana': ['Arkansas', 'Mississippi', 'Texas'],
        'Maine': ['New Hampshire'],
        'Maryland': ['Delaware', 'Pennsylvania', 'Virginia', 'West Virginia'],
        'Massachusetts': ['Connecticut', 'New Hampshire', 'New York', 'Rhode Island', 'Vermont'],
        'Michigan': ['Illinois', 'Indiana', 'Minnesota', 'Ohio', 'Wisconsin'],
        'Minnesota': ['Iowa', 'Michigan', 'North Dakota', 'South Dakota', 'Wisconsin'],
        'Mississippi': ['Alabama', 'Arkansas', 'Louisiana', 'Tennessee'],
        'Missouri': ['Arkansas', 'Illinois', 'Iowa', 'Kansas', 'Kentucky', 'Nebraska', 'Oklahoma', 'Tennessee'],
        'Montana': ['Idaho', 'North Dakota', 'South Dakota', 'Wyoming'],
        'Nebraska': ['Colorado', 'Iowa', 'Kansas', 'Missouri', 'South Dakota', 'Wyoming'],
        'Nevada': ['Arizona', 'California', 'Idaho', 'Oregon', 'Utah'],
        'New Hampshire': ['Maine', 'Massachusetts', 'Vermont'],
        'New Jersey': ['Delaware', 'New York', 'Pennsylvania'],
        'New Mexico': ['Arizona', 'Colorado', 'Oklahoma', 'Texas', 'Utah'],
        'New York': ['Connecticut', 'Massachusetts', 'New Jersey', 'Pennsylvania', 'Vermont'],
        'North Carolina': ['Georgia', 'South Carolina', 'Tennessee', 'Virginia'],
        'North Dakota': ['Minnesota', 'Montana', 'South Dakota'],
        'Ohio': ['Indiana', 'Kentucky', 'Michigan', 'Pennsylvania', 'West Virginia'],
        'Oklahoma': ['Arkansas', 'Colorado', 'Kansas', 'Missouri', 'New Mexico', 'Texas'],
        'Oregon': ['California', 'Idaho', 'Nevada', 'Washington'],
        'Pennsylvania': ['Delaware', 'Maryland', 'New Jersey', 'New York', 'Ohio', 'West Virginia'],
        'Rhode Island': ['Connecticut', 'Massachusetts'],
        'South Carolina': ['Georgia', 'North Carolina'],
        'South Dakota': ['Iowa', 'Minnesota', 'Montana', 'Nebraska', 'North Dakota', 'Wyoming'],
        'Tennessee': ['Alabama', 'Arkansas', 'Georgia', 'Kentucky', 'Mississippi', 'Missouri', 'North Carolina', 'Virginia'],
        'Texas': ['Arkansas', 'Louisiana', 'New Mexico', 'Oklahoma'],
        'Utah': ['Arizona', 'Colorado', 'Idaho', 'Nevada', 'New Mexico', 'Wyoming'],
        'Vermont': ['Massachusetts', 'New Hampshire', 'New York'],
        'Virginia': ['Kentucky', 'Maryland', 'North Carolina', 'Tennessee', 'West Virginia'],
        'Washington': ['Idaho', 'Oregon'],
        'West Virginia': ['Kentucky', 'Maryland', 'Ohio', 'Pennsylvania', 'Virginia'],
        'Wisconsin': ['Illinois', 'Iowa', 'Michigan', 'Minnesota'],
        'Wyoming': ['Colorado', 'Idaho', 'Montana', 'Nebraska', 'South Dakota', 'Utah']
}

def share_common_border_us_states(s1: str, s2: str) -> bool:
    # Normalize inputs
    s1 = s1.strip().title()
    s2 = s2.strip().title()

    if s1 not in state_borders or s2 not in state_borders:
        raise ValueError(f"State '{s1}' or '{s2}' is not a valid U.S. state")

    return s2 in state_borders[s1]

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("justin2028/unemployment-in-america-per-us-state")

print("Path to dataset files:", path)

In [None]:
data = pd.read_csv(list(Path(path).glob("*.csv"))[0])
example_year: Final[int] = 1976
example_month: Final[int] = 1

filtered_data = data[
    data["State/Area"].isin(state_borders.keys()) & 
    ~data["State/Area"].isin(["Alaska", "Hawaii"])
]
filtered_data = filtered_data.set_index(filtered_data["State/Area"])

x_cols = ["Total Civilian Labor Force in State/Area",
          "Total Civilian Non-Institutional Population in State/Area",
          "Percent (%) of State/Area\'s Population"
          ]
y_col = "Percent (%) of Labor Force Employed in State/Area"
filtered_data[x_cols[:2]] = filtered_data[x_cols[:2]].apply(lambda col: pd.to_numeric(col.str.replace(",", ""), errors="coerce"))


def subset_month(year: int, month: int) -> pd.DataFrame:
    data_year = filtered_data[filtered_data["Year"] == year]
    return data_year[data_year["Month"] == month]



data_month = subset_month(year=example_year, month=example_month)
data_month

In [None]:
plot_us_state_choropleth(data_month, value_col=y_col)

In [None]:
import seaborn as sns 

x_cols_short = list(range(len(x_cols)))
renamed_data = data_month.rename(
    columns={k:v for k, v in zip(x_cols, x_cols_short)}
)
sns.pairplot(renamed_data, vars=x_cols_short)

First two features are almost perfectly correlated, so getting rid of the first feature

In [None]:
x_cols_reduced = x_cols[1:]
x_cols_short_reduced = x_cols_short[1:]
sns.scatterplot(data_month, x=x_cols_reduced[0], y=x_cols_reduced[1], hue=y_col)

In [None]:
# Northwestern US states
test_states = [
    "Washington", "Oregon", "Idaho", "Montana", "Wyoming",
    "North Dakota", "South Dakota", "Nebraska", "Minnesota", "Iowa",
    "Colorado", "Utah", "Nevada", "Kansas", "Missouri"
]

# Eastern US states
val_states = [
    "Maine", "New Hampshire", "Vermont", "Massachusetts", "Connecticut",
    "New York", "New Jersey", "Pennsylvania", "Delaware", "Maryland",
    "Rhode Island", "Virginia", "North Carolina", "South Carolina", "Georgia"
]

set_col: Final[str] = "set"
data_month[set_col] = "train"
data_month.loc[test_states, set_col] = "test"
data_month.loc[val_states, set_col] = "val"

plot_us_state_choropleth(data_month, value_col=y_col, set_col=set_col, out_path=Path(f"emloyment_{example_year}_{example_month}.png"))

In [None]:
from itertools import product
from collections import defaultdict
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
import torch
from tqdm import tqdm
from tabrel.benchmark.nw_regr import NwModelConfig, NwTrainConfig, RelNwRegr

train_cfg, model_cfg = NwTrainConfig(lr=0.006, n_epochs=100), NwModelConfig()

results = defaultdict(list)

for year in tqdm(filtered_data["Year"].unique()):
    for month in range(1, 13):
        data_month = subset_month(year, month)
        if data_month.empty:
            continue

        try:
            test_indices = data_month.index.get_indexer(test_states)
            val_indices = data_month.index.get_indexer(val_states)
            if np.any(test_indices == -1) or np.any(val_indices == -1):
                continue

            n_samples: Final[int] = len(data_month)
            all_indices = np.arange(n_samples)
            nontrain_indices = np.array([*test_indices, *val_indices])
            train_indices = np.setdiff1d(all_indices, test_indices)

            x = data_month[x_cols_reduced].to_numpy()
            y = data_month[y_col].to_numpy()

            x_train = x[train_indices]
            y_train = y[train_indices]
            x_test = x[test_indices]
            y_test = y[test_indices]
            x_val = x[val_indices]
            y_val = y[val_indices]

            x_mean = np.mean(x_train, axis=0, keepdims=True)
            x_std = np.std(x_train, axis=0, keepdims=True) + 1e-8

            x_train_norm = (x_train - x_mean) / x_std
            x_test_norm = (x_test - x_mean) / x_std
            x_val_norm = (x_val - x_mean) / x_std

            r = np.zeros((n_samples, n_samples))
            for i, j in product(range(n_samples), range(n_samples)):
                state_i, state_j = data_month.index[i], data_month.index[j]
                if share_common_border_us_states(state_i, state_j):
                    r[i, j] = 1

            r_train = r[np.ix_(train_indices, train_indices)]
            r_test_train = r[np.ix_(test_indices, train_indices)]

            x_nonval_norm = np.concatenate((x_train_norm, x_test_norm))
            y_nonval = np.concatenate((y_train, y_test))
            r_val_nonval = r[np.ix_(val_indices, np.concatenate((train_indices, test_indices)))]

            # Convert to torch
            x_train_norm = torch.tensor(x_train_norm, dtype=torch.float32)
            y_train = torch.tensor(y_train, dtype=torch.float32)
            x_test_norm = torch.tensor(x_test_norm, dtype=torch.float32)
            y_test = torch.tensor(y_test, dtype=torch.float32)
            r_test_train = torch.tensor(r_test_train, dtype=torch.float32)

            x_val_norm = torch.tensor(x_val_norm, dtype=torch.float32)
            x_nonval_norm = torch.tensor(x_nonval_norm, dtype=torch.float32)
            y_val = torch.tensor(y_val, dtype=torch.float32)
            y_nonval = torch.tensor(y_nonval, dtype=torch.float32)
            r_val_nonval = torch.tensor(r_val_nonval, dtype=torch.float32)

            for use_rel in (True, False):
                model = RelNwRegr(model_cfg)
                optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg.lr)
                loss_fn = torch.nn.MSELoss()

                model.train()
                for epoch in range(train_cfg.n_epochs):
                    optimizer.zero_grad()
                    y_pred = model(x_train_norm, y_train, x_test_norm,
                                   r_test_train if use_rel else torch.zeros_like(r_test_train))
                    loss = loss_fn(y_pred, y_test)
                    loss.backward()
                    optimizer.step()

                model.eval()
                with torch.no_grad():
                    y_pred_val = model(x_nonval_norm, y_nonval, x_val_norm,
                                       r_val_nonval if use_rel else torch.zeros_like(r_val_nonval))
                    y_pred_val_np = y_pred_val.numpy()
                    y_val_np = y_val.numpy()

                    mse = mean_squared_error(y_val_np, y_pred_val_np)
                    r2 = r2_score(y_val_np, y_pred_val_np)

                    results[f"rel={use_rel}"].append((mse, r2))
        except Exception as e:
            # print(f"Skipping year={year}, month={month}: {e}")
            continue

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), sharex=True)

labels = list(results.keys())
bar_locs = np.arange(len(labels))

# Prepare data
mse_means = [np.mean([x[0] for x in results[k]]) for k in labels]
mse_stds = [np.std([x[0] for x in results[k]]) for k in labels]

r2_means = [np.mean([x[1] for x in results[k]]) for k in labels]
r2_stds = [np.std([x[1] for x in results[k]]) for k in labels]

# MSE plot
ax1.bar(bar_locs, mse_means, yerr=mse_stds, capsize=5, color='skyblue')
ax1.set_ylabel("MSE (val)")
ax1.set_title("Validation MSE (mean ± std)")
ax1.set_xticks(bar_locs)
ax1.set_xticklabels(labels)

# R squared plot
ax2.bar(bar_locs, r2_means, yerr=r2_stds, capsize=5, color='lightgreen')
ax2.set_ylabel("$R^2$ (val)")
ax2.set_title("Validation $R^2$ (mean $\pm$ std)")
ax2.set_xticks(bar_locs)
ax2.set_xticklabels(labels)

plt.tight_layout()
plt.show()