In [1]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = ''
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import sparse as jsparse
from jax.scipy.sparse.linalg import cg as jcg

import numpy as np
from scipy.sparse.linalg import LinearOperator, cg
from scipy.sparse import tril, triu, eye, csr_matrix

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from linsolve.precond import llt_prec_trig_solve, llt_inv_prec
from utils import factorsILUp, batchedILUp, ILU2, jILU2, iter_per_residual, jBCOO_to_scipyCSR
from linsolve.cg import ConjGrad
from data.qtt import div_k_grad, scipy_validation, solve_precChol, solve_precLU

In [3]:
A, b, x, k_stats = div_k_grad(1, 32, 'gaussian', 1.)
print(k_stats)

A_ = jBCOO_to_scipyCSR(A[0, ...])
b_ = np.asarray(b[0, ...])

CUDA backend failed to initialize: jaxlib/cuda/versions_helpers.cc:98: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


{'min': 462.0, 'max': 462.0, 'mean': 462.0}


In [4]:
d_tol = scipy_validation(A_, b_, prec_f=lambda x: x)
print(f'CG: {d_tol}')

CG: {'1e-3': 262, '1e-6': 333, '1e-12': 450}


# ILU with recursive call of ILU(0)

In [5]:
L0, U0 = factorsILUp(A_, p=0)
L1, U1 = factorsILUp(A_, p=1)
L2, U2 = factorsILUp(A_, p=2)

In [6]:
print(f'L0.nnz = {L0.nnz}')
print(f'L1.nnz = {L1.nnz}')
print(f'L2.nnz = {L2.nnz}')
print()
print(f'(L1 - L0).nnz = {(L1 - L0).nnz}')
print(f'(L2 - L1).nnz = {(L2 - L1).nnz}')
print()
print(f'max((L1 - L0)) = {np.abs(L1 - L0).max()}')
print(f'max((L2 - L1)) = {np.abs(L2 - L1).max()}')

L0.nnz = 3008
L1.nnz = 3969
L2.nnz = 4046

(L1 - L0).nnz = 254
(L2 - L1).nnz = 306

max((L1 - L0)) = 1.1102230246251565e-16
max((L2 - L1)) = 1.1102230246251565e-16


In [7]:
L, U = factorsILUp(A_, p=0)
prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 43, '1e-6': 54, '1e-12': 72}
NNZ(L) = 0.287%


In [8]:
L, U = factorsILUp(A_, p=1)
prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 43, '1e-6': 54, '1e-12': 72}
NNZ(L) = 0.379%


In [9]:
L, U = factorsILUp(A_, p=2)
prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 43, '1e-6': 54, '1e-12': 72}
NNZ(L) = 0.386%


# ILU(2) from Fortran

In [26]:
import jax
import jax.numpy as jnp
from jax import lax

def incomplete_lu_factorization(A, B, p):
    """
    Incomplete LU factorization of a matrix A in CSR format.

    Parameters:
    A (csr_matrix): Input matrix in CSR format
    p (int): Level of fill-in

    Returns:
    B (csr_matrix): Matrix B in CSR format after incomplete LU factorization
    """
    n = A.shape[0]

    def body(carry, i):
        k = jnp.arange(1, i)
        mask = (B[k, i] != 0) & (lev(B[k, i]) <= p)
        B_i = B[i, :]
        B_k = B[k, :]
        B_ik = B[k, i]
        B_i /= B_k[:, jnp.newaxis] * B_ik
        B_i -= jnp.sum(B_k[:, jnp.newaxis] * B_ik, axis=0)
        update_levels(B, i, k, p)
        return carry, B_i

    init = jnp.zeros((n,))
    B, _ = lax.scan(body, init, jnp.arange(2, n))
    B = jnp.where(lev(B) <= p, B, 0)
    return B

def lev(aij):
    # Define a function to compute the level of fill-in for an element aij
    # This function is not specified in the algorithm, so I assume it's a simple function
    # that returns 0 for all nonzero elements
    return 0 if aij != 0 else float('inf')

def update_levels(B, i, k, p):
    # Define a function to update the levels of fill-in for the nonzero elements in row i
    # This function is not specified in the algorithm, so I assume it's a simple function
    # that updates the levels based on the formula levij = min{levij, levik + levkj + 1}
    for j in range(B.shape[1]):
        if B[i, j] != 0:
            B[i, j] = jnp.minimum(B[i, j], lev(B[k, j]) + lev(B[k, i]) + 1)
    return B

