In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd

from typing import Final
from pathlib import Path

CACHE_DIR: Final[Path] = Path('~/.cache/us_maps').resolve()
CACHE_DIR.mkdir(exist_ok=True, parents=True)

def plot_us_state_choropleth(df, value_col, cmap='OrRd', cache_dir: Path = CACHE_DIR):
    """
    Plots a colored U.S. map based on a continuous statistics value.
    Downloads and caches the U.S. states GeoJSON if not present locally.

    Parameters:
    - df: Pandas DataFrame with state names (index) and a statistics column.
    - value_col: Name of the statistics column.
    - cmap: Matplotlib colormap to use (default: 'OrRd').
    """
    geojson_path = cache_dir / 'us_states.geojson'

    # Download and cache if needed
    if not geojson_path.exists():
        url = 'https://raw.githubusercontent.com/jgoodall/us-maps/master/geojson/state.geo.json'
        print("Downloading U.S. states GeoJSON...")
        import requests
        response = requests.get(url)
        response.raise_for_status()
        with geojson_path.open('wb') as f:
            f.write(response.content)
        print("Done downloading")

    # Load cached GeoJSON
    usa_states = gpd.read_file(geojson_path)

    # Normalize index casing
    df = df.copy()
    df.index = df.index.str.title()

    # Merge user data with geometry
    usa_states = usa_states.rename(columns={'NAME10': 'state'})
    merged = usa_states.set_index('state').join(df)

    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(15, 10))
    merged.plot(column=value_col, cmap=cmap, linewidth=0.8, ax=ax, edgecolor='0.8', legend=True)
    ax.set_title(f'U.S. States Colored by {value_col}', fontsize=16)
    ax.axis('off')
    plt.show()

