## Utils

In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [2]:
import torch
import random
import numpy as np

def set_pytorch_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.use_deterministic_algorithms(True)

set_pytorch_seed(0)

In [None]:
from tsl.data.preprocessing.scalers import MinMaxScaler
from models.utils.MPNN import MPNN
from models.baseline.MPNN_ODE import MPNN_ODE
import torch
import sympytorch
from torch_geometric.data import Data
import sympy as sp
import copy
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np
from post_processing import get_model, make_callable, plot_predictions
from torch.optim import LBFGS
import pandas as pd
from typing import Dict
from datasets.RealEpidemics import RealEpidemics
import json


def get_scaler(data, tr_perc = 0.8, scale_range = (-1, 1)):
    raw_data = data.raw_data_sampled.detach().cpu().numpy() # shape (IC, T, N, 1)
    tr_len = raw_data.shape[1]
    raw_data = raw_data[0, :int(tr_perc*tr_len), :, :]
    scaler = MinMaxScaler(out_range=scale_range)
    scaler.fit(raw_data.flatten())
    
    return scaler


  
def eval_real_epid_int(data, countries_dict, build_symb_model, inferred_coeffs, scaler=None, use_euler=False, tr_perc = 0.8, remove_germany=False):
    y_true = data[0].y.detach().cpu().numpy()
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name, inferred_coeffs)
        # print(f"{country_name}")
        data_0 = data[0]
        if scaler is not None:
            tmp = scaler.transform(data[0].x)
            data_0 = data[0]
            data_0.x = tmp
        
        if use_euler:
            symb_model.integration_method = "euler"
            data_0.t_span = torch.arange(y_true.shape[0] + 1, device=data_0.x.device, dtype=data_0.t_span.dtype)
        
        try:
            pred = symb_model(data_0).detach().cpu().numpy()
        except AssertionError:
            print("Failed")
            continue
        
        if scaler is not None:
            pred = scaler.inverse_transform(pred)
        
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
    
    
    if remove_germany:
        y_true = np.delete(y_true, countries_dict["Germany"], axis=1)
        y_pred = np.delete(y_pred, countries_dict["Germany"], axis=1)
        
    tr_len = y_true.shape[0]
    tr_end = int(tr_perc * tr_len)
    y_true_val = y_true[tr_end:, :, :]
    y_pred_val = y_pred[tr_end:, :, :] 
    
    return y_true, y_pred, y_true_val, y_pred_val 


def eval_real_epid_journal(data, countries_dict, build_symb_model, inferred_coeffs, tr_perc = 0.8, step_size=1.0, scaler = None,
                           remove_germany = False):
    def get_dxdt_pred(data, symb_model):
        dxdt_pred = []
        for snapshot in data:
            dxdt_pred.append(symb_model(snapshot))
        
        return torch.stack(dxdt_pred, dim=0)
    
    def sum_over_dxdt(dxdt_pred):
        out = []
        for i in range(dxdt_pred.shape[0]):
            out.append(torch.sum(step_size*dxdt_pred[0:i+1, :, :], dim=0)) 
        
        return torch.stack(out, dim=0)
        
    def integrate(out, x0):
        pred = []
        for i in range(1, out.shape[0]):
            pred.append(x0 + out[i, :, :])
        return torch.stack(pred, dim=0)
      
    x0 = data[0].x
    y_true = torch.stack([d.x for d in data[1:]], dim=0).detach().cpu().numpy()
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name, inferred_coeffs)
        symb_model.predict_deriv = True
        dxdt_pred = get_dxdt_pred(data, symb_model)
        out = sum_over_dxdt(dxdt_pred)
        pred = integrate(out, x0).detach().cpu().numpy()
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
    
    if scaler is not None:
        y_pred = scaler.inverse_transform(y_pred)
        y_true = scaler.inverse_transform(y_true)
    
    if remove_germany:
        y_true = np.delete(y_true, countries_dict["Germany"], axis=1)
        y_pred = np.delete(y_pred, countries_dict["Germany"], axis=1)
    
    tr_len = y_true.shape[0]
    tr_end = int(tr_perc * tr_len)
    y_true_val = y_true[tr_end:, :, :]
    y_pred_val = y_pred[tr_end:, :, :] 
    
    return y_true, y_pred, y_true_val, y_pred_val 
    



Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [None]:
real_epid_data = RealEpidemics(
    root = './data_real_epid_covid_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False
)

