In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import Final

import kagglehub
from pycountry import countries

import numpy as np
import pandas as pd
import geopandas as gpd

from tabrel.benchmark.nw_regr import run_training
from tabrel.utils.geo import get_connected_country_set, build_border_map, build_r_countries



path = kagglehub.dataset_download("amirhosseinmirzaie/countries-life-expectancy")
df = pd.read_csv(list(Path(path).glob("*.csv"))[0])
features: Final[list[str]] = ["Hepatitis B", "Polio", "Diphtheria", "HIV/AIDS", "BMI"]
response: Final[str] = "Life expectancy"

class NoneCountry:
    alpha_3 = None
    
df["ISO_alpha"] = df["Country"].apply(lambda x: countries.get(name=x, default=NoneCountry).alpha_3)
df

In [None]:
df["Year"].unique()

In [None]:
df_2015 = df[df["Year"] == 2015]

# sns.pairplot(df_2015[features + [response]],hue=response)

# fig_choropleth = px.choropleth(
#     df_2015,
#     locations="ISO_alpha",
#     color="Life expectancy"
# )
# fig_choropleth.write_image("life_expectancy_choropleth.png", width=1000, height=600, scale=3)

In [None]:
df_2015.set_index("ISO_alpha", inplace=True)

In [None]:
world = gpd.read_file("/Users/user/Documents/git/gh_zuevval/tabrel/data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp")

world = world[world['ISO_A3_EH'] != '-99']
border_map = build_border_map(world)

In [None]:
from collections import defaultdict

from tabrel.benchmark.nw_regr import MlpConfig, train_nw_mlp
from tabrel.optuna import RelTrainData

max_query_size = max_val_size = 40
min_query_size = min_val_size = 10




def build_data(_seed: int) -> RelTrainData:
    np.random.seed(_seed)
    year = np.random.choice(df["Year"])
    df_year = df[df["Year"] == year]
    df_year.set_index("ISO_alpha", inplace=True)
    df_year.dropna(inplace=True)
    r, iso_list = build_r_countries(df_year, border_map)
    y = df_year[response].to_numpy()
    all_isos = set(iso_list)
    iso_list = list(all_isos)  # remove Nones

    x = df_year[
    ["Polio", "HIV/AIDS", "Diphtheria", "under-five deaths"]
    # ["Alcohol"]
    # ["thinness  1-19 years"],
    # ["BMI"],
    ].to_numpy()

    # query_iso = np.random.choice(iso_list)  # starting node for query set
    # val_iso = np.random.choice(iso_list)  # starting node for validation set

    # query_set = get_connected_country_set(query_iso, border_map, max_size=max_query_size)
    # val_set = get_connected_country_set(val_iso, border_map, max_size=max_val_size)
    # backgnd_set = all_isos - query_set - val_set

    # if len(query_set) < min_query_size or len(val_set) < min_val_size:
    #     return None

    # if query_set & val_set:
    #     return None

    # inds_back = np.array([i for i, iso in enumerate(iso_list) if iso in backgnd_set])
    # inds_q = np.array([i for i, iso in enumerate(iso_list) if iso in query_set])
    # inds_val = np.array([i for i, iso in enumerate(iso_list) if iso in val_set])

    n_samples = x.shape[0]
    n_back = n_query = n_samples // 3

    inds_back = np.array(range(n_back))
    inds_q = np.array(range(n_query)) + n_back
    inds_val = np.array(range(len(x) - n_query - n_back)) + n_query + n_back

    return RelTrainData(
        r=r,
        x=x,
        y=y,
        back_ids=inds_back,
        query_ids=inds_q,
        val_ids=inds_val,
    )

def train_nw_mlp_life_expectancy(
        _seed: int,
        mlp_hid_dim: int,
        mlp_out_dim: int,
        dropout: float,
        weight_decay: float,
        n_epochs: int,
        ) -> float:
    le_data = build_data(_seed)
    if le_data is None:
        print("data is None")
        return 0.
    _, r2 = train_nw_mlp(
        x=le_data.x,
        y=le_data.y,
        r=le_data.r,
        back_ids=le_data.back_ids,
        query_ids=le_data.query_ids,
        val_ids=le_data.val_ids,
        mlp_hid_dim=mlp_hid_dim,
        mlp_out_dim=mlp_out_dim,
        dropout=dropout,
        weight_decay=weight_decay,
        _n_epochs=n_epochs,
        writer=None,
        trainable_weights=False,
        seed=_seed,
    )
    return r2


## NW + rel-s + MLP embeddings + Optuna

In [None]:
from datetime import datetime

import optuna

