## Utils

In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [2]:
import torch
import random
import numpy as np

def set_pytorch_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.use_deterministic_algorithms(True)

set_pytorch_seed(0)

In [99]:
from tsl.data.preprocessing.scalers import MinMaxScaler
from models.utils.MPNN import MPNN
from models.baseline.MPNN_ODE import MPNN_ODE
import torch
import sympytorch
from torch_geometric.data import Data
import sympy as sp
import copy
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np
from post_processing import get_model, make_callable, plot_predictions
from torch.optim import LBFGS


def get_scaler(data, tr_perc = 0.8, scale_range = (-1, 1)):
    raw_data = data.raw_data_sampled.detach().cpu().numpy() # shape (IC, T, N, 1)
    tr_len = raw_data.shape[1]
    raw_data = raw_data[0, :int(tr_perc*tr_len), :, :]
    scaler = MinMaxScaler(out_range=scale_range)
    scaler.fit(raw_data.flatten())
    
    return scaler


  
def eval_real_epid_int(data, countries_dict, build_symb_model, scaler=None, use_euler=False):
    y_true = data[0].y.detach().cpu().numpy()
    y_pred = np.zeros_like(y_true)
    
    for country_name, node_idx in countries_dict.items():
        symb_model = build_symb_model(country_name)
        # print(f"{country_name}")
        data_0 = data[0]
        if scaler is not None:
            tmp = scaler.transform(data[0].x)
            data_0 = data[0]
            data_0.x = tmp
        
        if use_euler:
            symb_model.integration_method = "euler"
            data_0.t_span = torch.arange(y_true.shape[0] + 1, device=data_0.x.device, dtype=data_0.t_span.dtype)
        
        try:
            pred = symb_model(data_0).detach().cpu().numpy()
        except AssertionError:
            print("Failed")
            continue
        
        if scaler is not None:
            pred = scaler.inverse_transform(pred)
        
        y_pred[:, node_idx, :] = pred[:, node_idx, :]
        
    return y_true, y_pred 


import pandas as pd
from torch.optim import Adam
from torch_geometric.loader import DataLoader
from scipy.optimize import minimize
from typing import Dict
import optuna

    
def fit_param_per_country_gd(data_train, data_valid, countries_dict: Dict[str, int], model_path, build_symb_model, epochs=100, loss_fn=torch.nn.L1Loss(), device='cuda:0', lr=1e-3,
                              batch_size=64, patience=10, log=10):
    
    def get_predictions(batch_data, model, node_idx):
        y_pred = model(batch_data)
        y_true = batch_data.y

        y_pred = torch.reshape(y_pred, (batch_data.num_graphs, N, -1))
        y_true = torch.reshape(y_true, (batch_data.num_graphs, N, -1))

        return y_true[:, node_idx, :], y_pred[:, node_idx, :]

    def eval_model(model, valid_loader, node_idx, loss_fn):
        model.eval()
        y_pred = []
        y_true = []
        with torch.no_grad():
            for batch_valid in valid_loader:
                y_true_b, y_pred_b = get_predictions(
                    batch_data=batch_valid,
                    model=model,
                    node_idx=node_idx
                )
                y_true.append(y_true_b)
                y_pred.append(y_pred_b)

            y_pred = torch.cat(y_pred, dim=0)
            y_true = torch.cat(y_true, dim=0)

            valid_loss = loss_fn(y_pred, y_true)

        return valid_loss.item()

    N, _ = data_train[0].x.shape
    node_models = [build_symb_model().to(device) for _ in range(N)]

    for model in node_models:
        for param in model.parameters():
            param.requires_grad_ = True

    optimizers = [
        LBFGS(model.parameters(), lr=lr, line_search_fn="strong_wolfe",
              tolerance_grad=1e-32, tolerance_change=1e-32)
        for model in node_models
    ]

    results_df = pd.DataFrame()

    for country_name, node in countries_dict.items():
        train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(data_valid, batch_size=len(data_valid), shuffle=False)
        best_val_loss = float('inf')
        best_epoch = 0
        best_model_state = None

        model = node_models[node]
        optimizer = optimizers[node]

        print(f"\nProcessing country {country_name}")
        for epoch in range(epochs):
            train_node_loss = 0.
            count = 0
            model.train()

            for batch_data in train_loader:
                batch_data = batch_data.to(device)
                count += 1

                def closure():
                    optimizer.zero_grad()
                    y_true, y_pred = get_predictions(batch_data, model, node)
                    loss = loss_fn(y_pred, y_true)
                    loss.backward()
                    return loss

                loss = optimizer.step(closure)
                train_node_loss += loss.item()

            val_loss = eval_model(
                model=model,
                valid_loader=valid_loader,
                node_idx=node,
                loss_fn=loss_fn
            )

            if val_loss < best_val_loss:
                best_epoch = epoch
                best_val_loss = val_loss
                best_model_state = copy.deepcopy(model.state_dict())
            elif epoch - best_epoch > patience:
                print(f"Early stopping at epoch {epoch}")
                break

            if epoch % log == 0:
                print(f"Epoch {epoch}, train Loss: {train_node_loss / count:.4f}, valid loss: {val_loss:.4f}")

        model.load_state_dict(best_model_state)
        h_net = model.conv.model.h_net
        g_net = model.conv.model.g_net
        self_int_coeffs = torch.cat([p.detach().cpu().flatten() for p in h_net.parameters()]).numpy()
        pairwise_int_coeffs = torch.cat([p.detach().cpu().flatten() for p in g_net.parameters()]).numpy()
        coeffs = np.concatenate([self_int_coeffs, pairwise_int_coeffs])
        results_df[country_name] = coeffs
        print(f"Inferred coeffs for {country_name}: {coeffs}")

    results_df.to_csv(f"{model_path}/inferred_coeff.csv")
        


