In [1]:
from utils.utils import *

Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


## Epidemics

### GKAN

In [7]:
model_path = './saved_models_optuna/model-epidemics-gkan/epidemics-new-data/2/gkan'

In [8]:
pysr_model = lambda : get_pysr_model(
    model_selection="score",
    random_state = 0,
    deterministic=True,
    parallelism='serial',
    n_iterations=150
)

symb_model_black_box = fit_black_box_from_kan(
    model_path=model_path,
    n_g_hidden_layers=2,
    n_h_hidden_layers=2,
    theta=-np.inf,
    pysr_model=pysr_model,
    sample_size=9000,
    message_passing=False
)



In [9]:
symb_model_black_box

\sum_{j}( 0.500805746365327*x_j*(0.9996049 - x_i)) - 0.49989313*x_i

### MLP-based baseline

In [11]:
model_path_mpnn = './saved_models_optuna/model-epidemics-mpnn/epidemics-new-data/1'

In [12]:
pysr_model = lambda : get_pysr_model(
    model_selection="score",
    random_state = 0,
    deterministic=True,
    parallelism='serial',
    n_iterations=150
)

mpnn_symb = fit_mpnn(
    model_path=model_path_mpnn,
    pysr_model=pysr_model,
    sample_size=9000,
    message_passing=False
)




In [13]:
mpnn_symb

\sum_{j}( x_i*(0.49942818 - 0.49865368*x_j)) - 0.5019439*x_i

## Test Experiments

In [1]:
import optuna
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

storage = JournalStorage(JournalFileBackend("optuna_journal_storage.log"))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
optuna.delete_study(study_name="model-epidemics-gkan-test-new-dataset", storage=storage)

In [4]:
%time
from utils.utils import load_config
from experiments.experiments_gkan import ExperimentsGKAN


config_path = './configs/config_epidemics.yml'
config = load_config(config_path)


exp = ExperimentsGKAN(
    config=config,
    n_trials=2,
    study_name='test-new-dataset',
    process_id=0,
    store_to_sqlite = False
)

exp.run()


CPU times: user 1 μs, sys: 0 ns, total: 1 μs
Wall time: 5.25 μs
Builing the dataset...


Processing...
Done!
[I 2025-04-04 16:52:29,119] A new study created in Journal with name: model-epidemics-gkan-test-new-dataset
[I 2025-04-04 16:52:54,623] Trial 0 finished with value: 1.7817941625253297e-05 and parameters: {'lr': 0.004328450221293881, 'lamb': 0.0, 'batch_size': 32, 'use_orig_reg': False, 'lamb_g_net': 0.0002481040974867811, 'lamb_h_net': 4.2079886696066345e-06, 'grid_size_g_net': 5, 'spline_order_g_net': 1, 'range_limit_g_net': 3, 'mu_1_g_net': 0.8, 'mu_2_g_net': 0.9, 'hidden_dim_g_net': 1, 'grid_size_h_net': 10, 'spline_order_h_net': 3, 'range_limit_h_net': 1, 'mu_1_h_net': 0.6, 'mu_2_h_net': 0.6, 'hidden_dim_h_net': 1}. Best is trial 0 with value: 1.7817941625253297e-05.
[I 2025-04-04 16:53:16,575] Trial 1 finished with value: 2.9958915547467768e-05 and parameters: {'lr': 0.0032877474139911193, 'lamb': 0.0, 'batch_size': 32, 'use_orig_reg': False, 'lamb_g_net': 1.461896279370496e-05, 'lamb_h_net': 0.0002801635158716264, 'grid_size_g_net': 5, 'spline_order_g_net': 1,

In [1]:
from tsl.datasets import MetrLA, PemsBay

In [2]:
dataset = MetrLA('./data/', impute_zeros=True)

download: 13.1MB [00:01, 6.73MB/s]                            
  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


In [4]:
edge_index, edge_attr = dataset.get_connectivity(
    threshold=0.1,
    include_self=False,
    normalize_axis=1,
    layout="edge_index"
)

In [7]:
edge_index.shape

(2, 1515)

In [8]:
edge_attr.shape

(1515,)

In [9]:
df = dataset.dataframe()

In [10]:
df.head()

nodes,773869,767541,767542,717447,717446,717445,773062,767620,737529,717816,...,772167,769372,774204,769806,717590,717592,717595,772168,718141,769373
channels,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2012-03-01 00:00:00,64.375,67.625,67.125,61.5,66.875,68.75,65.125,67.125,59.625,62.75,...,45.625,65.5,64.5,66.428574,66.875,59.375,69.0,59.25,69.0,61.875
2012-03-01 00:05:00,62.666668,68.555557,65.444443,62.444443,64.444443,68.111115,65.0,65.0,57.444443,63.333332,...,50.666668,69.875,66.666664,58.555557,62.0,61.111111,64.444443,55.888889,68.444443,62.875
2012-03-01 00:10:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0
2012-03-01 00:15:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0
2012-03-01 00:20:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0
