In [1]:
import equinox as eqx
import optax 

import jax.numpy as jnp
from jax.nn import relu
from jax import random, jit, lax
import jraph

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

import os
import sys
from functools import partial

from data import direc_graph_from_linear_system, bi_direc_indx
from train import TrainerLLT
from loss import LLT_loss

from model import MLP, PrecNet
from mp_architecture import GraphNetwork_NodeEdge

In [2]:
sys.path.append('/mnt/local/data/vtrifonov/PNO')
from datasets.Elliptic import solvers

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [3]:
def random_lin_system(seed):
    keys = random.split(random.PRNGKey(seed), 3)
    A = jnp.eye(10, k=0)*random.randint(keys[0], [10], 1, 20) - jnp.eye(10, k=-1)*random.randint(keys[1], [10], 1, 20) - jnp.eye(10, k=1)*random.randint(keys[2], [10], 1, 20)
    b = random.randint(random.PRNGKey(40), [10], 0, 10)
    return A, b

def make_random_dataset(seeds: list):
    A_ls, b_ls, graph_ls, bi_edges_ls = [], [], [], []
    for i, s in enumerate(seeds):
        A, b = random_lin_system(s)
        graph = direc_graph_from_linear_system(A, b)
        indices = bi_direc_indx(graph)
        
        A_ls.append(A)
        b_ls.append(b)
        graph_ls.append(graph)
        bi_edges_ls.append(indices)
#         print(i)
    return A_ls, b_ls, graph_ls, bi_edges_ls

In [4]:
A, b, graph, bi_edges = make_random_dataset(jnp.arange(20))

2024-02-27 12:53:30.666901: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
x = b
X_train = X_test = [graph, x, b, bi_edges]

In [6]:
training_params = {'optimizer': optax.adamw, 'lr': 1e-3,
                   'optim_params':{'weight_decay': 2e-4}, 
                   'batch_size_train': 32, 'batch_size_test': 32,
                   'epoch_num': 100}

trainer = TrainerLLT(X_train, X_test, None, None, training_params, loss_function=LLT_loss)

In [7]:
seed = 42
NodeEncoder = MLP(features=[10, 16, 16], N_layers=2, key=random.PRNGKey(seed), act=relu)
EdgeEncoder = MLP(features=[28, 16, 16], N_layers=2, key=random.PRNGKey(seed), act=relu)
EdgeDecoder = MLP(features=[10, 16, 16], N_layers=2, key=random.PRNGKey(seed), act=relu)
mp_rounds = 3
MessagePass = GraphNetwork_NodeEdge(
    update_edge_fn = MLP(features=[48, 16, 16], N_layers=2, key=random.PRNGKey(seed), act=relu),
    update_node_fn = MLP(features=[33, 16, 16], N_layers=2, key=random.PRNGKey(seed), act=relu)
)

model = PrecNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
                   EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, mp_rounds=mp_rounds)

In [8]:
trainer(model, in_axes=((0, 0, 0, 0)), out_axes=((0, 0, 0, 0)))

0


TypeError: list indices must be integers or slices, not tuple

In [None]:
A[0].shape

In [10]:
a = [i for i in range(10)]
a[[0, 3, 4, 5, 9], [1, 2, 6, 7 , 8]]

TypeError: list indices must be integers or slices, not tuple

In [13]:
jnp.asarray(graph)

  return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore


ValueError: All input arrays must have the same shape.

In [16]:
batched_graph= jraph.batch(graph)

In [22]:
random.permutation(random.PRNGKey(42), jnp.asarray([1, 2, 3]))

Array([3, 1, 2], dtype=int64)

In [29]:
from itertools import permutations

list(permutations([1, 2, 3]))

[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

In [30]:
batched_graph.n_edge

Array([28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
       28, 28, 28], dtype=int64)