In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import Final
from itertools import combinations

import kagglehub
from pycountry import countries

import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from tqdm import tqdm
import geopandas as gpd

from tabrel.benchmark.nw_regr import run_training
from tabrel.train import train_relnet
from tabrel.utils.geo import get_connected_country_set, build_border_map, build_r_countries



path = kagglehub.dataset_download("amirhosseinmirzaie/countries-life-expectancy")
df = pd.read_csv(list(Path(path).glob("*.csv"))[0])
features: Final[list[str]] = ["Hepatitis B", "Polio", "Diphtheria", "HIV/AIDS", "BMI"]
response: Final[str] = "Life expectancy"

class NoneCountry:
    alpha_3 = None
    
df["ISO_alpha"] = df["Country"].apply(lambda x: countries.get(name=x, default=NoneCountry).alpha_3)
df

In [None]:
df["Year"].unique()

In [None]:
df_2015 = df[df["Year"] == 2015]

# sns.pairplot(df_2015[features + [response]],hue=response)

# fig_choropleth = px.choropleth(
#     df_2015,
#     locations="ISO_alpha",
#     color="Life expectancy"
# )
# fig_choropleth.write_image("life_expectancy_choropleth.png", width=1000, height=600, scale=3)

In [None]:
df_2015.set_index("ISO_alpha", inplace=True)

In [None]:
world = gpd.read_file("/Users/user/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

world = world[world['ISO_A3_EH'] != '-99']
border_map = build_border_map(world)

In [None]:
from collections import defaultdict
from itertools import product

from sklearn.metrics import mean_squared_error, r2_score
import torch.nn as nn

from tabrel.benchmark.nw_regr import MlpConfig, NwModelConfig, RelNwRegr



max_query_size = max_val_size = 40
min_query_size = min_val_size = 10

n_runs = 0
metrics = defaultdict(list)

seed: Final[int] = 42
np.random.seed(seed)
while n_runs < 15:
    year = np.random.choice(df["Year"])
    df_year = df[df["Year"] == year]
    df_year.set_index("ISO_alpha", inplace=True)
    R, iso_list = build_r_countries(df_year, border_map)
    y = df_year[response].to_numpy()
    all_isos = set(iso_list)
    iso_list = list(all_isos) # remove Nones
    query_iso = np.random.choice(iso_list)  # starting node for query set
    val_iso = np.random.choice(iso_list)  # starting node for validation set

    query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
    val_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)

    if len(query_set) < min_query_size or len(val_set) < min_val_size:
        continue

    if query_set & val_set:
        continue

    n_runs += 1
    backgnd_set = all_isos - query_set - val_set

    backgnd_indices = np.array([i for i, iso in enumerate(iso_list) if iso in backgnd_set])
    query_indices = np.array([i for i, iso in enumerate(iso_list) if iso in query_set])
    val_indices = np.array([i for i, iso in enumerate(iso_list) if iso in val_set])

    x_initial = df_year[
    ["Polio", "HIV/AIDS", "Diphtheria", "under-five deaths"]
    # ["Alcohol"]
    # ["thinness  1-19 years"],
    # ["BMI"],
    ].to_numpy()

    # torch.manual_seed(seed)
    # try:
    #     results_relnet = train_relnet(
    #         x=x_initial,
    #         y=y,
    #         r=R,
    #         backgnd_indices=np.array(backgnd_indices),
    #         query_indices=np.array(query_indices),
    #         val_indices=np.array(val_indices),
    #         lr=.01,
    #         n_epochs=1500,
    #         periodic_embed_dim=None,
    #         progress_bar=True,
    #         print_loss=False,
    #         n_layers=2,
    #         num_heads=2,
    #         embed_dim=8,
    #     )
    #     relnet_mse, relnet_r2 = results_relnet[:2]
    #     metrics["relnet_mse"].append(relnet_mse)
    #     metrics["relnet_r2"].append(relnet_r2)
    # except Exception as e:
    #      print(e)


    n_samples, n_feats = x_initial.shape

    n_back = n_query = n_samples // 3
    n_test = n_samples - (n_back + n_query)


    inds_back = np.array(range(n_back))
    inds_q = np.array(range(n_query)) + n_back
    inds_val = np.array(range(len(x_initial) - n_query - n_back)) + n_query + n_back

    try:
        res = run_training(
            x=x_initial, y=y, r=R,
            backgnd_indices=inds_back,
            query_indices=inds_q,
            val_indices=inds_val,
            lr=.005,
            n_epochs=5000,
            rel_as_feats=R,
            mlp_config=MlpConfig(
                    in_dim=x_initial.shape[1],
                    hidden_dim=10,
                    out_dim=20,
                    dropout=.19,
                ),
        )
        for k, v in res.items():
            mse, r2 = v[:2]
            metrics[f"{k}_mse"].append(mse)
            metrics[f"{k}_r2"].append(r2)
    except Exception as e:
            print(e) 

