In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import kagglehub

path_species = Path(
     # kagglehub.dataset_download("mexwell/bird-genetic-diversity")
     "/Users/user/.cache/kagglehub/datasets/mexwell/bird-genetic-diversity/versions/1"
    )
print(path_species)

birds_df = pd.read_csv(next(path_species.glob("*")))
# sns.pairplot(birds_df)

In [None]:
# sns.scatterplot(birds_df, x="Body mass", y="Breeding range size", hue="Allelic richness")

In [None]:
import plotly.express as px

fig = px.scatter_3d(
    birds_df,
    x="Body mass",
    y="Breeding range size",
    z="Allelic richness",
    color="Allelic richness",  # You can use another column if you prefer
    opacity=0.7
)

fig.update_layout(
    scene = dict(
        xaxis_type="log",
        yaxis_type="log",
        zaxis_type="log"
    )
)

fig.show()

In [None]:
birds_df["bm_log"] = np.log(birds_df["Body mass"])
birds_df["range_log"] = np.log(birds_df["Breeding range size"])
birds_df["richness_log"] = np.log(birds_df["Allelic richness"])
birds_df = birds_df[["Species", "bm_log", "range_log", "richness_log"]]
birds_df.head()

In [None]:
path_taxa = Path(kagglehub.dataset_download("willianoliveiragibin/animal-analyzing"))
df_taxa = pd.read_csv(next(path_taxa.glob("*")))
df_taxa = df_taxa[[# "Kingdom", "Subphylum", "Class", # all birds belong to the same class - Aves
                    "Order", "Family", "Genus", "Species"]]
birds_df_merged = pd.merge(birds_df, df_taxa, on="Species", how="inner")
birds_df_merged.head()

In [None]:
import numpy as np
from pycirclize import Circos
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import to_hex

# --- Preprocessing ---
# Filter orders with >=4 species
species_counts = birds_df_merged.groupby("Order")["Species"].count()
valid_orders = species_counts[species_counts >= 9].index


# Filter the dataframe
merged_df = birds_df_merged.copy()
merged_df = merged_df[merged_df["Order"].isin(valid_orders)]

# Sort species within each Order by Family and Genus (and optionally Species)
merged_df = merged_df.sort_values(["Order", "Family", "Genus", "Species"])

merged_df["bm_log"] -= merged_df["bm_log"].min()
merged_df["range_log"] -= merged_df["range_log"].min()
merged_df["richness_log"] -= merged_df["richness_log"].min()

# # Shared value scales
# bm_min, bm_max = merged_df['bm_log'].min(), merged_df['bm_log'].max()
# range_min, range_max = merged_df['range_log'].min(), merged_df['range_log'].max()
# richness_min, richness_max = merged_df['richness_log'].min(), merged_df['richness_log'].max()

# # Sector = Order
# sector_sizes = merged_df.groupby('Order').size().to_dict()
# circos = Circos(sector_sizes, space=2)

# # Use local color maps per sector for better contrast
# from collections import defaultdict
# from matplotlib.cm import get_cmap

# # Store maps: (order -> {family: color}), etc.
# order_family_colors = defaultdict(dict)
# order_genus_colors = defaultdict(dict)

# family_cmaps = [cm.get_cmap("tab10"), cm.get_cmap("Set1"), cm.get_cmap("Dark2")]
# genus_cmaps = [cm.get_cmap("tab20"), cm.get_cmap("Paired"), cm.get_cmap("tab20c")]

# # Cycle through a few high-contrast maps
# family_map_count = len(family_cmaps)
# genus_map_count = len(genus_cmaps)

# for sector_idx, order in enumerate(merged_df["Order"].unique()):
#     sub_df = merged_df[merged_df["Order"] == order]
    
#     families = sub_df["Family"].unique()
#     genera = sub_df["Genus"].unique()
    