# def fit_param_per_country(data, countries_dict, model_path, build_symb_model,init_params,  tr_perc = 0.8, scaler=None,
#                           use_euler=False, min_method = "BFGS"):
#     tr_len = data.raw_data_sampled.shape[1]
#     def optim_fun(params):
#         g_symb, h_symb = build_symb_model(params)
#         symb_model = get_model(
#             g = g_symb,
#             h = h_symb,
#             message_passing=False,
#             include_time=False,
#             integration_method='rk4'
#         )
        
#         data_0 = data[0]
#         if scaler is not None:
#             tmp = scaler.transform(data[0].x)
#             data_0 = data[0]
#             data_0.x = tmp
            
#         if use_euler:
#             symb_model.integration_method = "euler"
#             data_0.t_span = torch.arange(tr_len, device=data_0.x.device, dtype=data_0.t_span.dtype)
            
#         out = symb_model(data_0)
        
#         y_pred = out[:int(tr_len*tr_perc), node_idx, 0].detach().cpu().numpy()
#         y_true = data[0].y[:int(tr_len*tr_perc), node_idx, 0].detach().cpu().numpy()
        
#         if scaler is not None:
#             y_pred = scaler.inverse_transform(y_pred)
#         loss = mean_squared_error(y_true.flatten(), y_pred.flatten())
#         return loss
        
#     results_df = pd.DataFrame()
    
#     # init_params=[-0.0039747115, 2.4683, 2.464987]
#     # Optimization
#     for country_name, node_idx in countries_dict.items():
#         print(f"Processing {country_name}")
#         result = minimize(optim_fun, init_params, method=min_method)
#         results_df[country_name] = result.x
#         print(result.x)
    
#     results_df.to_csv(f"{model_path}/inferred_coeff.csv")     


# def fit_coeffs_per_country(data, countries_dict, model_path, symb_model:MPNN_ODE):
    
#     assert not symb_model.conv.model.message_passing, "This function works only for models of type H(x_i) + sum(G(x_i, x_j))"
    
#     T = len(data)
#     N = data[0].x.shape[0]
    
#     collate_fn = lambda samples_list: samples_list
#     train_loader = DataLoader(data, batch_size=T, shuffle=False, collate_fn=collate_fn)
#     all_data = next(iter(train_loader))
    
#     Y = torch.stack(
#         [d.y for d in data],
#         dim=0
#     )   # shape (T, N, 1)
    
#     # forward pass
#     _ = symb_model(all_data)
#     self_int_out = symb_model.conv.model.upduate_out
#     pairwise_int_out = symb_model.conv.model.message_out
    
#     self_int_out = torch.reshape(self_int_out, (T, N, -1)) 
#     pairwise_int_out = torch.reshape(pairwise_int_out, (T, N, -1))
    
#     def optim_fun(params):
#         a, b, = params
        
#         X_in = torch.stack([
#             self_int_out[:, node_idx, 0],
#             pairwise_int_out[:, node_idx, 0]
#         ], dim = 1)
        