In [25]:
s = perf_counter()
# sILU = ILU2(A_)
ILUp = incomplete_lu_factorization(jsparse.BCSR.from_scipy_sparse(A_), jsparse.BCSR.from_scipy_sparse(A_), 0)
print(perf_counter() - s)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function body at /tmp/ipykernel_708849/4232006813.py:18 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [None]:
plt.imshow(ILUp.todense())

In [12]:
# LU = jBCOO_to_scipyCSR(A[0, ...])
# LU.data = sILU
# L = tril(LU, k=-1) + eye(A_.shape[0])
# U = triu(LU)

In [10]:
# Do not converge

# prec_f = partial(solve_precLU, L=L, U=U)
# d_tol = scipy_validation(A_, b_, prec_f)
# print(f'CG with from matlab code ILU(2): {d_tol}')

# IC(0)

In [11]:
L = ilupp.ichol0(A_)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 45, '1e-6': 57, '1e-12': 74}
NNZ(L) = 0.287%


# ICt

In [12]:
L = ilupp.icholt(A_, add_fill_in=1, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 27, '1e-6': 34, '1e-12': 44}
NNZ(L) = 0.378%


In [13]:
L = ilupp.icholt(A_, add_fill_in=2, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 20, '1e-6': 25, '1e-12': 33}
NNZ(L) = 0.475%


In [14]:
L = ilupp.icholt(A_, add_fill_in=5, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 12, '1e-6': 15, '1e-12': 22}
NNZ(L) = 0.755%


In [15]:
L = ilupp.icholt(A_, add_fill_in=10, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 8, '1e-6': 10, '1e-12': 15}
NNZ(L) = 1.178%


In [16]:
L = ilupp.icholt(A_, add_fill_in=15, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 5, '1e-6': 7, '1e-12': 11}
NNZ(L) = 1.529%


In [17]:
L = ilupp.icholt(A_, add_fill_in=20, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 5, '1e-6': 6, '1e-12': 10}
NNZ(L) = 1.776%


In [18]:
L = ilupp.icholt(A_, add_fill_in=30, threshold=1e-4)

prec_f = partial(solve_precChol, L=L, U=L)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 4, '1e-6': 6, '1e-12': 9}
NNZ(L) = 1.908%


# ILUt

In [19]:
L, U = ilupp.ilut(A_, fill_in=1, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 143, '1e-6': 179, '1e-12': 230}
NNZ(L) = 0.098%


In [20]:
L, U = ilupp.ilut(A_, fill_in=2, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 10240, '1e-6': 10240, '1e-12': 10240}
NNZ(L) = 0.195%


In [21]:
L, U = ilupp.ilut(A_, fill_in=3, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 10240, '1e-6': 10240, '1e-12': 10240}
NNZ(L) = 0.290%


In [22]:
L, U = ilupp.ilut(A_, fill_in=4, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 10240, '1e-6': 10240, '1e-12': 10240}
NNZ(L) = 0.384%


In [23]:
L, U = ilupp.ilut(A_, fill_in=5, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 24, '1e-6': 37, '1e-12': 73}
NNZ(L) = 0.479%


In [24]:
L, U = ilupp.ilut(A_, fill_in=6, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 19, '1e-6': 32, '1e-12': 86}
NNZ(L) = 0.574%


In [25]:
L, U = ilupp.ilut(A_, fill_in=7, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 16, '1e-6': 26, '1e-12': 87}
NNZ(L) = 0.668%


In [26]:
L, U = ilupp.ilut(A_, fill_in=10, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 11, '1e-6': 17, '1e-12': 33}
NNZ(L) = 0.950%


In [27]:
L, U = ilupp.ilut(A_, fill_in=15, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 7, '1e-6': 10, '1e-12': 18}
NNZ(L) = 1.392%


In [28]:
L, U = ilupp.ilut(A_, fill_in=20, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 5, '1e-6': 7, '1e-12': 12}
NNZ(L) = 1.760%


In [29]:
L, U = ilupp.ilut(A_, fill_in=30, threshold=1e-4)

prec_f = partial(solve_precLU, L=L, U=U)
d_tol = scipy_validation(A_, b_, prec_f)
print(f'PCG: {d_tol}')
print(f'NNZ(L) = {L.nnz*100 / L.shape[-1] ** 2:.3f}%')

PCG: {'1e-3': 3, '1e-6': 5, '1e-12': 8}
NNZ(L) = 2.165%
