# Real Epid

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import numpy as np

import torch
from main import set_pytorch_seed
from post_processing import get_model, make_callable, plot_predictions
import pandas as pd
import sympy as sp
from sklearn.metrics import mean_absolute_error
from post_processing import build_model_from_file as build_kan
from post_processing_mpnn import build_model_from_file_mpnn as build_mpnn
from post_processing_mpnn import build_model_from_file_llc as build_llc

set_pytorch_seed(0)

In [None]:
from fine_tuning_coefficients import get_scaler

def eval_real_epid_int(data, countries_dict, build_symb_model, inferred_coeffs, scaler=None, use_euler=False, tr_perc = 0.8, 
                       mask = None, device='cuda:0'):
    y_true = data[0].y.detach().cpu().numpy()
    if mask is not None:
        y_true = y_true[:, mask, :]
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name, inferred_coeffs)
        symb_model = symb_model.to(device)
        # print(f"{country_name}")
        data_0 = data[0].cpu()
        if scaler is not None:
            tmp = scaler.transform(data[0].x)
            data_0 = data[0]
            data_0.x = tmp
        
        data_0 = data_0.to(device)
        if use_euler:
            symb_model.integration_method = "euler"
            data_0.t_span = torch.arange(y_true.shape[0] + 1, device=data_0.x.device, dtype=data_0.t_span.dtype)
        
        try:
            pred = symb_model(data_0).detach().cpu().numpy()
            if mask is not None:
                pred = pred[:, mask, :]
        except AssertionError:
            print("Failed")
            continue
        
        if scaler is not None:
            pred = scaler.inverse_transform(pred)
        
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
    
        
    tr_len = y_true.shape[0]
    tr_end = int(tr_perc * tr_len)
    y_true_val = y_true[tr_end:, :, :]
    y_pred_val = y_pred[tr_end:, :, :] 
    
    print(f"Test MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
    print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")    
    
    return y_true, y_pred, y_true_val, y_pred_val 


def eval_real_epid_journal(data, countries_dict, build_symb_model, inferred_coeffs, tr_perc = 0.8, step_size=1.0, scaler = None,
                           device='cpu', mask=None):
    def get_dxdt_pred(data, symb_model):
        dxdt_pred = []
        for snapshot in data:
            if scaler is not None:
                snapshot.x = scaler.transform(snapshot.x)
            snapshot = snapshot.to(device)
            dxdt_pred.append(symb_model(snapshot))
        
        return torch.stack(dxdt_pred, dim=0)
    
    def sum_over_dxdt(dxdt_pred):
        out = []
        for i in range(dxdt_pred.shape[0]):
            out.append(torch.sum(dxdt_pred[0:i+1, :, :], dim=0)) 
        
        return torch.stack(out, dim=0)
        
    def integrate(out, x0):
        pred = [x0]
        for i in range(out.shape[0] - 1):
            pred.append(x0 + step_size*out[i, :, :])
        return torch.stack(pred, dim=0)
      
    x0 = data[0].x
    if scaler is not None:
        x0 = scaler.transform(x0)
    x0 = x0.to(device)
    y_true = torch.stack([d.x for d in data], dim=0).detach().cpu().numpy()
    if mask is not None:
        y_true = y_true[:, mask, :]
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name, inferred_coeffs)
        symb_model = symb_model.to(device)
        symb_model.predict_deriv = True
        dxdt_pred = get_dxdt_pred(data, symb_model)
        out = sum_over_dxdt(dxdt_pred)
        pred = integrate(out, x0).detach().cpu().numpy()
        if mask is not None:
            pred = pred[:, mask, :]
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
    
    if scaler is not None:
        y_pred = scaler.inverse_transform(y_pred)    
    
    tr_len = y_true.shape[0]
    tr_end = int(tr_perc * tr_len)
    y_true_val = y_true[tr_end:, :, :]
    y_pred_val = y_pred[tr_end:, :, :] 
    
    print(f"Test MAE: {mean_absolute_error(y_true_val.flatten(), y_pred_val.flatten())}")
    print(f"Overall MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")  
    
    return y_true, y_pred, y_true_val, y_pred_val

In [None]:
import matplotlib.pyplot as plt
import json