#         y_true = Y[:, node_idx, 0].detach().cpu().numpy()
#         y_pred = (a*X_in[:, 0] + b*X_in[:, 1]).detach().cpu().numpy()
#         loss = mean_squared_error(y_true, y_pred)
#         return loss
        
    
#     results_df = pd.DataFrame()
#     for country_name, node_idx in countries_dict.items():
#         print(f"Processing {country_name}")
#         result = minimize(optim_fun, [1., 1.], method='L-BFGS-B')        
#         params = result.x  
#         print(params)
#         results_df[country_name] = params
        
#     results_df.to_csv(f"{model_path}/inferred_coeff.csv")
        
    
    
# def euler_int(symb_model, data_0, T):
#     pred = []
#     x = data_0.x[-1]
#     edge_index, edge_attr = data_0.edge_index, data_0.edge_attr
#     symb_model.conv.set_graph_attrs(edge_index, edge_attr)
#     for _ in range(T):
#         x = data_0.x[-1] + x + symb_model.conv(t=torch.tensor([], device = x.device), x=x)
#         pred.append(x)

#     return torch.stack(pred, dim=0)
    
    

In [5]:
from datasets.RealEpidemics import RealEpidemics

real_epid_data = RealEpidemics(
    root = './data_real_epid_covid_int',
    name = 'RealEpid',
    predict_deriv=False,
    history=1,
    horizon=44,
    scale=False
)

In [6]:
import json

with open('./data_real_epid_covid_int/RealEpid/countries_dict.json', 'r') as f:
    countries_dict = json.load(f)

## TSS

In [21]:
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/tss/real_epid_covid/Inferred_coefficients_covid.csv").values


def build_symb_model_2(country):
    x_i, x_j = sp.symbols('x_i x_j')    
    country_idx = countries_dict[country]

    g_symb = inf_coeff_covid[1, country_idx] / (1 + sp.exp(- (x_j - x_i)))
    h_symb = inf_coeff_covid[0, country_idx] * x_i

    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='euler'
    )
    
    return symb_model

y_true, y_pred = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model_2,
    use_euler=True
)

In [22]:
print(f"MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

MAE: 12133.45703125


In [23]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/tss/real_epid_covid/figures"
    )

## GKAN

### GD

In [68]:
class symb_wrapper(torch.nn.Module):
    def __init__(self, symb, is_self_interaction = True):
        super().__init__()
        self.symb = symb
        self.is_self_interaction = is_self_interaction
    
    def forward(self, x):
        if self.is_self_interaction:
            return self.symb(x_i=x[:, 0])
        else:
            return self.symb(x_i=x[:, 0], x_j=x[:, 1])
        

In [100]:
model_path = "./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0"

# import random

def build_symb_gkan():
    x_i, x_j = sp.symbols('x_i x_j')

    # a = random.uniform(0.0, 3.0)
    # b = random.uniform(0.0, 3.0)
    # c = random.uniform(-0.005, 0.005)
    a = 2.4682064
    b = 2.4648788 
    c = -0.0039747115
    

    g_symb = sympytorch.SymPyModule(expressions=[sp.exp(c*x_j)])
    h_symb = sympytorch.SymPyModule(expressions=[a * x_i + b])
    
    g_symb = symb_wrapper(g_symb, is_self_interaction=False)
    h_symb = symb_wrapper(h_symb, is_self_interaction=True)

    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5',
        eval=False
    )
    

    symb_model.predict_deriv = True
    symb_model = symb_model.train()
    
    symb_model = symb_model.to('cuda:0')
    
    return symb_model

data_train = RealEpidemics(
    root = './data_real_epid_covid_scaled',
    name = 'RealEpid',
    predict_deriv=True,
    scale=True,
    scale_range=(-1, 1),
    train_perc=0.8,
    device='cuda:0'
)

tr_len = len(data_train)
tr_end = int(0.8 * tr_len)
train_set = data_train[:tr_end]
valid_set = data_train[tr_end:]

fit_param_per_country_gd(
    data_train=train_set,
    data_valid=valid_set,
    build_symb_model=build_symb_gkan,
    countries_dict=countries_dict,
    model_path=model_path,
    epochs=5,
    lr=1e-3,
    patience=5,
    log=1,
    batch_size=64
)




Processing country Iceland
Epoch 0, train Loss: 0.0180, valid loss: 0.0737
Epoch 1, train Loss: 0.0174, valid loss: 0.0737
Epoch 2, train Loss: 0.0174, valid loss: 0.0737
Epoch 3, train Loss: 0.0174, valid loss: 0.0737
Epoch 4, train Loss: 0.0174, valid loss: 0.0737
Inferred coeffs for Iceland: [ 2.462699    2.4703915  -0.00393136]

