In [None]:
import jax.numpy as jnp
import torch
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from dlpack import asdlpack

import os
import sys
import random
from functools import partial

from data import direc_graph_from_linear_system, bi_direc_indx

In [None]:
sys.path.append('/mnt/local/data/vtrifonov/PNO')
from datasets.Elliptic import solvers

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [None]:
def discret_linear_operator(N_points, F, BCs, M_int=8, conservative=False):
    coeff = solvers.BVP_DD(N_points-2, F, BCs, M_int=M_int, conservative=conservative)
    A = jnp.diag(coeff[:, 1]) + jnp.diag(coeff[1:, 0], k=-1) + jnp.diag(coeff[:-1, 2], k=+1)
    f = coeff[:, -1]
    return A, f

a = lambda x: jnp.ones_like(x)
d = lambda x: jnp.zeros_like(x)
c = lambda x: jnp.zeros_like(x)
f = lambda x: -2*jnp.cos(jnp.pi*x)*jnp.sin(jnp.pi*x)*jnp.pi - jnp.exp(-x)*x**2 + jnp.sin(jnp.pi*x)*0.1

F = [a, d, c, f]
BCs = [0, 0]
N_points = 9

A_jax, f_jax = discret_linear_operator(N_points, F, BCs)
A = torch.from_dlpack(asdlpack(A_jax))
f_ = torch.from_dlpack(asdlpack(f_jax))

sol_fem_jax = solvers.solve_BVP(N_points, F, BCs)
sol_fem = torch.from_dlpack(asdlpack(sol_fem_jax))

In [1]:
import jax.numpy as jnp
# import equinox as eqx

from jax.nn import relu
from jax import random
import jraph

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ''

In [9]:
A = jnp.eye(10, k=0)*random.randint(random.PRNGKey(42), [10], 0, 10) - jnp.eye(10, k=-1)*random.randint(random.PRNGKey(43), [10], 0, 10) - jnp.eye(10, k=1)*random.randint(random.PRNGKey(41), [10], 0, 10)
b = random.randint(random.PRNGKey(40), [10], 0, 10)
A = A.at[1, 7].set(2)
A = A.at[7, 1].set(9)
display(A)
b

Array([[ 5., -1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [-8.,  2., -8.,  0.,  0.,  0.,  0.,  2.,  0.,  0.],
       [ 0., -8.,  3., -5.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0., -2.,  7., -8.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0., -5.,  5.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  5., -6.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0., -1.,  9., -2.,  0.,  0.],
       [ 0.,  9.,  0.,  0.,  0.,  0., -1.,  6., -5.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0., -6.,  2., -1.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -4.,  9.]], dtype=float32)

Array([2, 4, 6, 0, 2, 0, 2, 2, 4, 4], dtype=int32)

In [10]:
from data import *

In [11]:
g = direc_graph_from_linear_system(A, b)

In [12]:
ls = bi_direc_indx(g)

In [13]:
display(g.senders)
g.receivers

Array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7, 7,
       7, 8, 8, 8, 9, 9], dtype=int32)

Array([0, 1, 0, 1, 2, 7, 1, 2, 3, 2, 3, 4, 3, 4, 5, 6, 5, 6, 7, 1, 6, 7,
       8, 7, 8, 9, 8, 9], dtype=int32)

In [14]:
ls

Array([[ 1,  2],
       [ 4,  6],
       [ 5, 19],
       [ 8,  9],
       [11, 12],
       [15, 16],
       [18, 20],
       [22, 23],
       [25, 26]], dtype=int32)

In [7]:
display(check.senders)
check.receivers

Array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8,
       8, 8, 9, 9], dtype=int32)

Array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5, 6, 5, 6, 7, 6, 7, 8, 7,
       8, 9, 8, 9], dtype=int32)

In [25]:
check.receivers

Array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5, 6, 5, 6, 7, 6, 7, 8, 7,
       8, 9, 8, 9], dtype=int32)

In [28]:
a = check.receivers[check.senders == 4]
display(a)
jnp.nonzero(a==3)[0]

Array([3, 4], dtype=int32)

Array([0], dtype=int32)

In [29]:
check.receivers

Array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5, 6, 5, 6, 7, 6, 7, 8, 7,
       8, 9, 8, 9], dtype=int32)

In [17]:
display(check.senders)
jnp.nonzero(check.senders == 4)

Array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8,
       8, 8, 9, 9], dtype=int32)

(Array([11, 12], dtype=int32),)

In [20]:
# First senders index
jnp.nonzero(check.senders == 4)[0][0]

# To what first edge is connected
abc = check.receivers[check.senders == 4]

# Index within first node connection
jnp.nonzero(a==3)[0]

Array(11, dtype=int32)

In [18]:
a = jnp.vstack([check.senders, check.receivers]).T
a[a == jnp.asarray([0, 1])]

Array([0, 0, 1, 1, 1], dtype=int32)

In [107]:
(check.senders == 3).astype(int)

Array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0], dtype=int32)

In [10]:
check.receivers[check.senders == 3]

Array([2, 3, 4], dtype=int32)

In [101]:
display(jnp.where(check.senders == 3)[0])
display(check.receivers[jnp.where(check.senders == 3)[0]])
2 in check.receivers[jnp.where(check.senders == 3)[0]]

Array([ 8,  9, 10], dtype=int32)

Array([2, 3, 4], dtype=int32)

True

In [4]:
def direc_graph_from_linear_system(A, b):
    '''Matrix `A` should be sparse.'''
    node_features = jnp.asarray(b)
    senders, receivers = jnp.nonzero(A)
    edge_features = A[senders, receivers]
    n_node = jnp.array([len(node_features)])
    n_edge = jnp.array([len(senders)])
    graph = jraph.GraphsTuple(nodes=node_features, edges=edge_features, senders=senders,
                              receivers=receivers, n_node=n_node, n_edge=n_edge, globals=None)
    return graph

In [5]:
check = direc_graph_from_linear_system(A, b)

In [6]:
check

GraphsTuple(nodes=Array([2, 4, 6, 0, 2, 0, 2, 2, 4, 4], dtype=int32), edges=Array([ 5., -1., -8.,  2., -8., -8.,  3., -5., -2.,  7., -8., -5.,  5.,
        5., -6., -1.,  9., -2., -1.,  6., -5., -6.,  2., -1., -4.,  9.],      dtype=float32), receivers=Array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5, 6, 5, 6, 7, 6, 7, 8, 7,
       8, 9, 8, 9], dtype=int32), senders=Array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8,
       8, 8, 9, 9], dtype=int32), globals=None, n_node=Array([10], dtype=int32), n_edge=Array([26], dtype=int32))

In [21]:
re = jnp.stack(jnp.array([[1,2], [2, 3], [4,5], [2,3]]))
re

Array([[1, 2],
       [2, 3],
       [4, 5],
       [2, 3]], dtype=int32)