In [None]:
def save_country_predictions(y_true, preds_dict, countries_dict, save_dir="./outputs/covid"):

    os.makedirs(save_dir, exist_ok=True)

    for country_name, node_idx in countries_dict.items():
        true_vals = y_true[:, node_idx, 0]

        plt.figure(figsize=(10, 6))
        # Ground truth
        plt.plot(true_vals, label="True", linewidth=2, color="black")

        # Predictions for each model
        for model_name, (y_pred, color) in preds_dict.items():
            pred_vals = y_pred[:, node_idx, 0]
            plt.plot(pred_vals, linestyle="--", label=model_name, color=color)

        plt.title(f"{country_name} - Model Comparison")
        plt.xlabel("Days")
        plt.ylabel("Infected Count")
        plt.legend()
        plt.tight_layout()

        filename = os.path.join(save_dir, f"{country_name}_comparison.png")
        plt.savefig(filename, dpi=150)
        plt.close()

In [None]:
from datasets.RealEpidemics import RealEpidemics

real_epid_data = RealEpidemics(
    root = './data_real_epid_covid_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False
)

data_real_epid_orig = RealEpidemics(
    root = './data_real_epid_covid_orig',
    name = 'RealEpid',
    predict_deriv=True,
    scale=False,
)

with open('./data_real_epid_covid_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)
    
all_res_covid_traj = {}
all_res_covid_eul = {}

### TSS 2

In [None]:
def build_symb_model_tss(country, inf_coeff):
    x_i, x_j = sp.symbols('x_i x_j')    
    country_idx = countries_dict[country]

    g_symb = inf_coeff[1, country_idx] * (1 / (1 + sp.exp(- (x_j - x_i))))
    h_symb = inf_coeff[0, country_idx] * x_i
    
    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='rk4'
    )
    
    return symb_model

In [None]:
inf_coeff_covid = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_all_covid.csv").values

y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_covid,
    build_symb_model=build_symb_model_tss,
    use_euler=True,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")
inf_coeff_covid = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_test_covid.csv").values

y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(
    data = data_real_epid_orig,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_tss,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.9,
    step_size=1.0,
    device='cpu'
)

all_res_covid_traj["TP-SINDy"] = (y_pred_tss.copy(), "red")
all_res_covid_eul["TP-SINDy"] = (y_pred_tss_jrn.copy(), "red")

### GKAN

In [None]:
def build_symb_model_gkan(country, inf_coeff):
    x_i, x_j = sp.symbols('x_i x_j')    

    coeffs = inf_coeff[country]
    b, a, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2] 

    g_symb = c * sp.exp(x_j)
    h_symb = a * x_i + b
    
    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

In [None]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./inferred_coeffs/gkan/inferred_coeffs_covid_ts.csv")

y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.9,
    device='cpu'
)

all_res_covid_traj["GKAN + GP"] = (y_pred_gkan.copy(), "#5fa2d1")

t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)

print("Mae Eul\n")

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    tr_perc=0.9,
    step_size=epsilon.item(),
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler,
    device='cpu'
)

all_res_covid_eul["GKAN + GP"] = (y_pred_gkan_jrn.copy(), "#5fa2d1")

### GKAN SW

In [None]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./inferred_coeffs/gkan_sw/inferred_coeffs_covid_sw.csv")

def build_model_sw(country, inf_coeff):

    
    coeffs = inf_coeff[country]
    r, h, l, i, k, j, m, q, n, o, p, g, a, c, b, d, f, e = coeffs.iloc[0:]
    
    expr1 = a*sp.tanh(b*x_i + c) + d*sp.tanh(e*x_j + f) + g        
    expr2 = h*sp.tanh(i*sp.tanh(j*x_i + k) + l) + m*sp.tanh(n*x_i**3 + o*x_i**2 + p*x_i + q) + r
    
    g_symb = make_callable(expr1)
    h_symb = make_callable(expr2)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model


y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_model_sw,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.9,
    device='cpu'
)

t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]

print("Mae Eul\n")

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig,
    countries_dict=countries_dict,
    build_symb_model=build_model_sw,
    tr_perc=0.9,
    step_size=epsilon,
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler,
    device='cpu'
)