In [None]:
with open('./data_real_epid_covid_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)

## TSS

In [198]:
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/tss/real_epid_covid/Inferred_coefficients_covid.csv").values

def build_symb_model_tss(country, inf_coeff_covid):
    x_i, x_j = sp.symbols('x_i x_j')    
    country_idx = countries_dict[country]

    g_symb = inf_coeff_covid[1, country_idx] * (1 / (1 + sp.exp(- (x_j - x_i))))
    h_symb = inf_coeff_covid[0, country_idx] * x_i

    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='rk4'
    )
    
    return symb_model

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_covid,
    build_symb_model=build_symb_model_tss,
    use_euler=True
)

In [199]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")

Validation MAE: 50494.5703125


In [200]:
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Overall MAE: 12133.45703125


In [9]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/tss/real_epid_covid/figures"
    )

In [47]:
data_real_epid_orig = RealEpidemics(
    root = './data_real_epid_covid_orig',
    name = 'RealEpid',
    predict_deriv=True,
    scale=False,
    device='cpu'
)

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_journal(
    data = data_real_epid_orig[2:-2],
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_tss,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.8
)

In [48]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Validation MAE: 757.14208984375
Overall MAE: 413.9864807128906


In [18]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/tss/real_epid_covid/figures_journal"
    )

## GKAN

### GD

In [6]:
class symb_wrapper(torch.nn.Module):
    def __init__(self, symb, is_self_interaction = True):
        super().__init__()
        self.symb = symb
        self.is_self_interaction = is_self_interaction
    
    def forward(self, x):
        if self.is_self_interaction:
            return self.symb(x_i=x[:, 0])
        else:
            return self.symb(x_i=x[:, 0], x_j=x[:, 1])
        

In [None]:
model_path = "./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0"

# import random

def build_symb_gkan():
    x_i, x_j = sp.symbols('x_i x_j')

    a = 2.4682064
    b = 2.4648788 
    c = -0.0039747115
    

    g_symb = sympytorch.SymPyModule(expressions=[sp.Min(sp.Max(sp.exp(c * x_j), 1e-6), 1e6)])
    h_symb = sympytorch.SymPyModule(expressions=[a * x_i + b])
    
    g_symb = symb_wrapper(g_symb, is_self_interaction=False)
    h_symb = symb_wrapper(h_symb, is_self_interaction=True)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='rk4',
        eval=False,
        all_t=True
    )
    
    symb_model = symb_model.train()
    
    symb_model = symb_model.to('cuda:0')
    
    return symb_model


# scaler = get_scaler(data = real_epid_data, tr_perc=0.8)

fine_tune_data = RealEpidemics(
    root = './data_real_epid_covid_ft_scaled',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=9,
    stride=7,
    scale=True,
    scale_range=(-1, 1),
    train_perc=0.8,
    device="cuda:0"
)

tr_len = len(fine_tune_data)
tr_end = int(0.8 * tr_len)
train_set = fine_tune_data[:tr_end]
valid_set = fine_tune_data[tr_end:]

fit_param_per_country_gd(
    data=train_set,
    valid_data=valid_set,
    build_symb_model=build_symb_gkan,
    countries_dict=countries_dict,
    model_path=model_path,
    patience=30,
    log=10,
    scaler=None,
    tr_perc=0.8,
    save_file="prova_3.csv",
    optimizer_type="adam"
)

### Plot predictions

In [89]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/prova.csv")
# inf_coeff_covid = pd.read_csv("./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/inferred_coeff.csv")

def build_symb_model_gkan(country, inf_coeff_covid):
    x_i, x_j = sp.symbols('x_i x_j')    

    coeffs = inf_coeff_covid[country]
    a, b, _, _, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2], coeffs.iloc[3], coeffs.iloc[4] 

    g_symb = sp.exp(c * x_j)
    h_symb = b * x_i + a
    # a, b, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2] 

    # g_symb = sp.exp(c * x_j)
    # h_symb = b * x_i + a
    
    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.8,
    remove_germany=True
)

In [90]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")

Validation MAE: 424.1941833496094


In [91]:
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Overall MAE: 419.3455810546875


In [85]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/figures"
    )

#### Journal

In [80]:
data_real_epid_scaled = RealEpidemics(
    root = './data_real_epid_covid_scaled',
    name = 'RealEpid',
    predict_deriv=True,
    scale=True,
    scale_range=(-1, 1),
    train_perc=0.8,
)
t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_journal(
    data = data_real_epid_scaled[2:-2],
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    tr_perc=0.8,
    step_size=epsilon,
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler,
    remove_germany=False
)

In [81]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Validation MAE: 737.4694213867188
Overall MAE: 378.0634765625


In [38]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/figures_journal"
    )

## MPNN

In [None]:
model_path = "./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0"

# import random