#     fam_cmap = family_cmaps[sector_idx % family_map_count]
#     gen_cmap = genus_cmaps[sector_idx % genus_map_count]
    
#     for i, fam in enumerate(families):
#         order_family_colors[order][fam] = to_hex(fam_cmap(i % fam_cmap.N))
    
#     for i, gen in enumerate(genera):
#         order_genus_colors[order][gen] = to_hex(gen_cmap(i % gen_cmap.N))


# # --- Build Circos Plot ---
# for sector in circos.sectors:
#     order = sector.name
#     sub_df = merged_df[merged_df['Order'] == order].reset_index(drop=True)
#     x = np.arange(len(sub_df)) + 0.5

#     # Family track
#     fam_colors = sub_df["Family"].map(order_family_colors[order]).values
#     fam_track = sector.add_track((98, 100))
#     fam_track.axis()
#     fam_track.bar(x, np.ones_like(x), color=fam_colors, width=1)

#     # Genus track
#     gen_colors = sub_df["Genus"].map(order_genus_colors[order]).values
#     gen_track = sector.add_track((95, 97))
#     gen_track.axis()
#     gen_track.bar(x, np.ones_like(x), color=gen_colors, width=1)

#     bm_track = sector.add_track((30, 49))
#     bm_track.axis()
#     bm_track.bar(x, sub_df['bm_log'].values, color='blue', width=0.6, vmin=bm_min, vmax=bm_max)

#     range_track = sector.add_track((50, 69))
#     range_track.axis()
#     range_track.bar(x, sub_df['range_log'].values, color='green', width=0.6, vmin=range_min, vmax=range_max)

#     richness_track = sector.add_track((70, 94))
#     richness_track.axis()
#     richness_track.bar(x, sub_df['richness_log'].values, color='red', width=0.6, vmin=richness_min, vmax=richness_max)


#     # Add order label
#     sector.text(order, 
#                 r=102, 
#                 size=8)

# # Plot the figure
# fig = circos.plotfig(dpi=300)

# # Add variable legend
# legend_patches = [
#     Patch(color='blue'),
#     Patch(color='green'),
#     Patch(color='red')
# ]
# fig.legend(legend_patches, 
#            ['Body mass (log)', 'Breeding range (log)', 'Genetic richness (log)'],
#            loc='upper left', 
#            fontsize=12)
# plt.tight_layout()

# fig.savefig("circos_species_by_order.pdf", bbox_inches='tight')


In [None]:
from collections import defaultdict

from tqdm import tqdm

from tabrel.train import train_relnet

orders = merged_df["Order"].astype(str).values
orders_int = pd.factorize(orders)[0]
families = merged_df["Family"].astype(str).values
families_int = pd.factorize(families)[0]
n = len(merged_df)

# Build r_birds matrix
r_birds = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        if families[i] == families[j]:
            r_birds[i, j] = 1.0
        elif orders[i] == orders[j]:
            r_birds[i, j] = 0.5
        else:
            r_birds[i, j] = 0.0

# r_as_feats = np.array([orders_int, families_int]).T

X = merged_df[["bm_log", "range_log"]].to_numpy()
y = merged_df["richness_log"].to_numpy()
n_test = n_query = n // 3

In [None]:
r_birds

In [None]:
from tabrel.utils.plot import calc_diffs

calc_diffs(y, r_birds, threshold=.6, plot=True, out_fname="birds_hists.png")

In [None]:
from datetime import datetime
from typing import Final

import optuna

from tabrel.optuna import RelTrainData, build_objective_relnet

seed = 0
n_epochs_optuna: Final[int] = 1000
sqlite_path: Final[str] = "sqlite:///db.sqlite3"
def objective(trial: optuna.Trial) -> float:
    global seed
    seed += 1

    np.random.seed(seed)
    indices = np.random.permutation(n)

    data = RelTrainData(
        r=r_birds, 
        x=X,
        y=y,
        val_ids=indices[:n_test],
        query_ids=indices[n_test:n_test + n_query],
        back_ids=indices[n_test + n_query:],
        )
    return build_objective_relnet(
        trial,
        data,
        n_epochs_optuna,
        seed=seed,
    )