seed_current = 0
def objective(trial: optuna.Trial) -> float:
    weight_decay = trial.suggest_float("weight_decay", 0., 1e-1)
    mlp_hid_dim = trial.suggest_int("mlp_hid_dim", 4, 100)
    mlp_out_dim = trial.suggest_int("mlp_out_dim", 1, 40)
    dropout = trial.suggest_float("dropout", 0., .6)

    global seed_current
    seed_current += 1
    return train_nw_mlp_life_expectancy(
        _seed=seed_current,
        mlp_hid_dim=mlp_hid_dim,
        mlp_out_dim=mlp_out_dim,
        dropout=dropout,
        weight_decay=weight_decay,
        n_epochs=1000,
    )

sqlite_path: Final[str] = "sqlite:///db.sqlite3"
study = optuna.create_study(direction="maximize",
                            study_name=f"lifeExpectancy_{datetime.now()}", 
                            storage=sqlite_path)
study.optimize(objective, catch=(ValueError,),)

In [None]:
study.best_params

# {'weight_decay': 0.08806794264604936,
#  'mlp_hid_dim': 61,
#  'mlp_out_dim': 16,
#  'dropout': 0.4438266049177131}

## TabRel + Optuna

In [None]:
from tabrel.optuna import build_objective_relnet


seed = 0
n_epochs_optuna: Final[int] = 1000


def objective_relnet(trial: optuna.Trial) -> float:
    global seed
    seed += 1

    data = build_data(seed)
    return build_objective_relnet(
        trial,
        data,
        n_epochs_optuna,
        seed,
    )


study = optuna.create_study(
    storage=sqlite_path,
    study_name=f"LifeExpectancyRelnet_{datetime.now()}",
    direction="maximize",
)

study.optimize(objective_relnet)

In [None]:
import torch
from tabrel.train import train_relnet

n_runs = 0
metrics = defaultdict(list)

while n_runs < 15:
    np.random.seed(n_runs)
    data = build_data(_seed=n_runs)
    x_initial, R, y = data.x, data.r, data.y
    inds_back, inds_q, inds_val = data.back_ids, data.query_ids, data.val_ids

    torch.manual_seed(n_runs)
    try:
        results_relnet = train_relnet(
            x=x_initial,
            y=y,
            r=R,
            backgnd_indices=inds_back,
            query_indices=inds_q,
            val_indices=inds_val,
            lr=.0063,
            n_epochs=1000,
            periodic_embed_dim=69,
            n_layers=1,
            num_heads=3,
            embed_dim=3 * 10,
            dropout=.12,
            progress_bar=True,
            print_loss=False,
        )
        relnet_mse, relnet_r2 = results_relnet[:2]
        metrics["relnet_mse"].append(relnet_mse)
        metrics["relnet_r2"].append(relnet_r2)
    except Exception as e:
         print(e)

    try:
        res = run_training(
            x=x_initial, y=y, r=R,
            backgnd_indices=inds_back,
            query_indices=inds_q,
            val_indices=inds_val,
            lr=.005,
            n_epochs=5000,
            # rel_as_feats=R,
            mlp_config=MlpConfig(
                    in_dim=x_initial.shape[1],
                    hidden_dim=61,
                    out_dim=16,
                    dropout=.44,
                ),
        )
        for k, v in res.items():
            mse, r2 = v[:2]
            metrics[f"{k}_mse"].append(mse)
            metrics[f"{k}_r2"].append(r2)
    except Exception as e:
            print(e)
    n_runs += 1

In [None]:
results_stats = []
for k, v in metrics.items():
    results_stats.append({"name": k, 
                          "mean": round(np.mean(v), 2),
                          "std": round(np.std(v), 2)})
    #  print(f"{k}: {np.mean(v):.2f} & {np.std(v):.2f}")
pd.DataFrame(results_stats)

# Results

