In [1]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = ''
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.numpy as jnp
from jax import random, vmap, clear_caches, jit
import numpy as np

import optax
from equinox.nn import Conv1d
import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter

from data.dataset import dataset_Krylov, dataset_FD
from linsolve.cg import ConjGrad
from linsolve.precond import llt_prec_trig_solve
from model import MessagePassing, FullyConnectedNet, PrecNet, ConstantConv1d, MessagePassingWithDot

from utils import params_count, asses_cond, iter_per_residual, batch_indices
from data.utils import direc_graph_from_linear_system_sparse
from train import train

# Setup experiment

In [3]:
dataset = 'krylov'                # 'krylov', 'simple'
grid = 64
N_samples_train = 30
N_samples_test = 5

In [4]:
rhs_train = rhs_test = 'random'           # 'random', 'laplace', [5, 5, 2]
k_train = k_test = 'poisson'           # 'random', 'poisson', [5, 5, 2]
rhs_offset_train = rhs_offset_test = 0
k_offset_train = k_offset_test = 10
lhs_type = 'ilu2'                        # 'fd', 'ilu0', 'ilu1', 'ilu2'

cg_repeats = 100
if dataset == 'simple': cg_repeats = 1

In [5]:
layer_ = ConstantConv1d         # 'ConstantConv1d' to make a "zero" NN initialization; 'Conv1d' to make a random initialization
loss_type = 'llt-res-norm'               # 'llt', 'llt-norm', 'notay', 'llt-res', 'llt-res-norm'
with_cond = True               # If True will calculate cond during training. Extremly bad scaling (materialization of matrix)
# with_final_cond = False         # If True will calculate cond with final L. Also bad scaling

loss_reduction = jnp.mean
# loss_reduction = jnp.mean if loss_type == 'notay' else jnp.sum

In [6]:
if (loss_type in {'notay', 'llt-res', 'llt-res-norm'} and dataset == 'simple') or (loss_type in {'llt', 'llt-norm'} and dataset == 'krylov'):
    raise ValueError('Not valid dataset for a chosen loss')

In [7]:
batch_size = 64
lr = 1e-3
epoch_num = 300

In [8]:
# # Uncomment and setup to make steps in learning rate

# steps_per_batch = N_samples_train * cg_repeats // batch_size
# start, stop, step = 45*steps_per_batch, 101*steps_per_batch, 45*steps_per_batch
# decay_size = 1e-1
# lr = optax.piecewise_constant_schedule(
#     lr,
# #     {k: v for k, v in zip([37], [1e-1])}
#     {k: v for k, v in zip(np.arange(start, stop, step), [decay_size, ] * len(jnp.arange(start, stop, step)))}
# )

# Make dataset

In [9]:
s1 = perf_counter()
if dataset == 'krylov':
    A_train, A_pad_train, b_train, u_exact_train, bi_edges_train, res_train, u_app_train = dataset_Krylov(grid, N_samples_train, seed=42, rhs_distr=rhs_train, rhs_offset=rhs_offset_train,
                                                                                             k_distr=k_train, k_offset=k_offset_train, cg_repeats=cg_repeats, lhs_type=lhs_type)
    A_test, A_pad_test, b_test, u_exact_test, bi_edges_test, res_test, u_app_test = dataset_Krylov(grid, N_samples_test, seed=43, rhs_distr=rhs_test, rhs_offset=rhs_offset_test,
                                                                                       k_distr=k_test, k_offset=k_offset_test, cg_repeats=cg_repeats, lhs_type=lhs_type)
elif dataset == 'simple':
    A_train, A_pad_train, b_train, u_exact_train, bi_edges_train = dataset_FD(grid, N_samples_train, seed=42, rhs_distr=rhs_train, rhs_offset=rhs_offset_train,
                                                                 k_distr=k_train, k_offset=k_offset_train, lhs_type=lhs_type)
    A_test, A_pad_test, b_test, u_exact_test, bi_edges_test = dataset_FD(grid, N_samples_test, seed=43, rhs_distr=rhs_test, rhs_offset=rhs_offset_test,
                                                             k_distr=k_test, k_offset=k_offset_test, lhs_type=lhs_type)