study = optuna.create_study(
    direction="maximize",
    study_name=f"birdsRelnet_{datetime.now()}",
    storage=sqlite_path
)

study.optimize(objective, n_trials=100)

In [None]:
birds_metrics = defaultdict(list)
for seed in tqdm(range(15), "seeds"):
    np.random.seed(seed)
    indices = np.random.permutation(n)
    test_indices = indices[:n_test]
    query_indices = indices[n_test:n_test + n_query]
    back_indices = indices[n_test + n_query:]

    # res_birds = run_training(
    #     x=X,
    #     y=y,
    #     r=r_birds,
    #     backgnd_indices=back_indices,
    #     query_indices=query_indices,
    #     val_indices=test_indices,
    #     lr=0.05,
    #     n_epochs=200,
    #     rel_as_feats=r_as_feats,
    #     mlp_config=MlpConfig(
    #                 in_dim=X.shape[1],
    #                 hidden_dim=10,
    #                 out_dim=20,
    #                 dropout=.19,
    #             ),
    # )
    res_birds = {}
    torch.manual_seed(seed)
    res_birds["relnet"] = train_relnet(
        x=X,
        y=y,
        r=r_birds,
        backgnd_indices=back_indices,
        query_indices=query_indices,
        val_indices=test_indices,
        lr=0.004,
        n_epochs=1000,
        progress_bar=False,
        print_loss=False,
        n_layers=1,
        num_heads=17,
        embed_dim=17 * 18,
        periodic_embed_dim=18,
        dropout=0.444,
    )

    for k, v in res_birds.items():
        birds_metrics[k].append(v)

In [None]:
from tabrel.utils.misc import to_df


to_df(birds_metrics, decimal_places=3)

## 200 epochs

MSE, $R^2$

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>label</th>
      <th>means</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False</td>
      <td>[0.236, 0.176]</td>
      <td>[0.032, 0.09]</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=False;mlp=False</td>
      <td>[0.235, 0.178]</td>
      <td>[0.033, 0.096]</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=False;trainable_w=True;mlp=False</td>
      <td>[0.241, 0.16]</td>
      <td>[0.035, 0.097]</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=False;trainable_w=False;mlp=False</td>
      <td>[0.24, 0.162]</td>
      <td>[0.035, 0.088]</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=False;mlp=True</td>
      <td>[0.242, 0.157]</td>
      <td>[0.039, 0.099]</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=True;trainable_w=False;mlp=True</td>
      <td>[0.237, 0.172]</td>
      <td>[0.032, 0.097]</td>
    </tr>
    <tr>
      <th>6</th>
      <td>lgb</td>
      <td>[0.257, 0.102]</td>
      <td>[0.032, 0.083]</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel-as-feats</td>
      <td>[0.226, 0.214]</td>
      <td>[0.036, 0.083]</td>
    </tr>
    <tr>
      <th>8</th>
      <td>lgb-rel</td>
      <td>[0.257, 0.104]</td>
      <td>[0.032, 0.08]</td>
    </tr>
  </tbody>
</table>
</div>

