In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import Final
from itertools import combinations

import kagglehub
from pycountry import countries

import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from tqdm import tqdm
import geopandas as gpd

from tabrel.benchmark.nw_regr import run_training
from tabrel.train import train_relnet
from tabrel.utils.geo import get_connected_country_set, build_border_map, build_r_countries



path = kagglehub.dataset_download("amirhosseinmirzaie/countries-life-expectancy")
df = pd.read_csv(list(Path(path).glob("*.csv"))[0])
features: Final[list[str]] = ["Hepatitis B", "Polio", "Diphtheria", "HIV/AIDS", "BMI"]
response: Final[str] = "Life expectancy"

class NoneCountry:
    alpha_3 = None
    
df["ISO_alpha"] = df["Country"].apply(lambda x: countries.get(name=x, default=NoneCountry).alpha_3)
df

In [None]:
df["Year"].unique()

In [None]:
df_2015 = df[df["Year"] == 2015]

sns.pairplot(df_2015[features + [response]],hue=response)

fig_choropleth = px.choropleth(
    df_2015,
    locations="ISO_alpha",
    color="Life expectancy"
)
fig_choropleth.write_image("life_expectancy_choropleth.png", width=1000, height=600, scale=3)

In [None]:
df_2015.set_index("ISO_alpha", inplace=True)

In [None]:
world = gpd.read_file("/Users/vzuev/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

world = world[world['ISO_A3_EH'] != '-99']
border_map = build_border_map(world)

In [None]:
from collections import defaultdict
from itertools import product

from sklearn.metrics import mean_squared_error, r2_score
import torch.nn as nn

from tabrel.benchmark.nw_regr import NwModelConfig, RelNwRegr



max_query_size = max_val_size = 40
min_query_size = min_val_size = 10

n_runs = 0
metrics = defaultdict(list)

seed: Final[int] = 42
np.random.seed(seed)
while n_runs < 15:
    year = np.random.choice(df["Year"])
    df_year = df[df["Year"] == year]
    df_year.set_index("ISO_alpha", inplace=True)
    R, iso_list = build_r_countries(df_year, border_map)
    y = df_year[response].to_numpy()
    all_isos = set(iso_list)
    iso_list = list(all_isos) # remove Nones
    query_iso = np.random.choice(iso_list)  # starting node for query set
    val_iso = np.random.choice(iso_list)  # starting node for validation set

    query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
    val_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)

    if len(query_set) < min_query_size or len(val_set) < min_val_size:
        continue

    if query_set & val_set:
        continue

    n_runs += 1
    backgnd_set = all_isos - query_set - val_set

    backgnd_indices = np.array([i for i, iso in enumerate(iso_list) if iso in backgnd_set])
    query_indices = np.array([i for i, iso in enumerate(iso_list) if iso in query_set])
    val_indices = np.array([i for i, iso in enumerate(iso_list) if iso in val_set])

    x_initial = df_year[
    ["Polio", "HIV/AIDS", "Diphtheria", "under-five deaths"]
    # ["Alcohol"]
    # ["thinness  1-19 years"],
    # ["BMI"],
    ].to_numpy()

    torch.manual_seed(seed)
    try:
        results_relnet = train_relnet(
            x=x_initial,
            y=y,
            r=R,
            backgnd_indices=np.array(backgnd_indices),
            query_indices=np.array(query_indices),
            val_indices=np.array(val_indices),
            lr=.01,
            n_epochs=1500,
            periodic_embed_dim=None,
            progress_bar=True,
            print_loss=False,
            n_layers=2,
            num_heads=2,
            embed_dim=8,
        )
        relnet_mse, relnet_r2 = results_relnet[:2]
        metrics["relnet_mse"].append(relnet_mse)
        metrics["relnet_r2"].append(relnet_r2)
    except Exception as e:
         print(e)

    continue

    n_samples, n_feats = x_initial.shape

    x_extended = np.concatenate([x_initial, R], axis=1)

    for x, x_label in (
         (x_initial, "xInit"), 
        #  (x_extended, "xExtended")
         ):

        x_mean = np.mean(x, axis=0, keepdims=True)
        x_std = np.std(x, axis=0, keepdims=True)
        x_norm = (x - x_mean) / x_std

        r_torch = torch.Tensor(R)
        x_torch = torch.Tensor(x_norm)
        y_torch = torch.Tensor(y)

        n_back = n_query = n_samples // 3
        n_test = n_samples - (n_back + n_query)
        x_back, y_back = x_torch[:n_back], y_torch[:n_back]
        x_q, y_q = x_torch[n_back : n_query + n_back], y_torch[n_back : n_query + n_back]
        x_val, y_val = x_torch[n_back + n_query :], y_torch[n_back + n_query :]
        r_q_b = r_torch[n_back : n_query + n_back, :n_back]
        
        if x_label == "xExtended":
            r_q_b = torch.zeros_like(r_q_b)
        
        r_val_b = r_torch[n_back + n_query :, :n_back]
        # x_train, y_train = x_torch[: n_back + n_query], y_torch[: n_back + n_query]
        # r_val_train = r_torch[n_back + n_query :, : n_back + n_query]

        if x_label == "xExtended":
                r_val_train = torch.zeros_like(r_val_train)

        inds_back = np.array(range(n_back))
        inds_q = np.array(range(n_query)) + n_back
        inds_val = np.array(range(len(x) - n_query - n_back)) + n_query + n_back

        try:
             res = run_training(
                  x=x, y=y, r=R,
                  backgnd_indices=inds_back,
                  query_indices=inds_q,
                  val_indices=inds_val,
                  lr=.005,
                  n_epochs=100,
                  rel_as_feats=R,
             )
             for k, v in res.items():
                  mse, r2 = v[:2]
                  metrics[f"{k}_mse"].append(mse)
                  metrics[f"{k}_r2"].append(r2)
        except Exception as e:
             print(e)


In [None]:
results_stats = []
for k, v in metrics.items():
    results_stats.append({"name": k, 
                          "mean": round(np.mean(v), 2),
                          "std": round(np.std(v), 2)})
    print(f"{k}: {np.mean(v):.2f} & {np.std(v):.2f}")
# pd.DataFrame(results_stats)

In [None]:
for k, v in metrics.items():
    if k.startswith("xInit") and k.endswith("mse"):
        print(k, np.mean(v))

In [None]:
metrics["xInit_sigma0.5_rscale5_val_r2"]

In [None]:
# TODO try parameters grid