print(perf_counter() - s1)

CUDA backend failed to initialize: jaxlib/cuda/versions_helpers.cc:98: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


34.7735421359539


## Tests

In [112]:
from functools import partial
from jax import jit
from jax.lax import scan, cond
import jax.numpy as jnp

@partial(jit, static_argnums=(2))
def jspsolve_triangular(A, b, lower):
    '''A must be a lower/upper triangular matrix.
       It should be "valid": not singular (have no zeros on diagonal, no empty rows, etc.)'''
    A = A.sort_indices()
    Aval, bval = A.data, b
    Arows, Acols = A.indices[:, 0], A.indices[:, 1]
    x = jnp.zeros_like(bval)    

    diag_edge_indx = jnp.diff(jnp.hstack([Arows[:, None], Acols[:, None]]))
    diag_edge_indx = jnp.where(diag_edge_indx == 0, 1, 0)
    diag_edge_indx = jnp.nonzero(diag_edge_indx, size=bval.shape[0], fill_value=jnp.nan)[0].astype(jnp.int32)

    if lower:
        xs_ = jnp.hstack([
            jnp.arange(x.shape[0])[:, None],
            diag_edge_indx[:, None]
        ])
    else:
        xs_ = jnp.hstack([
            jnp.arange(x.shape[0]-1, -1, -1)[:, None],
            diag_edge_indx[::-1][:, None]
        ])
    def f_(carry, k):
        i, diag_ind = k
        Aval_, Arows_, Acols_, bval_, x_ = carry
        
#         x_i = x_.at[jnp.where(Arows_ == i, Acols_, Acols_.shape[0])].get(mode='fill', fill_value=0)
#         x_i = lax.cond(cond, f_1, f_2, operand) -- should try
        
        x_i = vmap(cond, in_axes=(0, None, None, 0), out_axes=(0))(Arows_ == i, lambda j: x_[j], lambda j: 0., Acols_)
        A_i = jnp.where(Arows_ == i, Aval_, 0)
        
        c = (bval_[i] - jnp.sum(A_i * x_i)) / (Aval_[diag_ind] + 1e-9)
        x_ = x_.at[i].set(c)
        return (Aval_, Arows_, Acols_, bval_, x_), None
    
    carry_ = (Aval, Arows, Acols, bval, x)
    (_, _, _, _, x), _ = scan(f_, carry_, xs_)
    return x

In [113]:
# from data.utils import direc_graph_from_linear_system_sparse, graph_tril, graph_to_low_tri_mat_sparse
# from jax.experimental import sparse as jsparse 

# nodes, edges, receivers, senders, _ = direc_graph_from_linear_system_sparse(A_train, b_train)
# print(edges.shape, receivers.shape, senders.shape)
# nodes, edges, receivers, senders = nodes[0, ...][None, ...], edges[0, ...], receivers[0, ...], senders[0, ...]
# nodes, edges, receivers, senders = graph_tril(nodes, edges, receivers, senders)
# print(edges.shape, receivers.shape, senders.shape)

# # abc_tril = vmap(graph_to_low_tri_mat_sparse, in_axes=(0, 0, 0, 0), out_axes=(0))(nodes[None, ...], edges[None, ...], receivers[None, ...], senders[None, ...])
# abc_tril = graph_to_low_tri_mat_sparse(nodes, edges, receivers, senders)
# # abc_tril.data = abc_tril.data * 0 + 1
# plt.imshow(abc_tril.todense());

In [115]:
clear_caches()

l_ = True
arr = abc_tril if l_ else jsparse.sparsify(lambda A: A.T)(abc_tril)

dt = 0
for _ in range(100):
    s_ = perf_counter()
    x1 = jspsolve_triangular(arr, b_train[0, :], lower=l_)
    dt += perf_counter() - s_

# display(x1)
print(f'time: {dt:.0f} sec')
jnp.abs(nodes[0, :] - arr @ x1).max()

time: 7 sec


Array(2.3841858e-07, dtype=float32)