In [8]:
list({'a': 1, 'b': 2, 'c': 3}.values())

[1, 2, 3]

In [1]:
import warnings
import sys
import os
import jax

jax.config.update("jax_enable_x64", True)
warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.experimental.sparse as jsparse
import jax.numpy as jnp
from jax import device_put, random
from jax.lax import scan
import numpy as np

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from data.dataset import dataset_qtt
from linsolve.cg import ConjGrad, ConjGradReduced
from linsolve.precond import llt_prec_trig_solve
from utils import iter_per_residual, jBCOO_to_scipyCSR

plt.rcParams['figure.figsize'] = (11, 7)
plt.rcParams['font.size'] = 20
plt.rcParams["lines.linewidth"] = 3

# Benchmark dataset

In [3]:
def benchmark_linsys(A, b, cg_iter, pcg_iter, small_linsys):
    L0, Lt1, Lt5= [], [], []
    L0_time, Lt1_time, Lt5_time = [], [], []
    Annz, L0nnz, Lt1nnz, Lt5nnz = [], [], [], []
    
    for i in range(A.shape[0]):
        A_i = jBCOO_to_scipyCSR(A[i, ...])
        Annz.append((A_i.nnz*100) / (A_i.shape[-1] ** 2))
        
        # IC(0)
        s = perf_counter()
        Lscipy = ilupp.ichol0(A_i)
        L0_time.append(perf_counter() - s)
        
        L0_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        L0.append(L0_i[None, ...])
        L0nnz.append((L0_i.nse*100) / (L0_i.shape[-1] ** 2))

        # ICt(1)
        s = perf_counter()
        Lscipy = ilupp.icholt(A_i, add_fill_in=1, threshold=1e-4)
        Lt1_time.append(perf_counter() - s)
         
        Lt1_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        Lt1.append(Lt1_i[None, ...])
        Lt1nnz.append((Lt1_i.nse*100) / (Lt1_i.shape[-1] ** 2))
        
        # ICt(5)
        s = perf_counter()
        Lscipy = ilupp.icholt(A_i, add_fill_in=5, threshold=1e-4)
        Lt5_time.append(perf_counter() - s)
        
        Lt5_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        Lt5.append(Lt5_i[None, ...])
        Lt5nnz.append((Lt5_i.nse*100) / (Lt5_i.shape[-1] ** 2))
        
    # Average time and nnz
    L0_time, Lt1_time, Lt5_time = np.mean(L0_time), np.mean(Lt1_time), np.mean(Lt5_time)
    Annz, L0nnz, Lt1nnz, Lt5nnz = np.mean(Annz), np.mean(L0nnz), np.mean(Lt1nnz), np.mean(Lt5nnz)
    
    # Assemble precs
    L0 = device_put(jsparse.bcoo_concatenate(L0, dimension=0))
    Lt1 = device_put(jsparse.bcoo_concatenate(Lt1, dimension=0))
    Lt5 = device_put(jsparse.bcoo_concatenate(Lt5, dimension=0))
    print('  Precs are combined')
    
    ## Run PCG
    if small_linsys:
        ## Can fit everythinf in single run
        # I
        _, res = ConjGradReduced(A, b, N_iter=cg_iter, prec_func=None, seed=42)
        i_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  I is done')

        # L0
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=L0), seed=42)
        l0_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  L0 is done')

        # Lt1
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt1), seed=42)
        lt1_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt1 is done')

        # Lt5
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt5), seed=42)
        lt5_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt5 is done')
    
    else:
        ## Cannot fit into GPU memory
        print('  Chosen big linsys')
        assert cg_iter % 500 == 0
        n_cg = int(cg_iter // 500)
        
        # I
        x_last, res = ConjGradReduced(A, b, N_iter=500, prec_func=None, seed=42)
        res_ls = [jnp.linalg.norm(res, axis=1).mean(0)[:-1]]
        print('  Loop starts')
        for i in range(n_cg-1):
            x_last, res = ConjGradReduced(A, b, N_iter=500, prec_func=None, seed=42, x0=x_last)
            if len(res_ls) == (n_cg - 1):
                res_ls.append(jnp.linalg.norm(res, axis=1).mean(0))
                del res
            else:
                res_ls.append(jnp.linalg.norm(res, axis=1).mean(0)[:-1])
                del res
        print('  Loop ends')
            
        res_ls = jnp.concatenate(res_ls, axis=0)
        i_tol = iter_per_residual(res_ls)
        del res_ls
        print('  I is done')

        # L0
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=L0), seed=42)
        l0_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  L0 is done')

        # Lt1
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt1), seed=42)
        lt1_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt1 is done')

        # Lt5
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt5), seed=42)
        lt5_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt5 is done')
        
    return (Annz, i_tol), (L0nnz, l0_tol, L0_time), (Lt1nnz, lt1_tol, Lt1_time), (Lt5nnz, lt5_tol, Lt5_time)