In [None]:
state_borders = {
        'Alabama': ['Florida', 'Georgia', 'Mississippi', 'Tennessee'],
        'Alaska': [],
        'Arizona': ['California', 'Colorado', 'Nevada', 'New Mexico', 'Utah'],
        'Arkansas': ['Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'Tennessee', 'Texas'],
        'California': ['Arizona', 'Nevada', 'Oregon'],
        'Colorado': ['Arizona', 'Kansas', 'Nebraska', 'New Mexico', 'Oklahoma', 'Utah', 'Wyoming'],
        'Connecticut': ['Massachusetts', 'New York', 'Rhode Island'],
        'Delaware': ['Maryland', 'New Jersey', 'Pennsylvania'],
        'Florida': ['Alabama', 'Georgia'],
        'Georgia': ['Alabama', 'Florida', 'North Carolina', 'South Carolina', 'Tennessee'],
        'Hawaii': [],
        'Idaho': ['Montana', 'Nevada', 'Oregon', 'Utah', 'Washington', 'Wyoming'],
        'Illinois': ['Indiana', 'Iowa', 'Michigan', 'Kentucky', 'Missouri', 'Wisconsin'],
        'Indiana': ['Illinois', 'Kentucky', 'Michigan', 'Ohio'],
        'Iowa': ['Illinois', 'Minnesota', 'Missouri', 'Nebraska', 'South Dakota', 'Wisconsin'],
        'Kansas': ['Colorado', 'Missouri', 'Nebraska', 'Oklahoma'],
        'Kentucky': ['Illinois', 'Indiana', 'Missouri', 'Ohio', 'Tennessee', 'Virginia', 'West Virginia'],
        'Louisiana': ['Arkansas', 'Mississippi', 'Texas'],
        'Maine': ['New Hampshire'],
        'Maryland': ['Delaware', 'Pennsylvania', 'Virginia', 'West Virginia'],
        'Massachusetts': ['Connecticut', 'New Hampshire', 'New York', 'Rhode Island', 'Vermont'],
        'Michigan': ['Illinois', 'Indiana', 'Minnesota', 'Ohio', 'Wisconsin'],
        'Minnesota': ['Iowa', 'Michigan', 'North Dakota', 'South Dakota', 'Wisconsin'],
        'Mississippi': ['Alabama', 'Arkansas', 'Louisiana', 'Tennessee'],
        'Missouri': ['Arkansas', 'Illinois', 'Iowa', 'Kansas', 'Kentucky', 'Nebraska', 'Oklahoma', 'Tennessee'],
        'Montana': ['Idaho', 'North Dakota', 'South Dakota', 'Wyoming'],
        'Nebraska': ['Colorado', 'Iowa', 'Kansas', 'Missouri', 'South Dakota', 'Wyoming'],
        'Nevada': ['Arizona', 'California', 'Idaho', 'Oregon', 'Utah'],
        'New Hampshire': ['Maine', 'Massachusetts', 'Vermont'],
        'New Jersey': ['Delaware', 'New York', 'Pennsylvania'],
        'New Mexico': ['Arizona', 'Colorado', 'Oklahoma', 'Texas', 'Utah'],
        'New York': ['Connecticut', 'Massachusetts', 'New Jersey', 'Pennsylvania', 'Vermont'],
        'North Carolina': ['Georgia', 'South Carolina', 'Tennessee', 'Virginia'],
        'North Dakota': ['Minnesota', 'Montana', 'South Dakota'],
        'Ohio': ['Indiana', 'Kentucky', 'Michigan', 'Pennsylvania', 'West Virginia'],
        'Oklahoma': ['Arkansas', 'Colorado', 'Kansas', 'Missouri', 'New Mexico', 'Texas'],
        'Oregon': ['California', 'Idaho', 'Nevada', 'Washington'],
        'Pennsylvania': ['Delaware', 'Maryland', 'New Jersey', 'New York', 'Ohio', 'West Virginia'],
        'Rhode Island': ['Connecticut', 'Massachusetts'],
        'South Carolina': ['Georgia', 'North Carolina'],
        'South Dakota': ['Iowa', 'Minnesota', 'Montana', 'Nebraska', 'North Dakota', 'Wyoming'],
        'Tennessee': ['Alabama', 'Arkansas', 'Georgia', 'Kentucky', 'Mississippi', 'Missouri', 'North Carolina', 'Virginia'],
        'Texas': ['Arkansas', 'Louisiana', 'New Mexico', 'Oklahoma'],
        'Utah': ['Arizona', 'Colorado', 'Idaho', 'Nevada', 'New Mexico', 'Wyoming'],
        'Vermont': ['Massachusetts', 'New Hampshire', 'New York'],
        'Virginia': ['Kentucky', 'Maryland', 'North Carolina', 'Tennessee', 'West Virginia'],
        'Washington': ['Idaho', 'Oregon'],
        'West Virginia': ['Kentucky', 'Maryland', 'Ohio', 'Pennsylvania', 'Virginia'],
        'Wisconsin': ['Illinois', 'Iowa', 'Michigan', 'Minnesota'],
        'Wyoming': ['Colorado', 'Idaho', 'Montana', 'Nebraska', 'South Dakota', 'Utah']
}

def share_common_border_us_states(s1: str, s2: str) -> bool:
    # Normalize inputs
    s1 = s1.strip().title()
    s2 = s2.strip().title()

    if s1 not in state_borders or s2 not in state_borders:
        raise ValueError(f"State '{s1}' or '{s2}' is not a valid U.S. state")

    return s2 in state_borders[s1]

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("justin2028/unemployment-in-america-per-us-state")

print("Path to dataset files:", path)

In [None]:
data = pd.read_csv(list(Path(path).glob("*.csv"))[0])
data_year = data[data["Year"] == 1976]
data_month = data_year[data_year["Month"] == 1]

filtered_data = data_month[
    data_month["State/Area"].isin(state_borders.keys()) & 
    ~data_month["State/Area"].isin(["Alaska", "Hawaii"])
]
filtered_data = filtered_data.set_index(filtered_data["State/Area"])

x_cols = ["Total Civilian Labor Force in State/Area",
          "Total Civilian Non-Institutional Population in State/Area",
          "Percent (%) of State/Area\'s Population"
          ]
y_col = "Percent (%) of Labor Force Employed in State/Area"
filtered_data[x_cols[:2]] = filtered_data[x_cols[:2]].apply(lambda col: pd.to_numeric(col.str.replace(",", ""), errors="coerce"))

filtered_data

In [None]:
plot_us_state_choropleth(filtered_data, value_col=y_col)

In [None]:
import seaborn as sns 

x_cols_short = list(range(len(x_cols)))
renamed_data = filtered_data.rename(
    columns={k:v for k, v in zip(x_cols, x_cols_short)}
)
sns.pairplot(renamed_data, vars=x_cols_short)

First two features are almost perfectly correlated, so getting rid of the first feature

In [None]:
x_cols_reduced = x_cols[1:]
x_cols_short_reduced = x_cols_short[1:]
sns.scatterplot(filtered_data, x=x_cols_reduced[0], y=x_cols_reduced[1], hue=y_col)

In [None]:
import numpy as np
from itertools import product

n_samples = len(filtered_data)
r = np.zeros((n_samples, n_samples))
for i, j in product(range(n_samples), range(n_samples)):
    state_i, state_j = filtered_data.index[i], filtered_data.index[j]
    if share_common_border_us_states(state_i, state_j):
        r[i, j] = 1

In [None]:
x = filtered_data[x_cols_reduced].to_numpy()
y = filtered_data[y_col].to_numpy()

test_states = [
    'New York', 'New Jersey', 'Connecticut', 'Massachusetts', 'Rhode Island',
    'Vermont', 'New Hampshire', 'Maine', 'Pennsylvania', 'Delaware',
    'Maryland', 'Virginia', 'West Virginia', 'North Carolina', 'South Carolina',
    'Georgia', 'Florida', 'Ohio', 'Indiana', 'Kentucky'
]
test_indices = filtered_data.index.get_indexer(test_states)

# Sanity check: Ensure no -1s (meaning state not found)
if np.any(test_indices == -1):
    missing = [state for i, state in enumerate(test_states) if test_indices[i] == -1]
    raise ValueError(f"Some test states were not found in the DataFrame index: {missing}")

# Get all indices and determine train indices (exclude test states)
all_indices = np.arange(n_samples)
train_indices = np.setdiff1d(all_indices, test_indices)

# Subset the arrays
x_train = x[train_indices]
y_train = y[train_indices]
x_test = x[test_indices]
y_test = y[test_indices]
r_train = r[np.ix_(train_indices, train_indices)]

# Compute mean and std from training data
x_mean = np.mean(x_train, axis=0, keepdims=True)
x_std = np.std(x_train, axis=0, keepdims=True) + 1e-8  # avoid division by zero

# Normalize train and test sets using training stats
x_train_norm = (x_train - x_mean) / x_std
x_test_norm = (x_test - x_mean) / x_std


In [None]:
r_test_train = r[np.ix_(test_indices, train_indices)]

In [None]:
plot_us_state_choropleth(filtered_data.iloc[train_indices], value_col=y_col)

In [None]:
plot_us_state_choropleth(filtered_data.iloc[test_indices], value_col=y_col)

In [None]:
import torch

from tabrel.benchmark.nw_regr import NwModelConfig, NwTrainConfig, RelNwRegr

train_cfg, model_cfg = NwTrainConfig(lr=.1), NwModelConfig()

model = RelNwRegr(model_cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg.lr)
loss_fn = torch.nn.MSELoss()

x_train_norm, y_train, x_test_norm, y_test, r_test_train = torch.tensor(x_train_norm), torch.tensor(y_train), torch.tensor(x_test_norm), torch.tensor(y_test), torch.tensor(r_test_train)

for use_rel in (True, False):
    model.train()
    for epoch in range(train_cfg.n_epochs):
        optimizer.zero_grad()
        y_pred = model(x_train_norm, y_train, x_test_norm, 
                       r_test_train if use_rel else torch.zeros(r_test_train.shape))
        loss = loss_fn(y_pred, y_test)
        
        loss.backward()
        optimizer.step()

    print(loss.item(), use_rel)