## 5000 epochs

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>label</th>
      <th>means</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False</td>
      <td>[0.2427419444467154, 0.1524371062224336]</td>
      <td>[0.03563937544733425, 0.101662067855152]</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=False;mlp=False</td>
      <td>[0.24646459393528103, 0.13737796495820342]</td>
      <td>[0.033744847264236495, 0.11001813251445133]</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=False;trainable_w=True;mlp=False</td>
      <td>[0.25349090172246797, 0.1168885225667323]</td>
      <td>[0.03916675879074428, 0.09949918383262824]</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=False;trainable_w=False;mlp=False</td>
      <td>[0.24547522287283957, 0.1447009572164006]</td>
      <td>[0.04071987844372341, 0.10971509173397336]</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=False;mlp=True</td>
      <td>[0.250158587339188, 0.12678941098391033]</td>
      <td>[0.0381792052692715, 0.10902320463022831]</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=True;trainable_w=False;mlp=True</td>
      <td>[0.2476819328807885, 0.1366361133491969]</td>
      <td>[0.04024182669239522, 0.10811121681143643]</td>
    </tr>
    <tr>
      <th>6</th>
      <td>lgb</td>
      <td>[0.25725628177916493, 0.10161220322981238]</td>
      <td>[0.03158159362994203, 0.08296224101776432]</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel-as-feats</td>
      <td>[0.22635312455947051, 0.21323706406913023]</td>
      <td>[0.035756882051606646, 0.08270302382834073]</td>
    </tr>
    <tr>
      <th>8</th>
      <td>lgb-rel</td>
      <td>[0.25652954798214084, 0.10432907175547337]</td>
      <td>[0.031505477393880654, 0.08041760906556603]</td>
    </tr>
  </tbody>
</table>
</div>

## 1000 epochs, RelNet

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>label</th>
      <th>means</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>relnet</td>
      <td>[0.28, 0.034]</td>
      <td>[0.032, 0.137]</td>
    </tr>
  </tbody>
</table>
</div>

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf

# Make sure categories are treated as such
merged_df["Order"] = merged_df["Order"].astype("category")
merged_df["Family"] = merged_df["Family"].astype("category")
merged_df["Genus"] = merged_df["Genus"].astype("category")

# Order-level model
model_order = smf.ols("richness_log ~ C(Order)", data=merged_df).fit()
anova_order = sm.stats.anova_lm(model_order, typ=2)
print(anova_order)

# Family-level model
model_family = smf.ols("richness_log ~ C(Family)", data=merged_df).fit()
anova_family = sm.stats.anova_lm(model_family, typ=2)
print(anova_family)

# Genus-level model
model_genus = smf.ols("richness_log ~ C(Genus)", data=merged_df).fit()
anova_genus = sm.stats.anova_lm(model_genus, typ=2)
print(anova_genus)

In [None]:
# Mixed-effects model
model = smf.mixedlm(
    "richness_log ~ bm_log + range_log",         # fixed effects
    data=merged_df,
    groups=merged_df["Order"],                   # main grouping
    re_formula="1",                              # random intercepts
    vc_formula={                                 # variance components
        "Family": "0 + C(Family)",
        "Genus": "0 + C(Genus)"
    }
)

result = model.fit()
print(result.summary())

In [None]:
import matplotlib.pyplot as plt

coefs = result.fe_params
conf_int = result.conf_int()

plt.figure(figsize=(6, 4))
plt.errorbar(coefs.index, coefs.values, 
            yerr=(conf_int[1][:3] - coefs.values),
             fmt='o', capsize=5, color='black')
plt.axhline(0, color='gray', linestyle='--')
plt.title("Fixed Effects Estimates")
plt.ylabel("Coefficient Estimate")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [None]:
var_comps = result.cov_re
vc = result.vcomp

# Print variance estimates by group
print("Random effect variances:")
print(vc)

In [None]:
import seaborn as sns

fitted = result.fittedvalues
resid = result.resid

plt.figure(figsize=(6, 4))
sns.scatterplot(x=fitted, y=resid)
plt.axhline(0, linestyle="--", color="gray")
plt.xlabel("Fitted values")
plt.ylabel("Residuals")
plt.title("Residuals vs Fitted")
plt.tight_layout()
plt.show()


In [None]:
import scipy.stats as stats
import statsmodels.api as sm

sm.qqplot(resid, line='s')
plt.title("QQ Plot of Residuals")
plt.tight_layout()
plt.show()