In [4]:
def benchmark_all(params_grid, N, cg_iter, pcg_iter, small_linsys):
    for p in params_grid:
        print(p)
        A, _, b, _, _ = dataset_qtt(pde=p[0], grid=p[1], variance=p[2], lhs_type='fd', return_train=False, N_samples=N, precision='f64')
        out = benchmark_linsys(A, b, cg_iter=cg_iter, pcg_iter=pcg_iter, small_linsys=small_linsys)
        print(f'\nCG:  {out[0][1]}, nnz(A) = {out[0][0]:.6f} %')
        print(f'L0:  {out[1][1]}, nnz(L) = {out[1][0]:.6f} %, prec creation time = {out[1][2]:.8f} sec')
        print(f'Lt1: {out[2][1]}, nnz(L) = {out[2][0]:.6f} %, prec creation time = {out[2][2]:.8f} sec')
        print(f'Lt5: {out[3][1]}, nnz(L) = {out[3][0]:.6f} %, prec creation time = {out[3][2]:.8f} sec', end='\n------------------------------------------------------------\n\n')
    return

In [5]:
params_grid = [
    ['poisson',    128, .1],
#     
    ['div_k_grad', 128, .1],
#     
    ['div_k_grad', 128, .5],
#     
    ['div_k_grad', 128, .7]
]
benchmark_all(params_grid, N=200, cg_iter=500, pcg_iter=300, small_linsys=False)

['poisson', 128, 0.1]


2024-07-12 05:16:17.802287: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  Chosen big linsys


2024-07-12 05:16:25.942243: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.94GiB (18192389843 bytes) by rematerialization; only reduced to 24.49GiB (26293043201 bytes), down from 24.49GiB (26293043201 bytes) originally
2024-07-12 05:16:26.444054: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.60GiB (17820387304 bytes) by rematerialization; only reduced to 36.88GiB (39596443344 bytes), down from 36.88GiB (39596443344 bytes) originally


  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 382, 1e-06: 472, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 115, 1e-06: 156, 1e-09: 183, 1e-12: 212}, nnz(L) = 0.018215 %, prec creation time = 0.00046763 sec
Lt1: {0.001: 75, 1e-06: 98, 1e-09: 117, 1e-12: 134}, nnz(L) = 0.024224 %, prec creation time = 0.00308485 sec
Lt5: {0.001: 42, 1e-06: 55, 1e-09: 65, 1e-12: 74}, nnz(L) = 0.048488 %, prec creation time = 0.00708209 sec
------------------------------------------------------------

['div_k_grad', 128, 0.1]
  Precs are combined
  Chosen big linsys


2024-07-12 09:45:40.230262: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.60GiB (17820387304 bytes) by rematerialization; only reduced to 36.88GiB (39596443344 bytes), down from 36.88GiB (39596443344 bytes) originally


  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: nan, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 133, 1e-06: 169, 1e-09: 198, 1e-12: 225}, nnz(L) = 0.018215 %, prec creation time = 0.00055635 sec
Lt1: {0.001: 81, 1e-06: 103, 1e-09: 121, 1e-12: 138}, nnz(L) = 0.024224 %, prec creation time = 0.00324600 sec
Lt5: {0.001: 43, 1e-06: 55, 1e-09: 65, 1e-12: 74}, nnz(L) = 0.048458 %, prec creation time = 0.00773732 sec
------------------------------------------------------------

['div_k_grad', 128, 0.5]
  Precs are combined
  Chosen big linsys


2024-07-12 14:17:47.729330: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.60GiB (17820387304 bytes) by rematerialization; only reduced to 36.88GiB (39596443344 bytes), down from 36.88GiB (39596443344 bytes) originally


  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: nan, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 149, 1e-06: 188, 1e-09: 220, 1e-12: 248}, nnz(L) = 0.018215 %, prec creation time = 0.00057650 sec
Lt1: {0.001: 91, 1e-06: 114, 1e-09: 134, 1e-12: 151}, nnz(L) = 0.024224 %, prec creation time = 0.00346401 sec
Lt5: {0.001: 46, 1e-06: 59, 1e-09: 69, 1e-12: 78}, nnz(L) = 0.048431 %, prec creation time = 0.00824350 sec
------------------------------------------------------------

