In [None]:
%load_ext autoreload
%autoreload 2

import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from typing import Final
from pathlib import Path

CACHE_DIR: Final[Path] = Path('~/.cache/us_maps').resolve()
CACHE_DIR.mkdir(exist_ok=True, parents=True)


def plot_us_state_choropleth(
    df: pd.DataFrame,
    value_col: str,
    cmap: str = 'OrRd',
    set_col: str | None = None,
    cache_dir: Path = CACHE_DIR,
    out_path: Path | None = None,
):
    """
    Plots a U.S. state choropleth map with optional border color coding by set.

    Parameters:
    - df: Pandas DataFrame with state names (index) and a statistics column.
    - value_col: Name of the statistics column.
    - cmap: Matplotlib colormap to use (default: 'OrRd').
    - set_col: Optional column name for border coloring with values: 'backgnd', 'query', 'val'.
    - cache_dir: Directory to store or load the cached GeoJSON.
    - out_path: Optional output path (pdf, png, etc) for the figure
    """
    geojson_path = cache_dir / 'us_states.geojson'

    if not geojson_path.exists():
        url = 'https://raw.githubusercontent.com/jgoodall/us-maps/master/geojson/state.geo.json'
        print("Downloading U.S. states GeoJSON...")
        import requests
        response = requests.get(url)
        response.raise_for_status()
        with geojson_path.open('wb') as f:
            f.write(response.content)
        print("Done downloading")

    usa_states = gpd.read_file(geojson_path)

    # Normalize index casing
    df = df.copy()
    df.index = df.index.str.title()

    # Rename and merge
    usa_states = usa_states.rename(columns={'NAME10': 'state'})
    merged = usa_states.set_index('state').join(df)

    fig, ax = plt.subplots(1, 1, figsize=(15, 10))

    # Base plot
    # Plot without legend
    merged.plot(column=value_col, cmap=cmap, linewidth=0.8, ax=ax, edgecolor='0.8', legend=False)

    # Create scalar mappable for colorbar
    norm = mpl.colors.Normalize(vmin=merged[value_col].min(), vmax=merged[value_col].max())
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)

    # Add smaller colorbar
    fig.colorbar(sm, ax=ax, fraction=0.025, pad=0.04)
    linewidth: Final[float] = 2.5

    if set_col and set_col in merged.columns:
        # Color mappings for 'backgnd', 'trial', 'val'
        set_colors = {'backgnd': 'blue', 'trial': 'green', 'val': 'orange'}

        for set_value, color in set_colors.items():
            subset = merged[merged[set_col] == set_value]
            if not subset.empty:
                subset.boundary.plot(ax=ax, edgecolor=color, linewidth=linewidth)

        # Add legend boxes as text annotations with background color
        ax.text(0.1, 0.24, 
                'background', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["backgnd"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

        ax.text(0.1, 0.17, 
                'trial', 
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["trial"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth)
                )

        ax.text(0.1, 0.1, 
                'validate',
                transform=ax.transAxes, 
                fontsize=14,
                bbox=dict(facecolor="white", edgecolor=set_colors["val"], boxstyle="round,pad=0.3", alpha=0.8, linewidth=linewidth),
                )

    ax.set_title(f'U.S. States Colored by {value_col}', fontsize=16)
    ax.axis('off')
    if out_path:
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
state_borders = {
        'Alabama': ['Florida', 'Georgia', 'Mississippi', 'Tennessee'],
        'Alaska': [],
        'Arizona': ['California', 'Colorado', 'Nevada', 'New Mexico', 'Utah'],
        'Arkansas': ['Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'Tennessee', 'Texas'],
        'California': ['Arizona', 'Nevada', 'Oregon'],
        'Colorado': ['Arizona', 'Kansas', 'Nebraska', 'New Mexico', 'Oklahoma', 'Utah', 'Wyoming'],
        'Connecticut': ['Massachusetts', 'New York', 'Rhode Island'],
        'Delaware': ['Maryland', 'New Jersey', 'Pennsylvania'],
        'Florida': ['Alabama', 'Georgia'],
        'Georgia': ['Alabama', 'Florida', 'North Carolina', 'South Carolina', 'Tennessee'],
        'Hawaii': [],
        'Idaho': ['Montana', 'Nevada', 'Oregon', 'Utah', 'Washington', 'Wyoming'],
        'Illinois': ['Indiana', 'Iowa', 'Michigan', 'Kentucky', 'Missouri', 'Wisconsin'],
        'Indiana': ['Illinois', 'Kentucky', 'Michigan', 'Ohio'],
        'Iowa': ['Illinois', 'Minnesota', 'Missouri', 'Nebraska', 'South Dakota', 'Wisconsin'],
        'Kansas': ['Colorado', 'Missouri', 'Nebraska', 'Oklahoma'],
        'Kentucky': ['Illinois', 'Indiana', 'Missouri', 'Ohio', 'Tennessee', 'Virginia', 'West Virginia'],
        'Louisiana': ['Arkansas', 'Mississippi', 'Texas'],
        'Maine': ['New Hampshire'],
        'Maryland': ['Delaware', 'Pennsylvania', 'Virginia', 'West Virginia'],
        'Massachusetts': ['Connecticut', 'New Hampshire', 'New York', 'Rhode Island', 'Vermont'],
        'Michigan': ['Illinois', 'Indiana', 'Minnesota', 'Ohio', 'Wisconsin'],
        'Minnesota': ['Iowa', 'Michigan', 'North Dakota', 'South Dakota', 'Wisconsin'],
        'Mississippi': ['Alabama', 'Arkansas', 'Louisiana', 'Tennessee'],
        'Missouri': ['Arkansas', 'Illinois', 'Iowa', 'Kansas', 'Kentucky', 'Nebraska', 'Oklahoma', 'Tennessee'],
        'Montana': ['Idaho', 'North Dakota', 'South Dakota', 'Wyoming'],
        'Nebraska': ['Colorado', 'Iowa', 'Kansas', 'Missouri', 'South Dakota', 'Wyoming'],
        'Nevada': ['Arizona', 'California', 'Idaho', 'Oregon', 'Utah'],
        'New Hampshire': ['Maine', 'Massachusetts', 'Vermont'],
        'New Jersey': ['Delaware', 'New York', 'Pennsylvania'],
        'New Mexico': ['Arizona', 'Colorado', 'Oklahoma', 'Texas', 'Utah'],
        'New York': ['Connecticut', 'Massachusetts', 'New Jersey', 'Pennsylvania', 'Vermont'],
        'North Carolina': ['Georgia', 'South Carolina', 'Tennessee', 'Virginia'],
        'North Dakota': ['Minnesota', 'Montana', 'South Dakota'],
        'Ohio': ['Indiana', 'Kentucky', 'Michigan', 'Pennsylvania', 'West Virginia'],
        'Oklahoma': ['Arkansas', 'Colorado', 'Kansas', 'Missouri', 'New Mexico', 'Texas'],
        'Oregon': ['California', 'Idaho', 'Nevada', 'Washington'],
        'Pennsylvania': ['Delaware', 'Maryland', 'New Jersey', 'New York', 'Ohio', 'West Virginia'],
        'Rhode Island': ['Connecticut', 'Massachusetts'],
        'South Carolina': ['Georgia', 'North Carolina'],
        'South Dakota': ['Iowa', 'Minnesota', 'Montana', 'Nebraska', 'North Dakota', 'Wyoming'],
        'Tennessee': ['Alabama', 'Arkansas', 'Georgia', 'Kentucky', 'Mississippi', 'Missouri', 'North Carolina', 'Virginia'],
        'Texas': ['Arkansas', 'Louisiana', 'New Mexico', 'Oklahoma'],
        'Utah': ['Arizona', 'Colorado', 'Idaho', 'Nevada', 'New Mexico', 'Wyoming'],
        'Vermont': ['Massachusetts', 'New Hampshire', 'New York'],
        'Virginia': ['Kentucky', 'Maryland', 'North Carolina', 'Tennessee', 'West Virginia'],
        'Washington': ['Idaho', 'Oregon'],
        'West Virginia': ['Kentucky', 'Maryland', 'Ohio', 'Pennsylvania', 'Virginia'],
        'Wisconsin': ['Illinois', 'Iowa', 'Michigan', 'Minnesota'],
        'Wyoming': ['Colorado', 'Idaho', 'Montana', 'Nebraska', 'South Dakota', 'Utah']
}

def share_common_border_us_states(s1: str, s2: str) -> bool:
    # Normalize inputs
    s1 = s1.strip().title()
    s2 = s2.strip().title()

    if s1 not in state_borders or s2 not in state_borders:
        raise ValueError(f"State '{s1}' or '{s2}' is not a valid U.S. state")

    return s2 in state_borders[s1]

In [None]:
import kagglehub

# path = kagglehub.dataset_download("justin2028/unemployment-in-america-per-us-state")
path = Path("/Users/user/.cache/kagglehub/datasets/justin2028/unemployment-in-america-per-us-state/versions/3")
print("Path to dataset files:", path)

In [None]:
data = pd.read_csv(list(Path(path).glob("*.csv"))[0])
example_year: Final[int] = 1976
example_month: Final[int] = 1

filtered_data = data[
    data["State/Area"].isin(state_borders.keys()) & 
    ~data["State/Area"].isin(["Alaska", "Hawaii"])
]
filtered_data = filtered_data.set_index(filtered_data["State/Area"])

x_cols = ["Total Civilian Labor Force in State/Area",
          "Total Civilian Non-Institutional Population in State/Area",
          "Percent (%) of State/Area\'s Population"
          ]
y_col = "Percent (%) of Labor Force Employed in State/Area"
filtered_data[x_cols[:2]] = filtered_data[x_cols[:2]].apply(lambda col: pd.to_numeric(col.str.replace(",", ""), errors="coerce"))


def subset_month(year: int, month: int) -> pd.DataFrame:
    data_year = filtered_data[filtered_data["Year"] == year]
    return data_year[data_year["Month"] == month]



data_month = subset_month(year=example_year, month=example_month)
data_month.head()

In [None]:
plot_us_state_choropleth(data_month, value_col=y_col)

In [None]:
import seaborn as sns 

x_cols_short = ["Labor Force", "Population", "Percent Eligible"]
y_col_short = "Percent Employed"
cols_short = x_cols_short + [y_col_short]
renamed_data = data_month.rename(
    columns={k:v for k, v in zip(x_cols + [y_col], cols_short)}
)
# sns.pairplot(renamed_data, vars=cols_short)
# plt.savefig("employment_pairplot.pdf")

First two features are almost perfectly correlated, so getting rid of the first feature

In [None]:
x_cols_reduced = x_cols[1:]
x_cols_short_reduced = x_cols_short[1:]
# sns.scatterplot(data_month, x=x_cols_reduced[0], y=x_cols_reduced[1], hue=y_col)

In [None]:
from itertools import product
from collections import defaultdict
from typing import Final
import numpy as np
import pandas as pd

# # Northwestern US states
# query_states = [
#     "Washington", "Oregon", "Idaho", "Montana", "Wyoming",
#     "North Dakota", "South Dakota", "Nebraska", "Minnesota", "Iowa",
#     "Colorado", "Utah", "Nevada", "Kansas", "Missouri"
# ]

# # Eastern US states
# val_states = [
#     "Maine", "New Hampshire", "Vermont", "Massachusetts", "Connecticut",
#     "New York", "New Jersey", "Pennsylvania", "Delaware", "Maryland",
#     "Rhode Island", "Virginia", "North Carolina", "South Carolina", "Georgia"
# ]

np.random.seed(42)
states_shuffled = list(set(filtered_data.index))
np.random.shuffle(states_shuffled)
n_query = n_val = len(states_shuffled) // 3

query_states = states_shuffled[:n_query]
val_states = states_shuffled[n_query : n_query + n_val]


set_col: Final[str] = "set"
data_month[set_col] = "backgnd"
data_month.loc[query_states, set_col] = "trial"
data_month.loc[val_states, set_col] = "val"

# plot_us_state_choropleth(data_month, value_col=y_col, set_col=set_col, out_path=Path(f"emloyment_{example_year}_{example_month}.png"))

In [None]:
from tqdm import tqdm
from tabrel.benchmark.nw_regr import MlpConfig, run_training
from tabrel.train import train_relnet

LR: Final[float] = 0.006
N_EPOCHS: Final[int] = 5000

def compute_relation_matrix(states: list[str]) -> np.ndarray:
    n = len(states)
    r = np.zeros((n, n))
    for i, j in product(range(n), range(n)):
        if share_common_border_us_states(states[i], states[j]):
            r[i, j] = 1
    return r


results = defaultdict(list)

for year in tqdm(filtered_data["Year"].unique()):
    for month in range(1, 13):
        data_month: pd.DataFrame = subset_month(year, month)
        if data_month.empty:
            continue

        # try:
        query_indices = data_month.index.get_indexer(query_states)
        val_indices = data_month.index.get_indexer(val_states)
        if np.any(query_indices == -1) or np.any(val_indices == -1):
            continue

        n_samples = len(data_month)
        all_indices = np.arange(n_samples)
        nonbackgnd_indices = np.array([*query_indices, *val_indices])
        backgnd_indices = np.setdiff1d(all_indices, nonbackgnd_indices)

        x = data_month[x_cols_reduced].to_numpy()
        y = data_month[y_col].to_numpy()

        r = compute_relation_matrix(list(data_month.index))

        month_results = run_training(
            x=x,
            y=y,
            r=r,
            backgnd_indices=backgnd_indices,
            query_indices=query_indices,
            val_indices=val_indices,
            lr=LR,
            n_epochs=N_EPOCHS,
            rel_as_feats=r,
            mlp_config=MlpConfig(
                        in_dim=x.shape[1],
                        hidden_dim=10,
                        out_dim=20,
                        dropout=.19,
                    ),
        )
        for k, v in month_results.items():
            results[k].append(v)
        
        def train_relnet_shorthand(n_layers: int) -> tuple[float, float]:
            return train_relnet(
                    x=x,
                    y=y,
                    r=r,
                    backgnd_indices=query_indices,
                    query_indices=backgnd_indices,
                    val_indices=val_indices,
                    lr=0.003,
                    n_epochs=5000,
                    n_layers=n_layers,
                    progress_bar=False,
                    print_loss=False,
                )[:2]
            # relnet_results = train_relnet_shorthand(2)
            # if relnet_results[0] > 100:
            #     relnet_results = train_relnet_shorthand(1)
            # results["relnet"].append([*relnet_results, None])
            

        # except Exception as e:
        #     print(f"Skipping year={year}, month={month}: {e}")
        #     continue
        break
    # break  # to make execution faster; remove to reproduce results


In [None]:
def to_df(metrics_results: dict) -> pd.DataFrame:
    results_np = {k: np.array([tup[:2] for tup in v]) for k, v in metrics_results.items()}
    results_means = [{"label": k, "means": v.mean(axis=0),"std": v.std(axis=0)} for k, v in results_np.items()]
    return pd.DataFrame(results_means)

to_df(results).to_csv("usUnemployment_results_firstMonth_ave.csv")

# World

In [None]:
import kagglehub
import pycountry
from pathlib import Path

# Download latest version
path = kagglehub.dataset_download("mlippo/average-global-iq-per-country-with-other-stats")

df = pd.read_csv(list(Path(path).glob("*.csv"))[0])

def clean_data(df):
    df['Population - 2023'] = df['Population - 2023'].str.replace('[,.]', '', regex=True)
   
    df = df.astype({'Population - 2023': 'int64'})
    return df

df_clean = clean_data(df.copy())
df_clean['ISO_alpha'] = df_clean['Country'].apply(lambda x: pycountry.countries.get(name=x).alpha_3 if pycountry.countries.get(name=x) else None)


In [None]:
import plotly.express as px

fig_geo = px.choropleth(
    df_clean,
    locations='ISO_alpha',
    color='Literacy Rate',
    color_continuous_scale=px.colors.sequential.YlOrRd,
    # labels={'Average IQ': 'Average IQ'},
    # title='Average IQ by Country',
)
# fig_geo.update_layout(title_x=0.5)
fig_geo.show()


In [None]:
fig_geo.write_image("literacy_rate_by_country.png", width=1000, height=600, scale=3)

In [None]:
import geopandas as gpd

# Load shapefile (adjust the path to where you extracted the data)
world = gpd.read_file("/Users/user/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

# Ensure ISO_A3 codes exist and are valid
world = world[world['ISO_A3_EH'] != '-99']  # Remove invalid entries
len(world)

In [None]:
from tabrel.utils.geo import build_border_map, share_common_border, get_connected_country_set, build_r_countries
border_map = build_border_map(world)

In [None]:

print(share_common_border("FRA", "DEU", border_map))  # True
print(share_common_border("FRA", "USA", border_map))  # False
print(share_common_border("MEX", "GTM", border_map))  # True


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
import torch
from tqdm import tqdm

feature_cols = ["Mean years of schooling - 2021"]
target_col = "Literacy Rate"

df_filtered = df_clean.dropna().copy()
df_filtered.set_index("ISO_alpha", inplace=True)

X = df_filtered[feature_cols].to_numpy()
y = df_filtered[target_col].to_numpy()


R, iso_list = build_r_countries(df_filtered, border_map)
all_isos = set(iso_list)

results = defaultdict(list)
for query_iso, val_iso in tqdm(
    list(combinations(iso_list, 2))
    # [("MNE", "EGY")],  # for execution speed
    ):
    max_query_size, max_val_size = 40, 30
    query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
    validate_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)

    if query_set & validate_set or len(query_set) < max_query_size - 5 or len(validate_set) < max_val_size - 5:
        continue  # skip overlapping sets and those with too small connected countries

    backgnd_set = all_isos - query_set - validate_set

    backgnd_indices = [i for i, iso in enumerate(iso_list) if iso in backgnd_set]
    query_indices = [i for i, iso in enumerate(iso_list) if iso in query_set]
    val_indices = [i for i, iso in enumerate(iso_list) if iso in validate_set]

    try:
        res = run_training(
            x=X, y=y, r=R,
            backgnd_indices=np.array(backgnd_indices),
            query_indices=np.array(query_indices),
            val_indices=np.array(val_indices),
            lr=LR,
            n_epochs=50,
            rel_as_feats=R,
        )
        # res["relnet"] = [*train_relnet(
        #     x=X,
        #     y=y,
        #     r=R,
        #     backgnd_indices=np.array(backgnd_indices),
        #     query_indices=np.array(query_indices),
        #     val_indices=val_indices,
        #     lr=.01,
        #     n_epochs=1000,
        #     progress_bar=False,
        #     print_loss=False,
        #     n_layers=1,
        #     num_heads=2,
        #     embed_dim=8,
        # )[:2], None]
    except Exception as e:
        print(e)
        continue
    
    r2_norel = res["rel=False;trainable_w=False;mlp=False"][1]

    if r2_norel < .2:
        print(f"bad R2 for no-rel: {r2_norel}")
        continue

    for k, v in res.items():
            results[k].append(v)

    label = f"\nQuery seed: {query_iso}, Validation seed: {val_iso}"
    print(label)

    # Get fitted model from rel=True
    fitted_rel = res["rel=True;trainable_w=False;mlp=False"][-1]
    fitted_norel = res["rel=False;trainable_w=False;mlp=False"][-1]

    # # Plot
    # x_feat = X[:, 0]
    # plt.figure(figsize=(6, 4))
    # plt.scatter(x_feat[query_indices], y[query_indices], label="Trial", color="blue")
    # plt.scatter(x_feat[val_indices], y[val_indices], label="Validate", color="orange")
    # plt.scatter(x_feat[backgnd_indices], y[backgnd_indices], label="Background", color="gray")

    # # Model prediction line (from min to max of x)
    # x_min, x_max = x_feat.min(), x_feat.max()
    # x_grid = torch.linspace(x_min, x_max, steps=200).unsqueeze(1)

    # x_back = torch.tensor(X[backgnd_indices], dtype=torch.float32)
    # y_back = torch.tensor(y[backgnd_indices], dtype=torch.float32)

    # y_grid_norel = fitted_norel.model(x_back, y_back, x_grid, torch.zeros((len(x_grid), len(backgnd_indices))))
    # # plt.plot(x_grid.numpy(), y_grid_norel.detach().numpy(), "--", label="Model fit (no rel)", color="red")

    # plt.xlabel(feature_cols[0])
    # plt.ylabel(target_col)
    # # plt.title(f"{query_iso} as Query, {val_iso} as Val")
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig(f"q{query_iso}_val{val_iso}_literacy.pdf")
    # plt.show()


In [None]:
to_df(results)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(5, 5), sharex=True)
labels = ("rel", 
          "norel", 
        #   "rel feats"
          )
for ax, (label, data) in zip(
    axes,
    (
        ("Validation MSE", [np.array(mses_rel), 
                            np.array(mses_norel), 
                            # np.array(mse_rel_feats)
                            ]),
        ("Validation $R^2$", [np.array(r2s_rel), 
                              np.array(r2s_norel),
                            #   np.array(r2s_rel_as_feats)
                                ]),
    ),
):
    plot_violin(ax, labels, data, title=label, color=colors[label])
    plt.tight_layout()
    plt.savefig("literacy_rate_violinplots.pdf")


In [None]:
cases_win = np.array(r2s_rel) > np.array(r2s_norel)
print(sum(cases_win), len(cases_win))

# Birds

In [None]:
path_species = Path(
    # kagglehub.dataset_download("mexwell/bird-genetic-diversity")
    "/Users/vzuev/.cache/kagglehub/datasets/mexwell/bird-genetic-diversity/versions/1"
    )
print(path_species)

birds_df = pd.read_csv(next(path_species.glob("*")))
sns.pairplot(birds_df)

In [None]:
sns.scatterplot(birds_df, x="Body mass", y="Breeding range size", hue="Allelic richness")

In [None]:
import plotly.express as px

fig = px.scatter_3d(
    birds_df,
    x="Body mass",
    y="Breeding range size",
    z="Allelic richness",
    color="Allelic richness",  # You can use another column if you prefer
    opacity=0.7
)

fig.update_layout(
    scene = dict(
        xaxis_type="log",
        yaxis_type="log",
        zaxis_type="log"
    )
)

fig.show()

In [None]:
birds_df["bm_log"] = np.log(birds_df["Body mass"])
birds_df["range_log"] = np.log(birds_df["Breeding range size"])
birds_df["richness_log"] = np.log(birds_df["Allelic richness"])
birds_df = birds_df[["Species", "bm_log", "range_log", "richness_log"]]
birds_df

In [None]:
path_taxa = Path(kagglehub.dataset_download("willianoliveiragibin/animal-analyzing"))
df_taxa = pd.read_csv(next(path_taxa.glob("*")))
df_taxa = df_taxa[[# "Kingdom", "Subphylum", "Class", # all birds belong to the same class - Aves
                    "Order", "Family", "Genus", "Species"]]
birds_df_merged = pd.merge(birds_df, df_taxa, on="Species", how="inner")
birds_df_merged

In [None]:
orders = birds_df_merged["Order"].unique()
orders

In [None]:
import numpy as np
from pycirclize import Circos
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import to_hex

# --- Preprocessing ---
# Filter orders with >=4 species
species_counts = birds_df_merged.groupby("Order")["Species"].count()
valid_orders = species_counts[species_counts >= 9].index


# Filter the dataframe
merged_df = birds_df_merged.copy()
merged_df = merged_df[merged_df["Order"].isin(valid_orders)]

# Sort species within each Order by Family and Genus (and optionally Species)
merged_df = merged_df.sort_values(["Order", "Family", "Genus", "Species"])

merged_df["bm_log"] -= merged_df["bm_log"].min()
merged_df["range_log"] -= merged_df["range_log"].min()
merged_df["richness_log"] -= merged_df["richness_log"].min()

# Shared value scales
bm_min, bm_max = merged_df['bm_log'].min(), merged_df['bm_log'].max()
range_min, range_max = merged_df['range_log'].min(), merged_df['range_log'].max()
richness_min, richness_max = merged_df['richness_log'].min(), merged_df['richness_log'].max()

# Sector = Order
sector_sizes = merged_df.groupby('Order').size().to_dict()
circos = Circos(sector_sizes, space=2)

# Use local color maps per sector for better contrast
from collections import defaultdict
from matplotlib.cm import get_cmap

# Store maps: (order -> {family: color}), etc.
order_family_colors = defaultdict(dict)
order_genus_colors = defaultdict(dict)

family_cmaps = [cm.get_cmap("tab10"), cm.get_cmap("Set1"), cm.get_cmap("Dark2")]
genus_cmaps = [cm.get_cmap("tab20"), cm.get_cmap("Paired"), cm.get_cmap("tab20c")]

# Cycle through a few high-contrast maps
family_map_count = len(family_cmaps)
genus_map_count = len(genus_cmaps)

for sector_idx, order in enumerate(merged_df["Order"].unique()):
    sub_df = merged_df[merged_df["Order"] == order]
    
    families = sub_df["Family"].unique()
    genera = sub_df["Genus"].unique()
    
    fam_cmap = family_cmaps[sector_idx % family_map_count]
    gen_cmap = genus_cmaps[sector_idx % genus_map_count]
    
    for i, fam in enumerate(families):
        order_family_colors[order][fam] = to_hex(fam_cmap(i % fam_cmap.N))
    
    for i, gen in enumerate(genera):
        order_genus_colors[order][gen] = to_hex(gen_cmap(i % gen_cmap.N))


# --- Build Circos Plot ---
for sector in circos.sectors:
    order = sector.name
    sub_df = merged_df[merged_df['Order'] == order].reset_index(drop=True)
    x = np.arange(len(sub_df)) + 0.5

    # Family track
    fam_colors = sub_df["Family"].map(order_family_colors[order]).values
    fam_track = sector.add_track((98, 100))
    fam_track.axis()
    fam_track.bar(x, np.ones_like(x), color=fam_colors, width=1)

    # Genus track
    gen_colors = sub_df["Genus"].map(order_genus_colors[order]).values
    gen_track = sector.add_track((95, 97))
    gen_track.axis()
    gen_track.bar(x, np.ones_like(x), color=gen_colors, width=1)

    bm_track = sector.add_track((30, 49))
    bm_track.axis()
    bm_track.bar(x, sub_df['bm_log'].values, color='blue', width=0.6, vmin=bm_min, vmax=bm_max)

    range_track = sector.add_track((50, 69))
    range_track.axis()
    range_track.bar(x, sub_df['range_log'].values, color='green', width=0.6, vmin=range_min, vmax=range_max)

    richness_track = sector.add_track((70, 94))
    richness_track.axis()
    richness_track.bar(x, sub_df['richness_log'].values, color='red', width=0.6, vmin=richness_min, vmax=richness_max)


    # Add order label
    sector.text(order, 
                r=102, 
                size=8)

# Plot the figure
fig = circos.plotfig(dpi=300)

# Add variable legend
legend_patches = [
    Patch(color='blue'),
    Patch(color='green'),
    Patch(color='red')
]
fig.legend(legend_patches, 
           ['Body mass (log)', 'Breeding range (log)', 'Genetic richness (log)'],
           loc='upper left', 
           fontsize=12)
plt.tight_layout()

fig.savefig("circos_species_by_order.pdf", bbox_inches='tight')


In [None]:
import numpy as np

orders = merged_df["Order"].astype(str).values
orders_int = pd.factorize(orders)[0]
families = merged_df["Family"].astype(str).values
families_int = pd.factorize(families)[0]
n = len(merged_df)

# Build r_birds matrix
r_birds = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        if families[i] == families[j]:
            r_birds[i, j] = 1.0
        elif orders[i] == orders[j]:
            r_birds[i, j] = 0.5
        else:
            r_birds[i, j] = 0.0

r_as_feats = np.array([orders_int, families_int]).T

birds_metrics = defaultdict(list)
for seed in tqdm(range(30), "seeds"):
    np.random.seed(seed)
    indices = np.random.permutation(n)
    n_test = int(0.3 * n)
    n_query = int(0.3 * n)
    test_indices = indices[:n_test]
    query_indices = indices[n_test:n_test + n_query]
    back_indices = indices[n_test + n_query:]

    # Prepare data
    X = merged_df[["bm_log", "range_log"]].to_numpy()
    y = merged_df["richness_log"].to_numpy()

    res_birds = run_training(
        x=X,
        y=y,
        r=r_birds,
        backgnd_indices=back_indices,
        query_indices=query_indices,
        val_indices=test_indices,
        lr=0.05,
        n_epochs=200,
        rel_as_feats=r_as_feats,
    )
    torch.manual_seed(seed)
    res_birds["relnet"] = train_relnet(
        x=X,
        y=y,
        r=r_birds,
        backgnd_indices=back_indices,
        query_indices=query_indices,
        val_indices=test_indices,
        lr=0.002,
        n_epochs=2000,
        progress_bar=False,
        print_loss=False,
        n_layers=2,
        num_heads=2,
        embed_dim=32,
        periodic_embed_dim=None,
    )

    for k, v in res_birds.items():
        birds_metrics[k].append(v)

In [None]:
birds_labels = ("rel", "nrel", "lgb", "rel-ft", "lgb-r", "relnet")

birds_metrics_all = {
    "Validation MSE": [np.array([x[0] for x in birds_metrics[k]]) for k in birds_metrics.keys()],
    "Validation $R^2$": [np.array([x[1] for x in birds_metrics[k]]) for k in birds_metrics.keys()],
}

fig, axes = plt.subplots(1, len(birds_metrics_all), figsize=(5, 5), sharex=True)

for ax, (label, data) in zip(axes, birds_metrics_all.items()):
    plot_violin(ax, birds_labels, data, title=label, color=colors[label])

plt.tight_layout()
plt.savefig("birds_violinplots.pdf")
plt.show()

In [None]:
metrics_mean(birds_metrics_all, birds_labels)

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf

# Make sure categories are treated as such
merged_df["Order"] = merged_df["Order"].astype("category")
merged_df["Family"] = merged_df["Family"].astype("category")
merged_df["Genus"] = merged_df["Genus"].astype("category")

# Order-level model
model_order = smf.ols("richness_log ~ C(Order)", data=merged_df).fit()
anova_order = sm.stats.anova_lm(model_order, typ=2)
print(anova_order)

# Family-level model
model_family = smf.ols("richness_log ~ C(Family)", data=merged_df).fit()
anova_family = sm.stats.anova_lm(model_family, typ=2)
print(anova_family)

# Genus-level model
model_genus = smf.ols("richness_log ~ C(Genus)", data=merged_df).fit()
anova_genus = sm.stats.anova_lm(model_genus, typ=2)
print(anova_genus)

In [None]:
# Mixed-effects model
model = smf.mixedlm(
    "richness_log ~ bm_log + range_log",         # fixed effects
    data=merged_df,
    groups=merged_df["Order"],                   # main grouping
    re_formula="1",                              # random intercepts
    vc_formula={                                 # variance components
        "Family": "0 + C(Family)",
        "Genus": "0 + C(Genus)"
    }
)

result = model.fit()
print(result.summary())

In [None]:
import matplotlib.pyplot as plt

coefs = result.fe_params
conf_int = result.conf_int()

plt.figure(figsize=(6, 4))
plt.errorbar(coefs.index, coefs.values, 
            yerr=(conf_int[1][:3] - coefs.values),
             fmt='o', capsize=5, color='black')
plt.axhline(0, color='gray', linestyle='--')
plt.title("Fixed Effects Estimates")
plt.ylabel("Coefficient Estimate")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [None]:
var_comps = result.cov_re
vc = result.vcomp

# Print variance estimates by group
print("Random effect variances:")
print(vc)

In [None]:
import seaborn as sns

fitted = result.fittedvalues
resid = result.resid

plt.figure(figsize=(6, 4))
sns.scatterplot(x=fitted, y=resid)
plt.axhline(0, linestyle="--", color="gray")
plt.xlabel("Fitted values")
plt.ylabel("Residuals")
plt.title("Residuals vs Fitted")
plt.tight_layout()
plt.show()


In [None]:
import scipy.stats as stats
import statsmodels.api as sm

sm.qqplot(resid, line='s')
plt.title("QQ Plot of Residuals")
plt.tight_layout()
plt.show()
