In [None]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [None]:
import jax.experimental.sparse as jsparse
import jax.numpy as jnp
from jax import device_put, random
from jax.lax import scan
import numpy as np

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from data.dataset import dataset_qtt
from linsolve.cg import ConjGrad
from linsolve.precond import llt_prec_trig_solve
from utils import iter_per_residual, jBCOO_to_scipyCSR

plt.rcParams['figure.figsize'] = (11, 7)
plt.rcParams['font.size'] = 20
plt.rcParams["lines.linewidth"] = 3

# Benachmark dataset

In [3]:
def benchmark_linsys(A, b, N=50):
    L0, Lt1, Lt5, Lt10 = [], [], [], []
    for i in range(N):
        A_i = jBCOO_to_scipyCSR(A[i, ...])
        
        L0_i = jsparse.BCOO.from_scipy_sparse(ilupp.ichol0(A_i))
        L0.append(L0_i[None, ...])
        L0nnz = (L0_i.nse*100) / (L0_i.shape[-1] ** 2)
        
        Lt1_i = jsparse.BCOO.from_scipy_sparse(ilupp.icholt(A_i, add_fill_in=1, threshold=1e-4))
        Lt1.append(Lt1_i[None, ...])
        Lt1nnz = (Lt1_i.nse*100) / (Lt1_i.shape[-1] ** 2)
        
        Lt5_i = jsparse.BCOO.from_scipy_sparse(ilupp.icholt(A_i, add_fill_in=5, threshold=1e-4))
        Lt5.append(Lt5_i[None, ...])
        Lt5nnz = (Lt5_i.nse*100) / (Lt5_i.shape[-1] ** 2)
        
        Lt10_i = jsparse.BCOO.from_scipy_sparse(ilupp.icholt(A_i, add_fill_in=10, threshold=1e-4))
        Lt10.append(Lt10_i[None, ...])
        Lt10nnz = (Lt10_i.nse*100) / (Lt10_i.shape[-1] ** 2)
        
    L0 = device_put(jsparse.bcoo_concatenate(L0, dimension=0))
    Lt1 = device_put(jsparse.bcoo_concatenate(Lt1, dimension=0))
    Lt5 = device_put(jsparse.bcoo_concatenate(Lt5, dimension=0))
    Lt10 = device_put(jsparse.bcoo_concatenate(Lt10, dimension=0))
    
    print('  Precs are combined')
    A_loc = A[:N, ...]
    b_loc = b[:N, ...]

    # I
    _, res = ConjGrad(A_loc, b_loc, N_iter=2000, prec_func=None, seed=42)
    i_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
    print('  I is done')

    # L0
    _, res = ConjGrad(A_loc, b_loc, N_iter=500, prec_func=partial(llt_prec_trig_solve, L=L0), seed=42)
    l0_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
    print('  L0 is done')
    
    # Lt1
    _, res = ConjGrad(A_loc, b_loc, N_iter=500, prec_func=partial(llt_prec_trig_solve, L=Lt1), seed=42)
    lt1_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
    print('  Lt1 is done')
    
    # Lt5
    _, res = ConjGrad(A_loc, b_loc, N_iter=500, prec_func=partial(llt_prec_trig_solve, L=Lt5), seed=42)
    lt5_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
    print('  Lt5 is done')
    
    # Lt10
    _, res = ConjGrad(A_loc, b_loc, N_iter=500, prec_func=partial(llt_prec_trig_solve, L=Lt10), seed=42)
    lt10_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
    print('  Lt10 is done')
    
    return i_tol, (L0nnz, l0_tol), (Lt1nnz, lt1_tol), (Lt5nnz, lt5_tol), (Lt10nnz, lt10_tol)

In [4]:
def benchmark_all(params_grid, N=50):
    for p in params_grid:
        print(p)
        A, _, b, _, _ = dataset_qtt(pde=p[0], grid=p[1], variance=p[2], lhs_type='fd', return_train=False)
        out = benchmark_linsys(A, b, N=N)
        print(f'\nCG:  {out[0]}')
        print(f'L0:  {out[1][1]}, nnz(L) = {out[1][0]:.6f}')
        print(f'Lt1: {out[2][1]}, nnz(L) = {out[2][0]:.6f}')
        print(f'Lt5: {out[3][1]}, nnz(L) = {out[3][0]:.6f}')
        print(f'Lt10: {out[4][1]}, nnz(L) = {out[4][0]:.6f}', end='\n------------------------------------------------------------\n\n')
    return

In [5]:
jnp.array([1])

2024-05-11 07:30:08.334728: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Array([1], dtype=int32)

In [None]:
params_grid = [
#     ['poisson',    32,  .1],
#     ['poisson',    64,  .1],
    ['poisson',    128, .1],
    ['div_k_grad', 32,  .1],
    ['div_k_grad', 64,  .1],
    ['div_k_grad', 128, .1],
    ['div_k_grad', 32,  .5],
    ['div_k_grad', 64,  .5],
    ['div_k_grad', 128, .5],
    ['div_k_grad', 32,  1.],
    ['div_k_grad', 64,  1.],
    ['div_k_grad', 128, 1.],
    ['div_k_grad', 32,  1.5],
    ['div_k_grad', 64,  1.5],
    ['div_k_grad', 128, 1.5],
]
benchmark_all(params_grid, N=50)

['poisson', 128, 0.1]
  Precs are combined


2024-05-11 07:30:16.390874: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.95GiB (18202411941 bytes) by rematerialization; only reduced to 24.46GiB (26260173148 bytes), down from 24.46GiB (26260173148 bytes) originally


In [None]:
2+3

---

---

# Check number of iterations

In [None]:
pde = 'div_k_grad'      # 'poisson', 'div_k_grad'
grid = 64            # 32, 64, 128
variance = 1.5        # 0.1, 0.5, 1.0 1.5

In [None]:
s1 = perf_counter()
A, _, b, _, _ = dataset_qtt(pde, grid, variance, lhs_type='fd', return_train=False)
dt_data = perf_counter() - s1

In [None]:
i = 101
fill_factor = 10
threshold = 1e-4

In [None]:
L0 = ilupp.ichol0(jBCOO_to_scipyCSR(A[i, ...]))
print(f'IC0 NNZ(L) = {(L0.nnz*100) / (L0.shape[-1] ** 2):.3f}%')

In [None]:
_, res_I = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=None, seed=42)
print('CG is done')

# L = ilupp.icholt(jBCOO_to_scipyCSR(A[i, ...]), add_fill_in=fill_factor, threshold=threshold)
L = L0
prec = partial(llt_prec_trig_solve, L=jsparse.BCOO.from_scipy_sparse(L)[None, ...])
print('Prec is generated', end='\n\n')

s_prec = perf_counter()
_, res = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=prec, seed=42)
dt_cg = perf_counter() - s_prec

plt.plot(range(res_I.shape[-1]), jnp.linalg.norm(res_I, axis=1).mean(0), label="CG")
plt.plot(range(res.shape[-1]), jnp.linalg.norm(res, axis=1).mean(0), label="PCG")
plt.xlabel('Iteration')
plt.ylabel('$\|res\|$')
plt.legend()
plt.yscale('log')
plt.grid();

cg_tol = iter_per_residual(jnp.linalg.norm(res_I, axis=1).mean(0))
pcg_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))

print(f'CG: {cg_tol}')
print(f'PCG: {pcg_tol}')
print(f'NNZ(L) = {(L.nnz*100) / (L.shape[-1] ** 2):.3f}%')