Processing country Canada
Epoch 0, train Loss: 0.1884, valid loss: 0.0940
Epoch 1, train Loss: 0.0294, valid loss: 0.0940
Epoch 2, train Loss: 0.0294, valid loss: 0.0940
Epoch 3, train Loss: 0.0294, valid loss: 0.0940
Epoch 4, train Loss: 0.0294, valid loss: 0.0940
Inferred coeffs for Canada: [2.3665886  2.5652099  0.01628246]

Processing country Algeria
Epoch 0, train Loss: 0.0427, valid loss: 0.0577
Epoch 1, train Loss: 0.0143, valid loss: 0.0577
Epoch 2, train Loss: 0.0143, valid loss: 0.0579
Epoch 3, train Loss: 0.0143, valid loss: 0.0579
Epoch 4, train Loss: 0.0143, valid loss: 0.0579
Inferred coeffs for Algeria: [ 2.4395778   2.4934804  -0.00253889]


### BFGS

In [None]:
# model_path = "./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0"
# scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
# x_i, x_j = sp.symbols('x_i x_j')

# def build_symb_gkan(params):
#     a = params[0]
#     g_symb = make_callable(sp.exp(a * x_j))
#     h_symb = make_callable(2.4682064*x_i + 2.4648788)
    
#     return g_symb, h_symb

# fit_param_per_country(
#     data=real_epid_data,
#     countries_dict=countries_dict,
#     model_path=model_path,
#     build_symb_model=build_symb_gkan,
#     tr_perc=0.8,
#     scaler = scaler,
#     use_euler=False,
#     min_method="L-BFGS-B",
#     init_params=[-0.0039747115]
#     )

Processing Iceland
[-0.00397471]
Processing Canada
[-0.00397471]
Processing Algeria
[-0.00397471]
Processing Burkina Faso
[-0.00397471]
Processing Ghana
[-0.00397471]
Processing Cote d'Ivoire
[-0.00397471]
Processing Niger
[-0.00397471]
Processing Tunisia
[-0.00397471]
Processing Belgium
[-0.00397471]
Processing Germany
[-0.00397471]
Processing Estonia
[-0.00397471]
Processing Ireland
[-0.00397471]
Processing Luxembourg
[-0.00397471]
Processing Norway
[-0.00397471]
Processing Poland
[-0.00397471]
Processing Sweden
[-0.00397471]
Processing South Africa
[-0.00397471]
Processing Cameroon
[-0.00397471]
Processing Mali
[-0.00397471]
Processing Spain
[-0.00397471]
Processing Morocco
[-0.00397471]
Processing Guinea
[-0.00397471]
Processing Somalia
[-0.00397471]
Processing Egypt
[-0.00397471]
Processing Albania
[-0.00397471]
Processing Bulgaria
[-0.00397471]
Processing Cyprus
[-0.00397471]
Processing Croatia
[-0.00397471]
Processing Greece
[-0.00397471]
Processing Hungary
[-0.00397471]
Process

### Plot predictions

In [101]:
scaler = get_scaler(data = real_epid_data, tr_perc=0.8)
x_i, x_j = sp.symbols('x_i x_j')
inf_coeff_covid = pd.read_csv("./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/inferred_coeff.csv")


def build_symb_model(country):
    x_i, x_j = sp.symbols('x_i x_j')    

    coeffs = inf_coeff_covid[country]
    a, b, c = coeffs.iloc[0], coeffs.iloc[1], coeffs.iloc[2] 
    g_symb = sp.exp(c*x_j)
    h_symb = b * x_i + a

    g_symb = make_callable(g_symb)
    h_symb = make_callable(h_symb)
    
    symb_model = get_model(
        g = g_symb,
        h = h_symb,
        message_passing=False,
        include_time=False,
        integration_method='dopri5'
    )
    
    return symb_model

y_true, y_pred = eval_real_epid_int(
    data = real_epid_data,
    countries_dict=countries_dict,
    build_symb_model=build_symb_model,
    scaler=scaler,
    use_euler=False
)

In [102]:
print(f"MAE: {mean_absolute_error(y_true.flatten(), y_pred.flatten())}")

MAE: 1040.9840087890625


In [103]:
for country, node_idx in countries_dict.items():
    plot_predictions(
        y_true=y_true,
        y_pred=y_pred,
        node_index=node_idx,
        show=False,
        title = country,
        save_path="./saved_models_optuna/model-real-epid-gkan/real_epid_gkan_7/0/figures"
    )