In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import Final
import kagglehub
from pycountry import countries

import pandas as pd
import plotly.express as px
import seaborn as sns

path = kagglehub.dataset_download("amirhosseinmirzaie/countries-life-expectancy")
df = pd.read_csv(list(Path(path).glob("*.csv"))[0])

class NoneCountry:
    alpha_3 = None
    
df["ISO_alpha"] = df["Country"].apply(lambda x: countries.get(name=x, default=NoneCountry).alpha_3)
df

In [None]:
df_2015 = df[df["Year"] == 2015]
features: Final[list[str]] = ["Hepatitis B", "Polio", "Diphtheria", "HIV/AIDS"]
response: Final[str] = "Life expectancy"
sns.pairplot(df_2015[features + [response]],hue=response)

px.choropleth(
    df_2015,
    locations="ISO_alpha",
    color="Life expectancy"
)

In [None]:
import geopandas as gpd
from tabrel.utils.geo import build_border_map, build_r_countries

world = gpd.read_file("/Users/vzuev/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

world = world[world['ISO_A3_EH'] != '-99']
border_map = build_border_map(world)

In [None]:
df_2015.set_index("ISO_alpha", inplace=True)
R, iso_list = build_r_countries(df_2015, border_map)
y = df_2015[response].to_numpy()
all_isos = set(iso_list)

In [None]:
import numpy as np
from tabrel.train import train_relnet
from tabrel.utils.geo import get_connected_country_set

query_iso, val_iso = "MNE", "EGY"
max_query_size = max_val_size = 35
query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
val_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)

backgnd_set = all_isos - query_set - val_set

backgnd_indices = [i for i, iso in enumerate(iso_list) if iso in backgnd_set]
query_indices = [i for i, iso in enumerate(iso_list) if iso in query_set]
val_indices = [i for i, iso in enumerate(iso_list) if iso in val_set]

X = df_2015[["Polio", "HIV/AIDS", "Diphtheria"]].to_numpy()
results_relnet = train_relnet(
    x=X,
    y=y,
    r=R,
    backgnd_indices=np.array(backgnd_indices),
    query_indices=np.array(query_indices),
    val_indices=np.array(val_indices),
    lr=.01,
    n_epochs=1000,
    periodic_embed_dim=None,
    progress_bar=False,
    print_loss=True,
    n_layers=2,
    num_heads=2,
    embed_dim=8,
)

In [None]:
results_relnet

# TODO try random query, val sets
# TODO compare with LightGBM; select important features using LightGBM