['div_k_grad', 128, 0.7]
  Precs are combined
  Chosen big linsys


2024-07-12 18:47:45.630776: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.60GiB (17820387304 bytes) by rematerialization; only reduced to 36.88GiB (39596443344 bytes), down from 36.88GiB (39596443344 bytes) originally


  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: nan, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 159, 1e-06: 199, 1e-09: 233, 1e-12: 261}, nnz(L) = 0.018215 %, prec creation time = 0.00051987 sec
Lt1: {0.001: 97, 1e-06: 121, 1e-09: 142, 1e-12: 159}, nnz(L) = 0.024224 %, prec creation time = 0.00340330 sec
Lt5: {0.001: 49, 1e-06: 62, 1e-09: 72, 1e-12: 81}, nnz(L) = 0.048429 %, prec creation time = 0.00804459 sec
------------------------------------------------------------



In [7]:
2+3

5

In [5]:
params_grid = [
    ['poisson',    32,  .1],
    ['poisson',    64,  .1],
#     
    ['div_k_grad', 32,  .1],
    ['div_k_grad', 64,  .1],
#     
    ['div_k_grad', 32,  .5],
    ['div_k_grad', 64,  .5],
#     
    ['div_k_grad', 32,  .7],
    ['div_k_grad', 64,  .7],
]
benchmark_all(params_grid, N=200, cg_iter=2, pcg_iter=150, small_linsys=True)

['poisson', 32, 0.1]


2024-07-11 17:09:10.849808: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: nan, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.476074 %
L0:  {0.001: 32, 1e-06: 41, 1e-09: 50, 1e-12: 57}, nnz(L) = 0.286865 %, prec creation time = 0.00008857 sec
Lt1: {0.001: 20, 1e-06: 27, 1e-09: 32, 1e-12: 36}, nnz(L) = 0.378513 %, prec creation time = 0.00029969 sec
Lt5: {0.001: 12, 1e-06: 15, 1e-09: 18, 1e-12: 21}, nnz(L) = 0.758362 %, prec creation time = 0.00058063 sec
------------------------------------------------------------

['poisson', 64, 0.1]
  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: nan, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.120544 %
L0:  {0.001: 62, 1e-06: 79, 1e-09: 94, 1e-12: 109}, nnz(L) = 0.072479 %, prec creation time = 0.00016003 sec
Lt1: {0.001: 39, 1e-06: 50, 1e-09: 60, 1e-12: 69}, nnz(L) = 0.096136 %, prec creation time = 0.00082627 sec
Lt5: {0.001: 22, 1e-06: 28, 1e-09: 34, 1e-12: 38}, nnz(L) = 0.192547 %, prec

---

---

# Check number of iterations

In [None]:
pde = 'div_k_grad'      # 'poisson', 'div_k_grad'
grid = 64            # 32, 64, 128
variance = 1.5        # 0.1, 0.5, 1.0 1.5

In [None]:
s1 = perf_counter()
A, _, b, _, _ = dataset_qtt(pde, grid, variance, lhs_type='fd', return_train=False)
dt_data = perf_counter() - s1

In [None]:
i = 101
fill_factor = 10
threshold = 1e-4

In [None]:
L0 = ilupp.ichol0(jBCOO_to_scipyCSR(A[i, ...]))
print(f'IC0 NNZ(L) = {(L0.nnz*100) / (L0.shape[-1] ** 2):.3f}%')

In [None]:
_, res_I = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=None, seed=42)
print('CG is done')

# L = ilupp.icholt(jBCOO_to_scipyCSR(A[i, ...]), add_fill_in=fill_factor, threshold=threshold)
L = L0
prec = partial(llt_prec_trig_solve, L=jsparse.BCOO.from_scipy_sparse(L)[None, ...])
print('Prec is generated', end='\n\n')

s_prec = perf_counter()
_, res = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=prec, seed=42)
dt_cg = perf_counter() - s_prec

plt.plot(range(res_I.shape[-1]), jnp.linalg.norm(res_I, axis=1).mean(0), label="CG")
plt.plot(range(res.shape[-1]), jnp.linalg.norm(res, axis=1).mean(0), label="PCG")
plt.xlabel('Iteration')
plt.ylabel('$\|res\|$')
plt.legend()
plt.yscale('log')
plt.grid();

cg_tol = iter_per_residual(jnp.linalg.norm(res_I, axis=1).mean(0))
pcg_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))

print(f'CG: {cg_tol}')
print(f'PCG: {pcg_tol}')
print(f'NNZ(L) = {(L.nnz*100) / (L.shape[-1] ** 2):.3f}%')