In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import Final
import kagglehub
from pycountry import countries

import pandas as pd
import plotly.express as px
import seaborn as sns

path = kagglehub.dataset_download("amirhosseinmirzaie/countries-life-expectancy")
df = pd.read_csv(list(Path(path).glob("*.csv"))[0])

class NoneCountry:
    alpha_3 = None
    
df["ISO_alpha"] = df["Country"].apply(lambda x: countries.get(name=x, default=NoneCountry).alpha_3)
df

In [None]:
df_2015 = df[df["Year"] == 2015]
features: Final[list[str]] = ["Hepatitis B", "Polio", "Diphtheria", "HIV/AIDS", "BMI"]
response: Final[str] = "Life expectancy"

sns.pairplot(df_2015[features + [response]],hue=response)

px.choropleth(
    df_2015,
    locations="ISO_alpha",
    color="Life expectancy"
)

In [None]:
df_2015.set_index("ISO_alpha", inplace=True)

In [None]:
import geopandas as gpd
from tabrel.utils.geo import build_border_map, build_r_countries

world = gpd.read_file("/Users/vzuev/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

world = world[world['ISO_A3_EH'] != '-99']
border_map = build_border_map(world)

In [None]:
R, iso_list = build_r_countries(df_2015, border_map)
y = df_2015[response].to_numpy()
all_isos = set(iso_list)
iso_list = list(all_isos) # remove Nones

In [None]:
from itertools import combinations

import numpy as np
import torch
from tabrel.benchmark.nw_regr import run_training
from tabrel.train import train_relnet
from tabrel.utils.geo import get_connected_country_set
from tqdm import tqdm


max_query_size = max_val_size = 40
min_query_size = min_val_size = 10

# np.random.seed(42)
# iso_perm = np.random.permutation(iso_list)
# query_set = set(iso_perm[:max_query_size])
# val_set = set(iso_perm[max_query_size: max_query_size + max_val_size])

X = df_2015[
    ["Polio", "HIV/AIDS", "Diphtheria", "under-five deaths"]
    # ["Alcohol"]
    # ["thinness  1-19 years"],
    # ["BMI"],
    ].to_numpy()

results = {}

for query_iso, val_iso in (
    ("MNE", "EGY"),
    ("TCD", "PER"),
    ("AFG", "ESP"),
    ): #tqdm(list(combinations(all_isos, 2))): - wrong! overlaps may occur
     
    query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
    val_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)

    if len(query_set) < min_query_size or len(val_set) < min_val_size:
        continue

    backgnd_set = all_isos - query_set - val_set

    backgnd_indices = np.array([i for i, iso in enumerate(iso_list) if iso in backgnd_set])
    query_indices = np.array([i for i, iso in enumerate(iso_list) if iso in query_set])
    val_indices = np.array([i for i, iso in enumerate(iso_list) if iso in val_set])

    res = run_training(
        x=X, y=y, r=R,
        backgnd_indices=backgnd_indices,
        query_indices=query_indices,
        val_indices=val_indices,
        lr=0.006,
        n_epochs=100,

    )

    torch.manual_seed(42)
    results_relnet = train_relnet(
        x=X,
        y=y,
        r=R,
        backgnd_indices=np.array(backgnd_indices),
        query_indices=np.array(query_indices),
        val_indices=np.array(val_indices),
        lr=.01,
        n_epochs=1500,
        periodic_embed_dim=None,
        progress_bar=False,
        print_loss=False,
        n_layers=2,
        num_heads=2,
        embed_dim=8,
    )
    res["TabRel"] = results_relnet
    results[(query_iso, val_iso)] = res


In [None]:
for key in ("rel=True", "rel=False", "lgb", "TabRel"):
    print(key, "r2:", [val[key][1] for val in results.values()])

In [None]:
# TODO try parameters grid