In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import equinox as eqx
import optax 

import jax
import jax.numpy as jnp
from jax.nn import relu
from jax import random, jit, lax
import jraph

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

import sys
from functools import partial
from itertools import chain

from data import direc_graph_from_linear_system, bi_direc_indx
from train import TrainerLLT
from loss import LLT_loss

from model import MessagePassing, FullyConnectedNet, PrecNet



from train_func import train, compute_loss

In [2]:
sys.path.append('/mnt/local/data/vtrifonov/PNO')
from datasets.Elliptic import solvers

from jax import config
config.update('jax_enable_x64', True)
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [3]:
def random_lin_system(seed):
    keys = random.split(random.PRNGKey(seed), 3)
    A = jnp.eye(10, k=0)*random.randint(keys[0], [10], 1, 20) - jnp.eye(10, k=-1)*random.randint(keys[1], [10], 1, 20) - jnp.eye(10, k=1)*random.randint(keys[2], [10], 1, 20)
    b = random.randint(random.PRNGKey(seed), [10], 0, 10)
    return A * 1., b * 1.

def make_random_dataset(seeds: list):
    A_ls, b_ls, bi_edges_ls = [], [], []
    nodes_ls, edges_ls, receivers_ls, senders_ls = [], [], [], []
    ls_of_ls = [A_ls, b_ls, bi_edges_ls, nodes_ls, edges_ls, receivers_ls, senders_ls]
    
    for i, s in enumerate(seeds):
        A, b = random_lin_system(s)
        nodes, edges, receivers, senders, n_node, n_edge = direc_graph_from_linear_system(A, b)
        indices = bi_direc_indx(receivers, senders, n_node)
        
        ls_of_val = [A, b, indices, nodes, edges, receivers, senders]
        for ls, val in zip(ls_of_ls, ls_of_val):
            ls.append(val)
            
    for i, ls in enumerate(ls_of_ls):
        ls_of_ls[i] = jnp.stack(ls)
    return ls_of_ls

In [4]:
A, b, bi_edges, nodes, edges, receivers, senders = make_random_dataset(jnp.arange(20))

2024-03-01 11:32:02.430868: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
# x = b
# X_train = X_test = (nodes, edges, receivers, senders, bi_edges, x, b)

In [6]:
# training_params = {'optimizer': optax.adamw, 'lr': 1e-3,
#                    'optim_params': {'weight_decay': 2e-4}, 
#                    'epoch_num': 10}

# trainer = TrainerLLT(X_train, X_test, jnp.array([1]), jnp.array([1]), training_params, loss_function=LLT_loss)

In [7]:
seed = 42
NodeEncoder = FullyConnectedNet(features=[1, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu)
EdgeEncoder = FullyConnectedNet(features=[1, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu)
EdgeDecoder = FullyConnectedNet(features=[15, 15, 1], N_layers=2, key=random.PRNGKey(seed), act=relu)
mp_rounds = 3
# MessagePass = MessagePassing(
#     update_edge_fn = lambda e, sent, rec: FullyConnectedNet(
#         features=[45, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu
#     )(jnp.concatenate([e, sent, rec], axis=0)),
    
#     update_node_fn = lambda n, sent, rec: FullyConnectedNet(
#         features=[45, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu
#     )(jnp.concatenate([n, sent, rec], axis=0)),
#     mp_rounds=mp_rounds
# )

MessagePass = MessagePassing(
    update_edge_fn = FullyConnectedNet(features=[45, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu),    
    update_node_fn = FullyConnectedNet(features=[45, 15, 15], N_layers=2, key=random.PRNGKey(seed), act=relu),
    mp_rounds=mp_rounds
)

model = PrecNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
                EdgeDecoder=EdgeDecoder, MessagePass=MessagePass)

In [8]:
# model_check = FullyConnectedNet(features=[1, 15, 1], N_layers=3, key=random.PRNGKey(seed), act=relu)

In [9]:
x = b
X_train = X_test = (nodes, edges, receivers, senders, bi_edges, x, b)
data = (X_train, X_test, jnp.array([1]), jnp.array([1]))

train_config = {'optimizer': optax.adamw, 'lr': 1e-3,
                'optim_params': {'weight_decay': 2e-4}, 
                'epoch_num': 10}

In [10]:
model = train(model, data, train_config, compute_loss)
# losses_ls = trainer(model)

  0%|          | 0/10 [00:00<?, ?it/s]

183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298
183.6775108829826 182.04373422280298


In [11]:
data = jnp.arange(15 * 28).reshape(15, 28)
segment_ids = jnp.array(list(chain.from_iterable([[i, i] for i in range(0, 14)])))
display(data.shape)
display(segment_ids.shape)

res = jax.ops.segment_sum(data.T, segment_ids)
res.T.shape

(15, 28)

(28,)

(15, 14)