In [1]:

import networkx as nx
from torch_geometric.utils import from_networkx
import torch
from utils.utils import create_datasets, load_config, pre_processing
import numpy as np

In [2]:
config = load_config(config_path='./configs/config_kuramoto.yml')

G = nx.complete_graph(10)
# G = nx.grid_2d_graph(7, 10)
train_data, t_train, valid_data, t_valid , test_data, t_test = create_datasets(config=config, graph=G)

In [3]:
train_data = pre_processing(train_data)
valid_data = pre_processing(valid_data)
test_data = pre_processing(test_data)

In [4]:
from models.KanGDyn import KanGDyn
from models.NetWrapper import NetWrapper

In [5]:
edge_index = from_networkx(G).edge_index

model_config = {
    'h_hidden_layers':[2, 1],
    'g_hidden_layers':[2,1,1],
    'grid_range':[-3, 3],
    'grid_size': 7,
    'model_path':'./saved_models/higher-order-kuramoto',
    'device':'cpu',
    'store_acts':True
}

model = NetWrapper(KanGDyn, model_config, edge_index, update_grid=False)

criterion = torch.nn.MSELoss()
lr = 0.001
mu_1 = 1.
mu_2 = 1.
lmbd = 0.0001
epochs = 300
opt = torch.optim.Adam(model.parameters(), lr=lr)

In [6]:
from train_and_eval import fit

In [7]:
_ = fit(
    model,
    train_data,
    t_train,
    valid_data,
    t_valid,
    test_data,
    t_test,
    epochs=epochs,
    patience=100,
    lr=lr,
    lmbd=lmbd,
    mu_1=mu_1,
    mu_2=mu_2,
    criterion=criterion,
    use_orig_reg=True,
    save_updates=True
)

Epoch: 0 	 Training loss: 0.07325 	 Val Loss: 0.05443 	 Tot Loss: 0.07358
Epoch: 10 	 Training loss: 0.06892 	 Val Loss: 0.05110 	 Tot Loss: 0.06925
Epoch: 20 	 Training loss: 0.06490 	 Val Loss: 0.04803 	 Tot Loss: 0.06523
Epoch: 30 	 Training loss: 0.06117 	 Val Loss: 0.04521 	 Tot Loss: 0.06149


KeyboardInterrupt: 

In [8]:
best_model_state = torch.load(f'./{model.model.model_path}/best_state_dict.pth', weights_only=False)

In [9]:
from utils.utils import plot, save_acts, sample_from_spatio_temporal_graph

In [10]:
model.load_state_dict(best_model_state)

net = model.model

net.h_net.store_act = True
net.g_net.store_act = True

dummy_x, dummy_edge_index = sample_from_spatio_temporal_graph(train_data, edge_index, sample_size=32)

with torch.no_grad():
    _ = net(dummy_x, dummy_edge_index)

plot(folder_path=f'{net.h_net.model_path}/figures', layers=net.h_net.layers, show_plots=False)
plot(folder_path=f'{net.g_net.model_path}/figures', layers=net.g_net.layers, show_plots=False)

save_acts(layers=net.h_net.layers, folder_path=f'{net.h_net.model_path}/cached_acts')
save_acts(layers=net.g_net.layers, folder_path=f'{net.g_net.model_path}/cached_acts')


<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>