## 100 iterations

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>name</th>
      <th>mean</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False_mse</td>
      <td>26.40</td>
      <td>3.02</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=True;mlp=False_r2</td>
      <td>0.58</td>
      <td>0.06</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=True;trainable_w=False;mlp=False_mse</td>
      <td>26.95</td>
      <td>3.85</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=True;trainable_w=False;mlp=False_r2</td>
      <td>0.57</td>
      <td>0.08</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=True;mlp=False_mse</td>
      <td>35.08</td>
      <td>6.70</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=False;trainable_w=True;mlp=False_r2</td>
      <td>0.44</td>
      <td>0.15</td>
    </tr>
    <tr>
      <th>6</th>
      <td>rel=False;trainable_w=False;mlp=False_mse</td>
      <td>34.72</td>
      <td>7.77</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel=False;trainable_w=False;mlp=False_r2</td>
      <td>0.44</td>
      <td>0.16</td>
    </tr>
    <tr>
      <th>8</th>
      <td>rel=False;trainable_w=False;mlp=True_mse</td>
      <td>33.22</td>
      <td>4.71</td>
    </tr>
    <tr>
      <th>9</th>
      <td>rel=False;trainable_w=False;mlp=True_r2</td>
      <td>0.47</td>
      <td>0.11</td>
    </tr>
    <tr>
      <th>10</th>
      <td>rel=True;trainable_w=False;mlp=True_mse</td>
      <td>27.76</td>
      <td>3.42</td>
    </tr>
    <tr>
      <th>11</th>
      <td>rel=True;trainable_w=False;mlp=True_r2</td>
      <td>0.56</td>
      <td>0.07</td>
    </tr>
    <tr>
      <th>12</th>
      <td>lgb_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>13</th>
      <td>lgb_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
    <tr>
      <th>14</th>
      <td>rel-as-feats_mse</td>
      <td>1357.17</td>
      <td>92.54</td>
    </tr>
    <tr>
      <th>15</th>
      <td>rel-as-feats_r2</td>
      <td>-20.45</td>
      <td>2.04</td>
    </tr>
    <tr>
      <th>16</th>
      <td>lgb-rel_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>17</th>
      <td>lgb-rel_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
  </tbody>
</table>
</div>

## 5000 iterations

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>name</th>
      <th>mean</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False_mse</td>
      <td>23.56</td>
      <td>3.42</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=True;mlp=False_r2</td>
      <td>0.63</td>
      <td>0.07</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=True;trainable_w=False;mlp=False_mse</td>
      <td>25.23</td>
      <td>2.79</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=True;trainable_w=False;mlp=False_r2</td>
      <td>0.60</td>
      <td>0.06</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=True;mlp=False_mse</td>
      <td>30.63</td>
      <td>5.18</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=False;trainable_w=True;mlp=False_r2</td>
      <td>0.52</td>
      <td>0.08</td>
    </tr>
    <tr>
      <th>6</th>
      <td>rel=False;trainable_w=False;mlp=False_mse</td>
      <td>26.26</td>
      <td>2.32</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel=False;trainable_w=False;mlp=False_r2</td>
      <td>0.59</td>
      <td>0.02</td>
    </tr>
    <tr>
      <th>8</th>
      <td>rel=False;trainable_w=False;mlp=True_mse</td>
      <td>28.94</td>
      <td>4.07</td>
    </tr>
    <tr>
      <th>9</th>
      <td>rel=False;trainable_w=False;mlp=True_r2</td>
      <td>0.54</td>
      <td>0.10</td>
    </tr>
    <tr>
      <th>10</th>
      <td>rel=True;trainable_w=False;mlp=True_mse</td>
      <td>39.28</td>
      <td>25.24</td>
    </tr>
    <tr>
      <th>11</th>
      <td>rel=True;trainable_w=False;mlp=True_r2</td>
      <td>0.37</td>
      <td>0.44</td>
    </tr>
    <tr>
      <th>12</th>
      <td>lgb_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>13</th>
      <td>lgb_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
    <tr>
      <th>14</th>
      <td>rel-as-feats_mse</td>
      <td>1232.02</td>
      <td>68.89</td>
    </tr>
    <tr>
      <th>15</th>
      <td>rel-as-feats_r2</td>
      <td>-18.57</td>
      <td>2.63</td>
    </tr>
    <tr>
      <th>16</th>
      <td>lgb-rel_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>17</th>
      <td>lgb-rel_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
  </tbody>
</table>
</div>

In [None]:
results_stats = []
for k, v in metrics.items():
    results_stats.append({"name": k, 
                          "mean": round(np.mean(v), 2),
                          "std": round(np.std(v), 2)})
    #  print(f"{k}: {np.mean(v):.2f} & {np.std(v):.2f}")
pd.DataFrame(results_stats)

In [None]:
# TODO try parameters grid