def build_symb_mpnn_to_opt():
    x_i, x_j = sp.symbols('x_i x_j')

    a = 3.7716758
    b = 1.9867662
    c = 1.2657967

    eps = 1e-8
    expr1 = sp.ln(sp.Max(sp.tan(x_i + c)**2 + 1, eps))
    expr2 = a * sp.ln(sp.Max(x_i + b, eps))

    g_symb = sympytorch.SymPyModule(expressions=[expr1])
    h_symb = sympytorch.SymPyModule(expressions=[expr2])
    
    g_symb = symb_wrapper(g_symb, is_self_interaction=False)
    h_symb = symb_wrapper(h_symb, is_self_interaction=True)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='rk4',
        eval=False,
        all_t=True
    )

    symb_model = symb_model.train()   
    symb_model = symb_model.to('cuda:0')
    
    return symb_model


fine_tune_data = RealEpidemics(
    root = './data_real_epid_covid_ft_scaled',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=9,
    stride=7,
    scale=True,
    scale_range=(-1, 1),
    train_perc=0.8
)

tr_len = len(fine_tune_data)
tr_end = int(0.8 * tr_len)
train_set = fine_tune_data[:tr_end]
valid_set = fine_tune_data[tr_end:]

fit_param_per_country_gd(
    data=train_set,
    valid_data=valid_set,
    build_symb_model=build_symb_mpnn_to_opt,
    countries_dict=countries_dict,
    model_path=model_path,
    patience=30,
    log=10,
    scaler=None,
    tr_perc=0.8,
    save_file="prova_2.csv",
    optimizer_type="adam"
)



Processing country Iceland
Testing config: lr=0.01, epochs=50


Epoch 0, train loss: 0.0092, val loss: 0.0084
Epoch 10, train loss: 0.0013, val loss: 0.0119
Epoch 20, train loss: 0.0011, val loss: 0.0112
Epoch 30, train loss: 0.0012, val loss: 0.0106
Early stopping at epoch 38
Testing config: lr=0.01, epochs=100
Epoch 0, train loss: 0.0092, val loss: 0.0084
Epoch 10, train loss: 0.0013, val loss: 0.0119
Epoch 20, train loss: 0.0011, val loss: 0.0112


KeyboardInterrupt: 

### Plot predictions

In [74]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/prova.csv")


def build_symb_model_mpnn(country, inf_coeff_covid):
    
    coeffs = inf_coeff_covid[country]
    a, _, b, _, c, = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2], coeffs.iloc[3], coeffs.iloc[4] 

    
    expr1 = sp.ln(sp.Max(sp.tan(x_i + c)**2 + 1, 1e-6))
    expr2 = a * sp.ln(sp.Max(x_i + b, 1e-6))
    
    g_symb = make_callable(expr1)
    h_symb = make_callable(expr2)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.8,
    remove_germany=False
)

In [75]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")

Validation MAE: 304.4453125


In [76]:
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Overall MAE: 309.2868957519531


In [58]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/figures"
    )

#### Journal

In [59]:
inf_coeff_covid = pd.read_csv("./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/prova.csv")

In [60]:
data_real_epid_scaled = RealEpidemics(
    root = './data_real_epid_covid_scaled',
    name = 'RealEpid',
    predict_deriv=True,
    scale=True,
    scale_range=(-1, 1),
    train_perc=0.8,
)
t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)

y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_journal(
    data = data_real_epid_scaled[2:-2],
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    tr_perc=0.8,
    step_size=epsilon,
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler,
    remove_germany=False
)

In [61]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Validation MAE: 583.709716796875
Overall MAE: 258.97357177734375


In [62]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/figures_journal"
    )

## Generalization H1N1

In [170]:
from datasets.RealEpidemics import RealEpidemics

real_epid_data_h1n1 = RealEpidemics(
    root = './data_real_epid_h1n1_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False,
    inf_threshold=100,
    infection_data="./data/RealEpidemics/infected_numbers_H1N1.csv"
)

import json

with open('./data_real_epid_h1n1_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)

### GKAN

In [171]:
model_path = "./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0"

# import random

def build_symb_gkan():
    x_i, x_j = sp.symbols('x_i x_j')

    a = 2.4682064
    b = 2.4648788 
    c = -0.0039747115
    

    g_symb = sympytorch.SymPyModule(expressions=[sp.Min(sp.Max(sp.exp(c * x_j), 1e-6), 1e6)])
    h_symb = sympytorch.SymPyModule(expressions=[a * x_i + b])
    
    g_symb = symb_wrapper(g_symb, is_self_interaction=False)
    h_symb = symb_wrapper(h_symb, is_self_interaction=True)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='rk4',
        eval=False,
        all_t=True
    )
    
    symb_model = symb_model.train()
    
    symb_model = symb_model.to('cuda:1')
    
    return symb_model