all_res_covid_traj["GKAN + SW"] = (y_pred_gkan.copy(), "#a2c8e3")
all_res_covid_eul["GKAN + SW"] = (y_pred_gkan_jrn.copy(), "#a2c8e3")

### MPNN

In [None]:
def build_symb_model_mpnn(country, inf_coeff):
    
    coeffs = inf_coeff[country]
    a, _, b, _, c, = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2], coeffs.iloc[3], coeffs.iloc[4] 

    
    expr1 = sp.ln(sp.Max(sp.tan(x_i + c)**2 + 1, 1e-6))
    expr2 = a * sp.ln(sp.Max(x_i + b, 1e-6))
    
    g_symb = make_callable(expr1)
    h_symb = make_callable(expr2)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

In [None]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./inferred_coeffs/mpnn/inferred_coeffs_covid_ts.csv")

y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    scaler=scaler,
    use_euler=False,
    inferred_coeffs=inf_coeff_covid,
    tr_perc=0.9,
    device='cpu'
)

all_res_covid_traj["MPNN + GP"] = (y_pred_mpnn.copy(), "#fcb97d")

t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)

print("Mae Eul\n")

y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(
    data = data_real_epid_orig,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    tr_perc=0.9,
    step_size=epsilon,
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler,
    device='cpu'
)

all_res_covid_eul["MPNN + GP"] = (y_pred_mpnn_jrn.copy(), "#fcb97d")


### LLC

In [None]:
def build_symb_model_llc(country, inf_coeff):
    
    coeffs = inf_coeff[country]
    a, b, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2]

    expr1 = c*((x_i - x_j) * sp.exp(- x_j))
    expr2 = a * sp.tanh(x_i + b)
    
    g_symb = make_callable(expr1)
    h_symb = make_callable(expr2)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

In [None]:
scaler_covid = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./inferred_coeffs/llc/inferred_coeffs_covid_new.csv")

y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_covid,
    build_symb_model=build_symb_model_llc,
    scaler=scaler_covid,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")

t = real_epid_data.t_sampled
epsilon = t[0][1] - t[0][0]
    
y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(
    data = data_real_epid_orig,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_llc,
    inferred_coeffs=inf_coeff_covid,
    scaler=scaler_covid,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_covid_traj["LLC + GP"] = (y_pred_llc.copy(), "#34eb6e")
all_res_covid_eul["LLC + GP"] = (y_pred_llc_jrn.copy(), "#34eb6e")

## Generalization on H1N1 data 

In [None]:
real_epid_h1n1 = RealEpidemics(
    root = './data_real_epid_h1n1_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False,
    infection_data="./data/RealEpidemics/infected_numbers_H1N1.csv",
    inf_threshold=100
)

data_real_epid_orig_h1n1 = RealEpidemics(
    root = './data_real_epid_h1n1_orig',
    name = 'RealEpid',
    predict_deriv=True,
    scale=False,
    infection_data="./data/RealEpidemics/infected_numbers_H1N1.csv",
    inf_threshold=100
)

with open('./data_real_epid_h1n1_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)
    
all_res_h1n1_traj = {}
all_res_h1n1_eul = {}

### TSS 2

In [None]:
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_all_h1n1.csv").values

y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(
    data = real_epid_h1n1,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_h1n1,
    build_symb_model=build_symb_model_tss,
    use_euler=True,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_test_h1n1.csv").values

y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_tss,
    inferred_coeffs=inf_coeff_h1n1,
    tr_perc=0.9,
    step_size=1.0,
    device='cpu'
)


all_res_h1n1_traj["TP-SINDy"] = (y_pred_tss.copy(), "red")
all_res_h1n1_eul["TP-SINDy"] = (y_pred_tss_jrn.copy(), "red")

### GKAN

In [None]:
scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/gkan/inferred_coeffs_h1n1_ts.csv")

y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_h1n1,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_h1n1,
    build_symb_model=build_symb_model_gkan,
    scaler=scaler_h1n1,
    tr_perc=0.9,
    device='cpu'
)

all_res_h1n1_traj["GKAN + GP"] = (y_pred_gkan.copy(), "#5fa2d1")

print("\nMae Eul\n")

t = real_epid_h1n1.t_sampled
epsilon = t[0][1] - t[0][0]
scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)
print(epsilon)

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    inferred_coeffs=inf_coeff_h1n1,
    scaler=scaler_h1n1,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