## 100 iterations

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>name</th>
      <th>mean</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False_mse</td>
      <td>26.40</td>
      <td>3.02</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=True;mlp=False_r2</td>
      <td>0.58</td>
      <td>0.06</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=True;trainable_w=False;mlp=False_mse</td>
      <td>26.95</td>
      <td>3.85</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=True;trainable_w=False;mlp=False_r2</td>
      <td>0.57</td>
      <td>0.08</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=True;mlp=False_mse</td>
      <td>35.08</td>
      <td>6.70</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=False;trainable_w=True;mlp=False_r2</td>
      <td>0.44</td>
      <td>0.15</td>
    </tr>
    <tr>
      <th>6</th>
      <td>rel=False;trainable_w=False;mlp=False_mse</td>
      <td>34.72</td>
      <td>7.77</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel=False;trainable_w=False;mlp=False_r2</td>
      <td>0.44</td>
      <td>0.16</td>
    </tr>
    <tr>
      <th>8</th>
      <td>rel=False;trainable_w=False;mlp=True_mse</td>
      <td>33.22</td>
      <td>4.71</td>
    </tr>
    <tr>
      <th>9</th>
      <td>rel=False;trainable_w=False;mlp=True_r2</td>
      <td>0.47</td>
      <td>0.11</td>
    </tr>
    <tr>
      <th>10</th>
      <td>rel=True;trainable_w=False;mlp=True_mse</td>
      <td>27.76</td>
      <td>3.42</td>
    </tr>
    <tr>
      <th>11</th>
      <td>rel=True;trainable_w=False;mlp=True_r2</td>
      <td>0.56</td>
      <td>0.07</td>
    </tr>
    <tr>
      <th>12</th>
      <td>lgb_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>13</th>
      <td>lgb_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
    <tr>
      <th>14</th>
      <td>rel-as-feats_mse</td>
      <td>1357.17</td>
      <td>92.54</td>
    </tr>
    <tr>
      <th>15</th>
      <td>rel-as-feats_r2</td>
      <td>-20.45</td>
      <td>2.04</td>
    </tr>
    <tr>
      <th>16</th>
      <td>lgb-rel_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>17</th>
      <td>lgb-rel_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
  </tbody>
</table>
</div>

## 5000 iterations

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>name</th>
      <th>mean</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>rel=True;trainable_w=True;mlp=False_mse</td>
      <td>23.56</td>
      <td>3.42</td>
    </tr>
    <tr>
      <th>1</th>
      <td>rel=True;trainable_w=True;mlp=False_r2</td>
      <td>0.63</td>
      <td>0.07</td>
    </tr>
    <tr>
      <th>2</th>
      <td>rel=True;trainable_w=False;mlp=False_mse</td>
      <td>25.23</td>
      <td>2.79</td>
    </tr>
    <tr>
      <th>3</th>
      <td>rel=True;trainable_w=False;mlp=False_r2</td>
      <td>0.60</td>
      <td>0.06</td>
    </tr>
    <tr>
      <th>4</th>
      <td>rel=False;trainable_w=True;mlp=False_mse</td>
      <td>30.63</td>
      <td>5.18</td>
    </tr>
    <tr>
      <th>5</th>
      <td>rel=False;trainable_w=True;mlp=False_r2</td>
      <td>0.52</td>
      <td>0.08</td>
    </tr>
    <tr>
      <th>6</th>
      <td>rel=False;trainable_w=False;mlp=False_mse</td>
      <td>26.26</td>
      <td>2.32</td>
    </tr>
    <tr>
      <th>7</th>
      <td>rel=False;trainable_w=False;mlp=False_r2</td>
      <td>0.59</td>
      <td>0.02</td>
    </tr>
    <tr>
      <th>8</th>
      <td>rel=False;trainable_w=False;mlp=True_mse</td>
      <td>28.94</td>
      <td>4.07</td>
    </tr>
    <tr>
      <th>9</th>
      <td>rel=False;trainable_w=False;mlp=True_r2</td>
      <td>0.54</td>
      <td>0.10</td>
    </tr>
    <tr>
      <th>10</th>
      <td>rel=True;trainable_w=False;mlp=True_mse</td>
      <td>39.28</td>
      <td>25.24</td>
    </tr>
    <tr>
      <th>11</th>
      <td>rel=True;trainable_w=False;mlp=True_r2</td>
      <td>0.37</td>
      <td>0.44</td>
    </tr>
    <tr>
      <th>12</th>
      <td>lgb_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>13</th>
      <td>lgb_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
    <tr>
      <th>14</th>
      <td>rel-as-feats_mse</td>
      <td>1232.02</td>
      <td>68.89</td>
    </tr>
    <tr>
      <th>15</th>
      <td>rel-as-feats_r2</td>
      <td>-18.57</td>
      <td>2.63</td>
    </tr>
    <tr>
      <th>16</th>
      <td>lgb-rel_mse</td>
      <td>22.11</td>
      <td>1.75</td>
    </tr>
    <tr>
      <th>17</th>
      <td>lgb-rel_r2</td>
      <td>0.65</td>
      <td>0.04</td>
    </tr>
  </tbody>
</table>
</div>

## TabRel, 1000 iters, Optuna-optimized hyperparams

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>name</th>
      <th>mean</th>
      <th>std</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>relnet_mse</td>
      <td>40.8811</td>
      <td>13.4400</td>
    </tr>
    <tr>
      <th>1</th>
      <td>relnet_r2</td>
      <td>0.4300</td>
      <td>0.2335</td>
    </tr>
  </tbody>
</table>
</div>