In [None]:
import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from typing import Final
from pathlib import Path

CACHE_DIR: Final[Path] = Path('~/.cache/us_maps').resolve()
CACHE_DIR.mkdir(exist_ok=True, parents=True)


def plot_us_state_choropleth(
    df: pd.DataFrame,
    value_col: str,
    cmap: str = 'OrRd',
    set_col: str | None = None,
    cache_dir: Path = CACHE_DIR,
    out_path: Path | None = None,
):
    """
    Plots a U.S. state choropleth map with optional border color coding by set.

    Parameters:
    - df: Pandas DataFrame with state names (index) and a statistics column.
    - value_col: Name of the statistics column.
    - cmap: Matplotlib colormap to use (default: 'OrRd').
    - set_col: Optional column name for border coloring with values: 'train', 'test', 'val'.
    - cache_dir: Directory to store or load the cached GeoJSON.
    - out_path: Optional output path (pdf, png, etc) for the figure
    """
    geojson_path = cache_dir / 'us_states.geojson'

    if not geojson_path.exists():
        url = 'https://raw.githubusercontent.com/jgoodall/us-maps/master/geojson/state.geo.json'
        print("Downloading U.S. states GeoJSON...")
        import requests
        response = requests.get(url)
        response.raise_for_status()
        with geojson_path.open('wb') as f:
            f.write(response.content)
        print("Done downloading")

    usa_states = gpd.read_file(geojson_path)

    # Normalize index casing
    df = df.copy()
    df.index = df.index.str.title()

    # Rename and merge
    usa_states = usa_states.rename(columns={'NAME10': 'state'})
    merged = usa_states.set_index('state').join(df)

    fig, ax = plt.subplots(1, 1, figsize=(15, 10))

    # Base plot
    # Plot without legend
    merged.plot(column=value_col, cmap=cmap, linewidth=0.8, ax=ax, edgecolor='0.8', legend=False)

    # Create scalar mappable for colorbar
    norm = mpl.colors.Normalize(vmin=merged[value_col].min(), vmax=merged[value_col].max())
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)

    # Add smaller colorbar
    fig.colorbar(sm, ax=ax, fraction=0.025, pad=0.04)
    linewidth: Final[float] = 2.5

    if set_col and set_col in merged.columns:
        # Color mappings for 'train', 'test', 'val'
        set_colors = {'train': 'blue', 'test': 'green', 'val': 'orange'}

        for set_value, color in set_colors.items():
            subset = merged[merged[set_col] == set_value]
            if not subset.empty:
                subset.boundary.plot(ax=ax, edgecolor=color, linewidth=linewidth)

        # Add legend boxes as text annotations with background color
        ax.text(0.1, 0.24, 
                'train', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["train"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

        ax.text(0.1, 0.17, 
                'test', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["test"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth)
                )

        ax.text(0.1, 0.1, 
                'validate',
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["val"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

    ax.set_title(f'U.S. States Colored by {value_col}', fontsize=16)
    ax.axis('off')
    if out_path:
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
state_borders = {
        'Alabama': ['Florida', 'Georgia', 'Mississippi', 'Tennessee'],
        'Alaska': [],
        'Arizona': ['California', 'Colorado', 'Nevada', 'New Mexico', 'Utah'],
        'Arkansas': ['Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'Tennessee', 'Texas'],
        'California': ['Arizona', 'Nevada', 'Oregon'],
        'Colorado': ['Arizona', 'Kansas', 'Nebraska', 'New Mexico', 'Oklahoma', 'Utah', 'Wyoming'],
        'Connecticut': ['Massachusetts', 'New York', 'Rhode Island'],
        'Delaware': ['Maryland', 'New Jersey', 'Pennsylvania'],
        'Florida': ['Alabama', 'Georgia'],
        'Georgia': ['Alabama', 'Florida', 'North Carolina', 'South Carolina', 'Tennessee'],
        'Hawaii': [],
        'Idaho': ['Montana', 'Nevada', 'Oregon', 'Utah', 'Washington', 'Wyoming'],
        'Illinois': ['Indiana', 'Iowa', 'Michigan', 'Kentucky', 'Missouri', 'Wisconsin'],
        'Indiana': ['Illinois', 'Kentucky', 'Michigan', 'Ohio'],
        'Iowa': ['Illinois', 'Minnesota', 'Missouri', 'Nebraska', 'South Dakota', 'Wisconsin'],
        'Kansas': ['Colorado', 'Missouri', 'Nebraska', 'Oklahoma'],
        'Kentucky': ['Illinois', 'Indiana', 'Missouri', 'Ohio', 'Tennessee', 'Virginia', 'West Virginia'],
        'Louisiana': ['Arkansas', 'Mississippi', 'Texas'],
        'Maine': ['New Hampshire'],
        'Maryland': ['Delaware', 'Pennsylvania', 'Virginia', 'West Virginia'],
        'Massachusetts': ['Connecticut', 'New Hampshire', 'New York', 'Rhode Island', 'Vermont'],
        'Michigan': ['Illinois', 'Indiana', 'Minnesota', 'Ohio', 'Wisconsin'],
        'Minnesota': ['Iowa', 'Michigan', 'North Dakota', 'South Dakota', 'Wisconsin'],
        'Mississippi': ['Alabama', 'Arkansas', 'Louisiana', 'Tennessee'],
        'Missouri': ['Arkansas', 'Illinois', 'Iowa', 'Kansas', 'Kentucky', 'Nebraska', 'Oklahoma', 'Tennessee'],
        'Montana': ['Idaho', 'North Dakota', 'South Dakota', 'Wyoming'],
        'Nebraska': ['Colorado', 'Iowa', 'Kansas', 'Missouri', 'South Dakota', 'Wyoming'],
        'Nevada': ['Arizona', 'California', 'Idaho', 'Oregon', 'Utah'],
        'New Hampshire': ['Maine', 'Massachusetts', 'Vermont'],
        'New Jersey': ['Delaware', 'New York', 'Pennsylvania'],
        'New Mexico': ['Arizona', 'Colorado', 'Oklahoma', 'Texas', 'Utah'],
        'New York': ['Connecticut', 'Massachusetts', 'New Jersey', 'Pennsylvania', 'Vermont'],
        'North Carolina': ['Georgia', 'South Carolina', 'Tennessee', 'Virginia'],
        'North Dakota': ['Minnesota', 'Montana', 'South Dakota'],
        'Ohio': ['Indiana', 'Kentucky', 'Michigan', 'Pennsylvania', 'West Virginia'],
        'Oklahoma': ['Arkansas', 'Colorado', 'Kansas', 'Missouri', 'New Mexico', 'Texas'],
        'Oregon': ['California', 'Idaho', 'Nevada', 'Washington'],
        'Pennsylvania': ['Delaware', 'Maryland', 'New Jersey', 'New York', 'Ohio', 'West Virginia'],
        'Rhode Island': ['Connecticut', 'Massachusetts'],
        'South Carolina': ['Georgia', 'North Carolina'],
        'South Dakota': ['Iowa', 'Minnesota', 'Montana', 'Nebraska', 'North Dakota', 'Wyoming'],
        'Tennessee': ['Alabama', 'Arkansas', 'Georgia', 'Kentucky', 'Mississippi', 'Missouri', 'North Carolina', 'Virginia'],
        'Texas': ['Arkansas', 'Louisiana', 'New Mexico', 'Oklahoma'],
        'Utah': ['Arizona', 'Colorado', 'Idaho', 'Nevada', 'New Mexico', 'Wyoming'],
        'Vermont': ['Massachusetts', 'New Hampshire', 'New York'],
        'Virginia': ['Kentucky', 'Maryland', 'North Carolina', 'Tennessee', 'West Virginia'],
        'Washington': ['Idaho', 'Oregon'],
        'West Virginia': ['Kentucky', 'Maryland', 'Ohio', 'Pennsylvania', 'Virginia'],
        'Wisconsin': ['Illinois', 'Iowa', 'Michigan', 'Minnesota'],
        'Wyoming': ['Colorado', 'Idaho', 'Montana', 'Nebraska', 'South Dakota', 'Utah']
}

def share_common_border_us_states(s1: str, s2: str) -> bool:
    # Normalize inputs
    s1 = s1.strip().title()
    s2 = s2.strip().title()

    if s1 not in state_borders or s2 not in state_borders:
        raise ValueError(f"State '{s1}' or '{s2}' is not a valid U.S. state")

    return s2 in state_borders[s1]

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("justin2028/unemployment-in-america-per-us-state")

print("Path to dataset files:", path)

In [None]:
data = pd.read_csv(list(Path(path).glob("*.csv"))[0])
example_year: Final[int] = 1976
example_month: Final[int] = 1

filtered_data = data[
    data["State/Area"].isin(state_borders.keys()) & 
    ~data["State/Area"].isin(["Alaska", "Hawaii"])
]
filtered_data = filtered_data.set_index(filtered_data["State/Area"])

x_cols = ["Total Civilian Labor Force in State/Area",
          "Total Civilian Non-Institutional Population in State/Area",
          "Percent (%) of State/Area\'s Population"
          ]
y_col = "Percent (%) of Labor Force Employed in State/Area"
filtered_data[x_cols[:2]] = filtered_data[x_cols[:2]].apply(lambda col: pd.to_numeric(col.str.replace(",", ""), errors="coerce"))


def subset_month(year: int, month: int) -> pd.DataFrame:
    data_year = filtered_data[filtered_data["Year"] == year]
    return data_year[data_year["Month"] == month]



data_month = subset_month(year=example_year, month=example_month)
data_month

In [None]:
plot_us_state_choropleth(data_month, value_col=y_col)

In [None]:
import seaborn as sns 

x_cols_short = list(range(len(x_cols)))
renamed_data = data_month.rename(
    columns={k:v for k, v in zip(x_cols, x_cols_short)}
)
sns.pairplot(renamed_data, vars=x_cols_short)

First two features are almost perfectly correlated, so getting rid of the first feature

In [None]:
x_cols_reduced = x_cols[1:]
x_cols_short_reduced = x_cols_short[1:]
sns.scatterplot(data_month, x=x_cols_reduced[0], y=x_cols_reduced[1], hue=y_col)

In [None]:
# Northwestern US states
test_states = [
    "Washington", "Oregon", "Idaho", "Montana", "Wyoming",
    "North Dakota", "South Dakota", "Nebraska", "Minnesota", "Iowa",
    "Colorado", "Utah", "Nevada", "Kansas", "Missouri"
]

# Eastern US states
val_states = [
    "Maine", "New Hampshire", "Vermont", "Massachusetts", "Connecticut",
    "New York", "New Jersey", "Pennsylvania", "Delaware", "Maryland",
    "Rhode Island", "Virginia", "North Carolina", "South Carolina", "Georgia"
]

set_col: Final[str] = "set"
data_month[set_col] = "train"
data_month.loc[test_states, set_col] = "test"
data_month.loc[val_states, set_col] = "val"

plot_us_state_choropleth(data_month, value_col=y_col, set_col=set_col, out_path=Path(f"emloyment_{example_year}_{example_month}.png"))

In [None]:
from itertools import product
from collections import defaultdict
from typing import Final
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import torch
from tqdm import tqdm
from tabrel.benchmark.nw_regr import NwModelConfig, RelNwRegr

LR: Final[float] = 0.006
N_EPOCHS: Final[int] = 100

def compute_relation_matrix(states: list[str]) -> np.ndarray:
    n = len(states)
    r = np.zeros((n, n))
    for i, j in product(range(n), range(n)):
        if share_common_border_us_states(states[i], states[j]):
            r[i, j] = 1
    return r

def run_training(
    x: np.ndarray,
    y: np.ndarray,
    r: np.ndarray,
    train_indices: np.ndarray,
    test_indices: np.ndarray,
    val_indices: np.ndarray,
    lr: float,
    n_epochs: int,
) -> dict[str, tuple[float, float, float, RelNwRegr]]:
    x_train = x[train_indices]
    y_train = y[train_indices]
    x_test = x[test_indices]
    y_test = y[test_indices]
    x_val = x[val_indices]
    y_val = y[val_indices]

    x_mean = np.mean(x_train, axis=0, keepdims=True)
    x_std = np.std(x_train, axis=0, keepdims=True) + 1e-8

    x_train_norm = (x_train - x_mean) / x_std
    x_test_norm = (x_test - x_mean) / x_std
    x_val_norm = (x_val - x_mean) / x_std

    r_test_train = r[np.ix_(test_indices, train_indices)]
    x_nonval_norm = np.concatenate((x_train_norm, x_test_norm))
    y_nonval = np.concatenate((y_train, y_test))
    r_val_nonval = r[np.ix_(val_indices, np.concatenate((train_indices, test_indices)))]

    # Convert to torch
    x_train_norm = torch.tensor(x_train_norm, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32)
    x_test_norm = torch.tensor(x_test_norm, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.float32)
    r_test_train = torch.tensor(r_test_train, dtype=torch.float32)
    x_val_norm = torch.tensor(x_val_norm, dtype=torch.float32)
    x_nonval_norm = torch.tensor(x_nonval_norm, dtype=torch.float32)
    y_val = torch.tensor(y_val, dtype=torch.float32)
    y_nonval = torch.tensor(y_nonval, dtype=torch.float32)
    r_val_nonval = torch.tensor(r_val_nonval, dtype=torch.float32)

    results_local = {}
    model_cfg = NwModelConfig()

    for use_rel in (True, False):
        model = RelNwRegr(model_cfg)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = torch.nn.MSELoss()

        model.train()
        for epoch in range(n_epochs):
            optimizer.zero_grad()
            y_pred = model(x_train_norm, y_train, x_test_norm,
                           r_test_train if use_rel else torch.zeros_like(r_test_train))
            loss = loss_fn(y_pred, y_test)
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            y_pred_val = model(x_nonval_norm, y_nonval, x_val_norm,
                               r_val_nonval if use_rel else torch.zeros_like(r_val_nonval))
            y_pred_val_np = y_pred_val.numpy()
            y_val_np = y_val.numpy()

            mse = mean_squared_error(y_val_np, y_pred_val_np)
            r2 = r2_score(y_val_np, y_pred_val_np)
            mae = mean_absolute_error(y_val_np, y_pred_val_np)

            results_local[f"rel={use_rel}"] = (mse, r2, mae, model)

    return results_local


results = defaultdict(list)

for year in tqdm(filtered_data["Year"].unique()):
    for month in range(1, 13):
        data_month: pd.DataFrame = subset_month(year, month)
        if data_month.empty:
            continue

        try:
            test_indices = data_month.index.get_indexer(test_states)
            val_indices = data_month.index.get_indexer(val_states)
            if np.any(test_indices == -1) or np.any(val_indices == -1):
                continue

            n_samples: Final[int] = len(data_month)
            all_indices = np.arange(n_samples)
            nontrain_indices = np.array([*test_indices, *val_indices])
            train_indices = np.setdiff1d(all_indices, test_indices)

            x = data_month[x_cols_reduced].to_numpy()
            y = data_month[y_col].to_numpy()

            r = compute_relation_matrix(data_month.index)

            month_results = run_training(
                x=x,
                y=y,
                r=r,
                train_indices=train_indices,
                test_indices=test_indices,
                val_indices=val_indices,
                lr=LR,
                n_epochs=N_EPOCHS,
            )
            for k, v in month_results.items():
                results[k].append(v)

        except Exception as e:
            # print(f"Skipping year={year}, month={month}: {e}")
            continue


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract labels
labels = list(results.keys())
bar_locs = np.arange(len(labels))

# Prepare data arrays
metrics_all = {
    "Validation MSE": [np.array([x[0] for x in results[k]]) for k in labels],
    "Validation $R^2$": [np.array([x[1] for x in results[k]]) for k in labels],
    # "Validation MAE": [np.array([x[2] for x in results[k]]) for k in labels],
}
colors = {
    "Validation MSE": "skyblue",
    "Validation $R^2$": "lightgreen",
    # "Validation MAE": "lightcoral",
}

def plot_violin(ax, data, title, color):
    vp = ax.violinplot(data, positions=bar_locs, showmeans=True, showmedians=False)
    for i, group in enumerate(data):
        jitter = np.random.normal(0, 0.05, size=len(group))
        ax.scatter(np.full(len(group), bar_locs[i]) + jitter, group, color='black', s=10, alpha=0.6)
    for pc in vp['bodies']:
        pc.set_facecolor(color)
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    ax.set_title(title)
    ax.set_xticks(bar_locs)
    ax.set_xticklabels(labels)

# Create subplots
fig, axes = plt.subplots(1, len(metrics_all), figsize=(5, 5), sharex=True)

for ax, (label, data) in zip(axes, metrics_all.items()):
    plot_violin(ax, data, title=label, color=colors[label])

plt.tight_layout()
plt.savefig("employment_violinplots.pdf")
plt.show()

# Custom dataset

In [None]:
df = pd.read_csv("/Users/vzuev/Documents/git/gh_zuevval/tabrel/data/birthrate_usStates.csv", sep=";")
df = df[:-2]
for col in ("Urban", "Birth rate", "PCPI"):
    df[col] = list(map(lambda s: float(s.replace(",", ".")), df[col]))
df = df.set_index(df["State"])
df = df[~df["State"].isin(["Alaska", "Hawaii"])]

df[set_col] = "train"
plot_us_state_choropleth(df, value_col="Birth rate")
sns.pairplot(df)

# United States Energy, Census, and GDP 2010-2014

In [None]:
path = kagglehub.dataset_download("lislejoem/us_energy_census_gdp_10-14")
gdp_df = pd.read_csv(next(Path(path).glob("*.csv")))
gdp_df = gdp_df[~gdp_df["State"].isin(("Alaska", "Hawaii", "District of Columbia", "United States"))]

gdp_year: Final[int] = 2014
data_year = gdp_df.loc[:, gdp_df.columns.str.endswith(str(gdp_year))]
data_year.index = gdp_df["State"]

def compute_gdp_per_capita(year: int) -> list[float]:
    return [g/c for g, c in zip(gdp_df[f"GDP{year}"], gdp_df["CENSUS2010POP"])]

data_r = data_year.loc[:, data_year.columns.str.startswith("R")]
gdp_pc_col: Final[str] = "gdpPerCap"
data_r.loc[:, gdp_pc_col] = compute_gdp_per_capita(gdp_year)


# data_r.loc[:, "rnatbd"] = data_r["RBIRTH2014"] - data_r["RDEATH2014"]
# data_r.loc[:, "rdomintmig"] = data_r["RINTERNATIONALMIG2014"] + data_r["RDOMESTICMIG2014"]
# sns.pairplot(data_r)

mig_col: Final[str] = f"RINTERNATIONALMIG{gdp_year}"
sns.scatterplot(data_r, x=gdp_pc_col, y=mig_col)
plot_us_state_choropleth(data_r, value_col=mig_col)

In [None]:
states = gdp_df["State"]
gdp_df = gdp_df.set_index(states)
states = states.tolist()
r = compute_relation_matrix(states)

all_indices = np.arange(len(states))
val_indices = gdp_df.index.get_indexer(val_states)
test_indices = gdp_df.index.get_indexer(test_states)
train_indices = np.setdiff1d(all_indices, np.concatenate((val_indices,  test_indices)))
print(val_indices, test_indices, train_indices)

for year in range(2011, 2015):
    gdp_col = f"GDP{year}"
    mig_col = f"RINTERNATIONALMIG{year}"
    gdp_df.loc[:, gdp_pc_col] = compute_gdp_per_capita(year)

    x = gdp_df[[gdp_pc_col]].to_numpy()
    y = gdp_df[mig_col].to_numpy()

    
    # Run training
    results = run_training(
        x=x,
        y=y,
        r=r,
        train_indices=train_indices,
        test_indices=test_indices,
        val_indices=val_indices,
        lr=LR,
        n_epochs=N_EPOCHS,
    )

    # Print results
    print(f"=== Year {year} ===")
    # model = results["rel=False"][-1][-1]
    for key, values in results.items():
        mse, r2, mae = values[:-1]
        print(f"{key}: MSE={mse:.3f}, R^2={r2:.3f}, MAE={mae:.3f}")
    print()

# World

In [None]:
import kagglehub
import pycountry
from pathlib import Path

# Download latest version
path = kagglehub.dataset_download("mlippo/average-global-iq-per-country-with-other-stats")

df = pd.read_csv(list(Path(path).glob("*.csv"))[0])

def clean_data(df):
    df['Population - 2023'] = df['Population - 2023'].str.replace('[,.]', '', regex=True)
   
    df = df.astype({'Population - 2023': 'int64'})
    return df

df_clean = clean_data(df.copy())
df_clean['ISO_alpha'] = df_clean['Country'].apply(lambda x: pycountry.countries.get(name=x).alpha_3 if pycountry.countries.get(name=x) else None)


In [None]:
import plotly.express as px

fig_geo = px.choropleth(
    df_clean,
    locations='ISO_alpha',
    color='Average IQ',
    color_continuous_scale=px.colors.sequential.YlOrRd,
    labels={'Average IQ': 'Average IQ'},
    title='Average IQ by Country'
)

fig_geo.show()

In [None]:
import geopandas as gpd
from shapely.geometry import Polygon, MultiPolygon

# Load shapefile (adjust the path to where you extracted the data)
world = gpd.read_file("/Users/vzuev/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

# Ensure ISO_A3 codes exist and are valid
world = world[world['ISO_A3_EH'] != '-99']  # Remove invalid entries
len(world)

In [None]:
from tqdm import tqdm
from shapely.ops import unary_union

from shapely.geometry import Polygon, MultiPolygon
from shapely.ops import unary_union

def preprocess_geometry(geom):
    """
    Selects the largest polygon and includes all others that touch it.
    This avoids including remote territories like French Guiana (France).
    """
    if isinstance(geom, Polygon):
        return geom

    elif isinstance(geom, MultiPolygon):
        # Find the largest polygon (assumed to be mainland)
        parts = list(geom.geoms)
        mainland = max(parts, key=lambda p: p.area)

        # Include all polygons that touch the mainland
        touching_parts = [p for p in parts if p == mainland or p.touches(mainland)]

        return unary_union(touching_parts)

    return None

def build_border_map(world):
    border_map = {}

    for _, country in tqdm(list(world.iterrows())):
        iso_a3 = country['ISO_A3_EH']
        geom = preprocess_geometry(country.geometry)

        # Skip if geometry is missing or invalid
        if not isinstance(geom, (Polygon, MultiPolygon)):
            continue

        neighbors = set()

        for idx2, other_country in world.iterrows():
            other_iso = other_country['ISO_A3_EH']
            if iso_a3 == other_iso:
                continue

            other_geom = preprocess_geometry(other_country.geometry)

            if not isinstance(other_geom, (Polygon, MultiPolygon)):
                continue

            # Check for shared border
            if geom.touches(other_geom):
                neighbors.add(other_iso)

        border_map[iso_a3] = neighbors

    return border_map

border_map = build_border_map(world)


In [None]:
def share_common_border(iso3_country1: str, iso3_country2: str, bm: dict = border_map) -> bool:
    iso3_country1 = iso3_country1.upper()
    iso3_country2 = iso3_country2.upper()
    return iso3_country2 in bm.get(iso3_country1, set())

print(share_common_border("FRA", "DEU"))  # True
print(share_common_border("FRA", "USA"))  # False
print(share_common_border("MEX", "GTM"))  # True


In [None]:
from collections import deque

def get_connected_country_set(seed_iso3: str, bm: dict, max_size=10) -> set:
    visited = set()
    queue = deque([seed_iso3])

    while queue and len(visited) < max_size:
        country = queue.popleft()
        if country in visited:
            continue
        visited.add(country)

        # Enqueue unvisited neighbors
        neighbors = bm.get(country, [])
        for neighbor in neighbors:
            if neighbor not in visited:
                queue.append(neighbor)

    return visited

query_set = get_connected_country_set("BRA", border_map)
validate_set = get_connected_country_set("CAF", border_map)

print(f"Query set (starts from Brazil): {len(query_set)} countries")
print(query_set)

print(f"\nValidate set (starts from Central Africa): {len(validate_set)} countries")
print(validate_set)

In [None]:
import numpy as np
import pandas as pd

# 1. Drop rows with any missing values in required columns

df_filtered = df_clean.dropna().copy()

# 2. Set index to ISO_alpha
df_filtered.set_index("ISO_alpha", inplace=True)

# 3. Extract feature matrix X and target y
X = df_filtered[[# " GNI - 2021",
                  "Mean years of schooling - 2021", 
                  "Literacy Rate"]].to_numpy()
y = df_filtered["Average IQ"].to_numpy()

# 4. Build sets based on query_set and val_set
all_isos = set(df_filtered.index)

train_set = all_isos - query_set - validate_set

train_indices = [i for i, iso in enumerate(df_filtered.index) if iso in train_set]
query_indices = [i for i, iso in enumerate(df_filtered.index) if iso in query_set]
val_indices = [i for i, iso in enumerate(df_filtered.index) if iso in validate_set]

# 5. Build adjacency matrix R
N = len(df_filtered)
R = np.zeros((N, N), dtype=int)
iso_list = list(df_filtered.index)

# Fast lookup: map ISO to row index
iso_to_idx = {iso: i for i, iso in enumerate(iso_list)}

for i, iso_i in enumerate(iso_list):
    neighbors = border_map.get(iso_i, set())
    for neighbor in neighbors:
        j = iso_to_idx.get(neighbor)
        if j is not None:
            R[i, j] = 1
            R[j, i] = 1  # Ensure symmetry


In [None]:
test_indices

In [None]:
run_training(x=X, y=y, r=R, train_indices=np.array(train_indices), test_indices=np.array(query_indices), val_indices=np.array(val_indices), lr=LR, n_epochs=50)