## Utils

In [None]:
from tsl.data.preprocessing.scalers import MinMaxScaler
from models.utils.MPNN import MPNN
from models.baseline.MPNN_ODE import MPNN_ODE
import torch
import sympytorch
from torch_geometric.data import Data
import sympy as sp
import copy
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np
from post_processing import get_model, make_callable, plot_predictions


def get_scaler(data, tr_perc = 0.8, scale_range = (-1, 1)):
    raw_data = data.raw_data_sampled.detach().cpu().numpy() # shape (IC, T, N, 1)
    tr_len = raw_data.shape[1]
    raw_data = raw_data[0, :int(tr_perc*tr_len), :, :]
    scaler = MinMaxScaler(out_range=scale_range)
    scaler.fit(raw_data.flatten())
    
    return scaler


  
def eval_real_epid_int(data, countries_dict, build_symb_model, scaler=None, use_euler=False):
    y_true = data[0].y.detach().cpu().numpy()
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name)
        # print(f"{country_name}")
        data_0 = data[0]
        if scaler is not None:
            tmp = scaler.transform(data[0].x)
            data_0 = data[0]
            data_0.x = tmp
        
        if use_euler:
            symb_model.integration_method = "euler"
            data_0.t_span = torch.arange(y_true.shape[0] + 1, device=data_0.x.device, dtype=data_0.t_span.dtype)
        
        try:
            pred = symb_model(data_0).detach().cpu().numpy()
        except AssertionError:
            print("Failed")
            continue
        
        if scaler is not None:
            pred = scaler.inverse_transform(pred)
        
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
        
    return y_true, y_pred 


import pandas as pd
# from torch_geometric.loader import DataLoader
# from sklearn.linear_model import LinearRegression 
from scipy.optimize import minimize



def fit_param_per_country(data, countries_dict, model_path, init_params, build_symb_model, tr_perc = 0.8, scaler=None,
                          use_euler=False):
    tr_len = data.raw_data_sampled.shape[1]
    def optim_fun(params):
        g_symb, h_symb = build_symb_model(params)
        symb_model = get_model(
            g = g_symb,
            h = h_symb,
            message_passing=False,
            include_time=False,
            integration_method='dopri5'
        )
        
        data_0 = data[0]
        if scaler is not None:
            tmp = scaler.transform(data[0].x)
            data_0 = data[0]
            data_0.x = tmp
            
        if use_euler:
            symb_model.integration_method = "euler"
            data_0.t_span = torch.arange(tr_len, device=data_0.x.device, dtype=data_0.t_span.dtype)
            
        out = symb_model(data_0)
        
        y_pred = out[:int(tr_len*tr_perc), node_idx, 0].detach().cpu().numpy()
        y_true = data[0].y[:int(tr_len*tr_perc), node_idx, 0].detach().cpu().numpy()
        if scaler is not None:
            y_true = scaler.transform(y_true)
        loss = mean_squared_error(y_true.flatten(), y_pred.flatten())
        return loss
        
    results_df = pd.DataFrame()
    # Optimization
    for country_name, node_idx in countries_dict.items():
        print(f"Processing {country_name}")
        result = minimize(optim_fun, init_params, method='L-BFGS-B')
        results_df[country_name] = result.x
        print(result.x)
    
    results_df.to_csv(f"{model_path}/inferred_coeff.csv")


# def fit_coeffs_per_country(data, countries_dict, model_path, symb_model:MPNN_ODE):
    
#     assert not symb_model.conv.model.message_passing, "This function works only for models of type H(x_i) + sum(G(x_i, x_j))"
    
#     T = len(data)
#     N = data[0].x.shape[0]
    
#     collate_fn = lambda samples_list: samples_list
#     train_loader = DataLoader(data, batch_size=T, shuffle=False, collate_fn=collate_fn)
#     all_data = next(iter(train_loader))
    
#     Y = torch.stack(
#         [d.y for d in data],
#         dim=0
#     )   # shape (T, N, 1)
    
#     # forward pass
#     _ = symb_model(all_data)
#     self_int_out = symb_model.conv.model.upduate_out
#     pairwise_int_out = symb_model.conv.model.message_out
    
#     self_int_out = torch.reshape(self_int_out, (T, N, -1)) 
#     pairwise_int_out = torch.reshape(pairwise_int_out, (T, N, -1))
    
#     def optim_fun(params):
#         a, b, = params
        
#         X_in = torch.stack([
#             self_int_out[:, node_idx, 0],
#             pairwise_int_out[:, node_idx, 0]
#         ], dim = 1)
        
#         y_true = Y[:, node_idx, 0].detach().cpu().numpy()
#         y_pred = (a*X_in[:, 0] + b*X_in[:, 1]).detach().cpu().numpy()
#         loss = mean_squared_error(y_true, y_pred)
#         return loss
        
    
#     results_df = pd.DataFrame()
#     for country_name, node_idx in countries_dict.items():
#         print(f"Processing {country_name}")
#         result = minimize(optim_fun, [1., 1.], method='L-BFGS-B')        
#         params = result.x  
#         print(params)
#         results_df[country_name] = params
        
#     results_df.to_csv(f"{model_path}/inferred_coeff.csv")
        



# def eval_real_epid_2(symb_model, data, scaler=None):
#     y_pred = []
#     y_true = []
#     raw_data = data.raw_data_sampled
#     edge_index, edge_attr = data[0].edge_index, data[0].edge_attr
#     symb_model.predict_deriv = True
    
#     for t in range(1, raw_data.shape[1] - 1):
        
#         x = raw_data[0, t, :, :] if scaler is None else scaler.transform(raw_data[0, t, :, :])
#         x0 = raw_data[0, 0, :, :] if scaler is None else scaler.transform(raw_data[0, 0, :, :])
        
#         fake_snap = Data(
#             edge_index=edge_index,
#             edge_attr=edge_attr,
#             x = x
#         )
#         x_next = x0 + symb_model(fake_snap)
        
#         y_true.append(raw_data[0, t+1, :, :])
#         y_pred.append(x_next)
    
#     y_true = torch.stack(y_true, dim=0).detach().cpu().numpy()
#     y_pred = torch.stack(y_pred, dim=0).detach().cpu().numpy()
    
#     if scaler is not None:
#         y_pred = scaler.inverse_transform(y_pred)
    
#     return y_true, y_pred
    
    
# def euler_int(symb_model, data_0, T):
# pred = []
# x = data_0.x[-1]
# edge_index, edge_attr = data_0.edge_index, data_0.edge_attr
# symb_model.conv.set_graph_attrs(edge_index, edge_attr)
# for _ in range(T):
#     x = x + symb_model.conv(t=torch.tensor([], device = x.device), x=x)
#     pred.append(x)

# return torch.stack(pred, dim=0)
    
    

In [None]:
from datasets.RealEpidemics import RealEpidemics

real_epid_data = RealEpidemics(
    root = './data_real_epid_covid_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False
)

In [None]:
import json

with open('./data_real_epid_covid_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)

## TSS

In [None]:
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/tss/real_epid_covid/Inferred_coefficients_covid.csv").values


def build_symb_model_2(country):
    x_i, x_j = sp.symbols('x_i x_j')    
    country_idx = countries_dict[country]

    g_symb = inf_coeff_covid[1, country_idx] / (1 + sp.exp(- (x_j - x_i)))
    h_symb = inf_coeff_covid[0, country_idx] * x_i

    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='euler'
    )
    
    return symb_model

y_true, y_pred = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_2,
    use_euler=True
)