scaler = get_scaler(data = real_epid_data_h1n1, tr_perc=0.8)

fit_param_per_country_gd(
    data=real_epid_data_h1n1,
    build_symb_model=build_symb_gkan,
    countries_dict=countries_dict,
    model_path=model_path,
    epochs=3,
    lr=1e-3,
    patience=3,
    log=1,
    scaler=scaler,
    tr_perc=0.8,
    device="cuda:1",
    save_file="inf_coeff_h1n1.csv"
)


Processing country Canada
[Parameter containing:
tensor(1000000., device='cuda:1', requires_grad=True), Parameter containing:
tensor(1.0000e-06, device='cuda:1', requires_grad=True), Parameter containing:
tensor(0.0466, device='cuda:1', requires_grad=True), Parameter containing:
tensor(2.1446, device='cuda:1', requires_grad=True), Parameter containing:
tensor(2.6572, device='cuda:1', requires_grad=True)]
Epoch 0, train Loss: 0.4583, valid loss: 0.0301
[Parameter containing:
tensor(1000000., device='cuda:1', requires_grad=True), Parameter containing:
tensor(1.0000e-06, device='cuda:1', requires_grad=True), Parameter containing:
tensor(-0.0540, device='cuda:1', requires_grad=True), Parameter containing:
tensor(1.9905, device='cuda:1', requires_grad=True), Parameter containing:
tensor(2.5256, device='cuda:1', requires_grad=True)]
Epoch 1, train Loss: 0.0101, valid loss: 0.0208
[Parameter containing:
tensor(1000000., device='cuda:1', requires_grad=True), Parameter containing:
tensor(1.000

In [172]:
scaler = get_scaler(data = real_epid_data_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/inf_coeff_h1n1.csv")


y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_int(
    data = real_epid_data_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_h1n1,
    tr_perc=0.8,
    remove_germany=False
)

In [173]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Validation MAE: 430.30413818359375
Overall MAE: 151.86192321777344


### MPNN

In [None]:
model_path = "./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0"

scaler = get_scaler(data = real_epid_data_h1n1, tr_perc=0.8)


fit_param_per_country_gd(
    data=real_epid_data_h1n1,
    build_symb_model=build_symb_mpnn_to_opt,
    countries_dict=countries_dict,
    model_path=model_path,
    epochs=3,
    lr=1e-3,
    patience=3,
    log=1,
    scaler=scaler,
    tr_perc=0.8,
    device="cuda:1",
    save_file="inf_coeff_h1n1.csv"
)


Processing country Canada
Epoch 0, train Loss: 0.1191, valid loss: 0.1693
Epoch 1, train Loss: 0.0204, valid loss: 0.1693
Epoch 2, train Loss: 0.0204, valid loss: 0.1693
Inferred coeffs for Canada: [3.7716115e+00 1.0000000e-03 2.0046334e+00 1.0000000e-03 1.2674485e+00]

Processing country United Kingdom
Epoch 0, train Loss: 0.0574, valid loss: 0.0079
Epoch 1, train Loss: 0.0045, valid loss: 0.0079
Epoch 2, train Loss: 0.0045, valid loss: 0.0079
Inferred coeffs for United Kingdom: [3.7716711e+00 1.0000000e-03 1.9791098e+00 1.0000000e-03 1.2643467e+00]

Processing country Spain
Epoch 0, train Loss: 0.0108, valid loss: 0.0409
Epoch 1, train Loss: 0.0100, valid loss: 0.0409
Epoch 2, train Loss: 0.0100, valid loss: 0.0409
Inferred coeffs for Spain: [3.7716761e+00 1.0000000e-03 1.9861017e+00 1.0000000e-03 1.2657113e+00]

Processing country Greece
Epoch 0, train Loss: 0.0657, valid loss: 0.0018
Epoch 1, train Loss: 0.0010, valid loss: 0.0018
Epoch 2, train Loss: 0.0010, valid loss: 0.0018
In

In [179]:
scaler = get_scaler(data = real_epid_data_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./saved_models_optuna/model-real-epid-mpnn/real_epid_mpnn_7/0/inf_coeff_h1n1.csv")


y_true, y_pred, y_true_val, y_pred_val = eval_real_epid_int(
    data = real_epid_data_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_h1n1,
    tr_perc=0.8,
    remove_germany=False
)

In [180]:
print(f"Validation MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

Validation MAE: 359.6402893066406
Overall MAE: 169.16229248046875