### GKAN SW

In [None]:
scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/gkan_sw/inferred_coeffs_h1n1_sw.csv")

y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_h1n1,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_h1n1,
    build_symb_model=build_model_sw,
    scaler=scaler_h1n1,
    tr_perc=0.9,
    device='cpu'
)

print("\nMae Eul\n")
t = real_epid_h1n1.t_sampled
epsilon = t[0][1] - t[0][0]

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_model_sw,
    inferred_coeffs=inf_coeff_h1n1,
    scaler=scaler_h1n1,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_h1n1_traj["GKAN + SW"] = (y_pred_gkan.copy(), "#a2c8e3")
all_res_h1n1_eul["GKAN + SW"] = (y_pred_gkan_jrn.copy(), "#a2c8e3")

### MPNN

In [None]:
scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/mpnn/inferred_coeffs_h1n1_ts.csv")

y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(
    data = real_epid_h1n1,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_h1n1,
    build_symb_model=build_symb_model_mpnn,
    scaler=scaler_h1n1,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")

t = real_epid_h1n1.t_sampled
epsilon = t[0][1] - t[0][0]
    
y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    inferred_coeffs=inf_coeff_h1n1,
    scaler=scaler_h1n1,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_h1n1_traj["MPNN + GP"] = (y_pred_mpnn.copy(), "#fcb97d")
all_res_h1n1_eul["MPNN + GP"] = (y_pred_mpnn_jrn.copy(), "#fcb97d")


### LLC

In [None]:
scaler_h1n1 = get_scaler(data = real_epid_h1n1, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_h1n1 = pd.read_csv("./inferred_coeffs/llc/inferred_coeffs_h1n1_new.csv")

y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(
    data = real_epid_h1n1,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_h1n1,
    build_symb_model=build_symb_model_llc,
    scaler=scaler_h1n1,
    tr_perc=0.9,
    device='cpu'
)


print("Mae Eul\n")

t = real_epid_h1n1.t_sampled
epsilon = t[0][1] - t[0][0]
    
y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_h1n1,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_llc,
    inferred_coeffs=inf_coeff_h1n1,
    scaler=scaler_h1n1,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_h1n1_traj["LLC + GP"] = (y_pred_llc.copy(), "#34eb6e")
all_res_h1n1_eul["LLC + GP"] = (y_pred_llc_jrn.copy(), "#34eb6e")

## Generalization SARS Data

In [None]:
from datasets.RealEpidemics import RealEpidemics

real_epid_sars = RealEpidemics(
    root = './data_real_epid_sars_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False,
    infection_data="./data/RealEpidemics/infected_numbers_sars.csv",
    inf_threshold=100
)

data_real_epid_orig_sars = RealEpidemics(
    root = './data_real_epid_sars_orig',
    name = 'RealEpid',
    predict_deriv=True,
    scale=False,
    infection_data="./data/RealEpidemics/infected_numbers_sars.csv",
    inf_threshold=100
)

with open('./data_real_epid_sars_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)
    
all_res_sars_traj = {}
all_res_sars_eul = {}

### TSS2

In [None]:
inf_coeff_sars = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_all_sars.csv").values

y_true_tss, y_pred_tss, y_true_val_tss, y_pred_val_tss = eval_real_epid_int(
    data = real_epid_sars,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_sars,
    build_symb_model=build_symb_model_tss,
    use_euler=True,
    tr_perc=0.9,
    device='cpu'
)


print("Mae Eul\n")
inf_coeff_sars = pd.read_csv("./inferred_coeffs/tpsindy/inf_coeffs_test_sars.csv").values

y_true_tss_jrn, y_pred_tss_jrn, y_true_val_tss_jrn, y_pred_val_tss_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_sars,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_tss,
    inferred_coeffs=inf_coeff_sars,
    tr_perc=0.9,
    step_size=1.0,
    device='cpu'
)

all_res_sars_traj["TP-SINDy"] = (y_pred_tss.copy(), "red")
all_res_sars_eul["TP-SINDy"] = (y_pred_tss_jrn.copy(), "red")


### GKAN

In [None]:
scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_sars = pd.read_csv("./inferred_coeffs/gkan/inferred_coeffs_sars_ts.csv")

y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_sars,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_sars,
    build_symb_model=build_symb_model_gkan,
    scaler=scaler_sars,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")

t = real_epid_sars.t_sampled
epsilon = t[0][1] - t[0][0]

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_sars,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_gkan,
    inferred_coeffs=inf_coeff_sars,
    scaler=scaler_sars,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_sars_traj["GKAN + GP"] = (y_pred_gkan.copy(), "#5fa2d1")
all_res_sars_eul["GKAN + GP"] = (y_pred_gkan_jrn.copy(), "#5fa2d1")

### GKAN SW

In [None]:
scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_sars = pd.read_csv("./inferred_coeffs/gkan_sw/inferred_coeffs_sars_sw.csv")

y_true_gkan, y_pred_gkan, y_true_val_gkan, y_pred_val_gkan = eval_real_epid_int(
    data = real_epid_sars,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_sars,
    build_symb_model=build_model_sw,
    scaler=scaler_sars,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")

t = real_epid_sars.t_sampled
epsilon = t[0][1] - t[0][0]

y_true_gkan_jrn, y_pred_gkan_jrn, y_true_val_gkan_jrn, y_pred_val_gkan_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_sars,
    countries_dict=countries_dict,
    build_symb_model=build_model_sw,
    inferred_coeffs=inf_coeff_sars,
    scaler=scaler_sars,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_sars_traj["GKAN + SW"] = (y_pred_gkan.copy(), "#a2c8e3")
all_res_sars_eul["GKAN + SW"] = (y_pred_gkan_jrn.copy(), "#a2c8e3")

### MPNN

In [None]:
scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_sars = pd.read_csv("./inferred_coeffs/mpnn/inferred_coeffs_sars_ts.csv")

y_true_mpnn, y_pred_mpnn, y_true_val_mpnn, y_pred_val_mpnn = eval_real_epid_int(
    data = real_epid_sars,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_sars,
    build_symb_model=build_symb_model_mpnn,
    scaler=scaler_sars,
    tr_perc=0.9,
    device='cpu'
)

print("Mae Eul\n")

t = real_epid_sars.t_sampled
epsilon = t[0][1] - t[0][0]
    
y_true_mpnn_jrn, y_pred_mpnn_jrn, y_true_val_mpnn_jrn, y_pred_val_mpnn_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_sars,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_mpnn,
    inferred_coeffs=inf_coeff_sars,
    scaler=scaler_sars,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

    
all_res_sars_traj["MPNN + GP"] = (y_pred_mpnn.copy(), "#fcb97d")
all_res_sars_eul["MPNN + GP"] = (y_pred_mpnn_jrn.copy(), "#fcb97d")

### LLC

In [None]:
scaler_sars = get_scaler(data = real_epid_sars, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_sars = pd.read_csv("./inferred_coeffs/llc/inferred_coeffs_sars_new.csv")

y_true_llc, y_pred_llc, y_true_val_llc, y_pred_val_llc = eval_real_epid_int(
    data = real_epid_sars,
    countries_dict=countries_dict,
    inferred_coeffs=inf_coeff_sars,
    build_symb_model=build_symb_model_llc,
    scaler=scaler_sars,
    tr_perc=0.9,
    device='cpu'
)


print("\nMae Eul\n")

t = real_epid_sars.t_sampled
epsilon = t[0][1] - t[0][0]
    
y_true_llc_jrn, y_pred_llc_jrn, y_true_val_llc_jrn, y_pred_val_llc_jrn = eval_real_epid_journal(
    data = data_real_epid_orig_sars,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_llc,
    inferred_coeffs=inf_coeff_sars,
    scaler=scaler_sars,
    step_size=epsilon,
    tr_perc=0.9,
    device='cpu'
)

all_res_sars_traj["LLC + GP"] = (y_pred_llc.copy(), "#34eb6e")
all_res_sars_eul["LLC + GP"] = (y_pred_llc_jrn.copy(), "#34eb6e")