In [1]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.experimental.sparse as jsparse
import jax.numpy as jnp
from jax import device_put, random
from jax.lax import scan
import numpy as np

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from data.dataset import dataset_qtt
from linsolve.cg import ConjGrad, ConjGradReduced
from linsolve.precond import llt_prec_trig_solve
from utils import iter_per_residual, jBCOO_to_scipyCSR

plt.rcParams['figure.figsize'] = (11, 7)
plt.rcParams['font.size'] = 20
plt.rcParams["lines.linewidth"] = 3

# Benachmark dataset

In [3]:
def benchmark_linsys(A, b, cg_iter, pcg_iter, small_linsys):
    L0, Lt1, Lt5= [], [], []
    L0_time, Lt1_time, Lt5_time = [], [], []
    Annz, L0nnz, Lt1nnz, Lt5nnz = [], [], [], []
    
    for i in range(A.shape[0]):
        A_i = jBCOO_to_scipyCSR(A[i, ...])
        Annz.append((A_i.nnz*100) / (A_i.shape[-1] ** 2))
        
        # IC(0)
        s = perf_counter()
        Lscipy = ilupp.ichol0(A_i)
        L0_time.append(perf_counter() - s)
        
        L0_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        L0.append(L0_i[None, ...])
        L0nnz.append((L0_i.nse*100) / (L0_i.shape[-1] ** 2))

        # ICt(1)
        s = perf_counter()
        Lscipy = ilupp.icholt(A_i, add_fill_in=1, threshold=1e-4)
        Lt1_time.append(perf_counter() - s)
         
        Lt1_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        Lt1.append(Lt1_i[None, ...])
        Lt1nnz.append((Lt1_i.nse*100) / (Lt1_i.shape[-1] ** 2))
        
        # ICt(5)
        s = perf_counter()
        Lscipy = ilupp.icholt(A_i, add_fill_in=5, threshold=1e-4)
        Lt5_time.append(perf_counter() - s)
        
        Lt5_i = jsparse.BCOO.from_scipy_sparse(Lscipy)
        Lt5.append(Lt5_i[None, ...])
        Lt5nnz.append((Lt5_i.nse*100) / (Lt5_i.shape[-1] ** 2))
        
    # Average time and nnz
    L0_time, Lt1_time, Lt5_time = np.mean(L0_time), np.mean(Lt1_time), np.mean(Lt5_time)
    Annz, L0nnz, Lt1nnz, Lt5nnz = np.mean(Annz), np.mean(L0nnz), np.mean(Lt1nnz), np.mean(Lt5nnz)
    
    # Assemble precs
    L0 = device_put(jsparse.bcoo_concatenate(L0, dimension=0))
    Lt1 = device_put(jsparse.bcoo_concatenate(Lt1, dimension=0))
    Lt5 = device_put(jsparse.bcoo_concatenate(Lt5, dimension=0))
    print('  Precs are combined')
    
    ## Run PCG
    if small_linsys:
        ## Can fit everythinf in single run
        # I
        _, res = ConjGradReduced(A, b, N_iter=cg_iter, prec_func=None, seed=42)
        i_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  I is done')

        # L0
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=L0), seed=42)
        l0_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  L0 is done')

        # Lt1
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt1), seed=42)
        lt1_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt1 is done')

        # Lt5
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt5), seed=42)
        lt5_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt5 is done')
    
    else:
        ## Cannot fit into GPU memory
        print('  Chosen big linsys')
        assert cg_iter % 500 == 0
        n_cg = int(cg_iter // 500)
        
        # I
        x_last, res = ConjGradReduced(A, b, N_iter=500, prec_func=None, seed=42)
        res_ls = [jnp.linalg.norm(res, axis=1).mean(0)[:-1]]
        print('  Loop starts')
        for i in range(n_cg-1):
            x_last, res = ConjGradReduced(A, b, N_iter=500, prec_func=None, seed=42, x0=x_last)
            if len(res_ls) == (n_cg - 1):
                res_ls.append(jnp.linalg.norm(res, axis=1).mean(0))
                del res
            else:
                res_ls.append(jnp.linalg.norm(res, axis=1).mean(0)[:-1])
                del res
        print('  Loop ends')
            
        res_ls = jnp.concatenate(res_ls, axis=0)
        i_tol = iter_per_residual(res_ls)
        del res_ls
        print('  I is done')

        # L0
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=L0), seed=42)
        l0_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  L0 is done')

        # Lt1
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt1), seed=42)
        lt1_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt1 is done')

        # Lt5
        _, res = ConjGradReduced(A, b, N_iter=pcg_iter, prec_func=partial(llt_prec_trig_solve, L=Lt5), seed=42)
        lt5_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))
        del res
        print('  Lt5 is done')
        
    return (Annz, i_tol), (L0nnz, l0_tol, L0_time), (Lt1nnz, lt1_tol, Lt1_time), (Lt5nnz, lt5_tol, Lt5_time)

In [4]:
def benchmark_all(params_grid, N, cg_iter, pcg_iter, small_linsys):
    for p in params_grid:
        print(p)
        A, _, b, _, _ = dataset_qtt(pde=p[0], grid=p[1], variance=p[2], lhs_type='fd', return_train=False, N_samples=N)
        out = benchmark_linsys(A, b, cg_iter=cg_iter, pcg_iter=pcg_iter, small_linsys=small_linsys)
        print(f'\nCG:  {out[0][1]}, nnz(A) = {out[0][0]:.6f} %')
        print(f'L0:  {out[1][1]}, nnz(L) = {out[1][0]:.6f} %, prec creation time = {out[1][2]:.8f} sec')
        print(f'Lt1: {out[2][1]}, nnz(L) = {out[2][0]:.6f} %, prec creation time = {out[2][2]:.8f} sec')
        print(f'Lt5: {out[3][1]}, nnz(L) = {out[3][0]:.6f} %, prec creation time = {out[3][2]:.8f} sec', end='\n------------------------------------------------------------\n\n')
    return

In [5]:
params_grid = [
#     ['poisson',    128, .1],
#     
#     ['div_k_grad', 128, .1],
#     
    ['div_k_grad', 128, .5],
#     
    ['div_k_grad', 128, .7]
]
benchmark_all(params_grid, N=200, cg_iter=4000, pcg_iter=350, small_linsys=False)

['div_k_grad', 128, 0.5]
  Precs are combined
  Chosen big linsys
  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 1759, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 149, 1e-06: 194, 1e-09: 246, 1e-12: 290}, nnz(L) = 0.018215 %, prec creation time = 0.00048764 sec
Lt1: {0.001: 91, 1e-06: 117, 1e-09: 150, 1e-12: 177}, nnz(L) = 0.024224 %, prec creation time = 0.00317652 sec
Lt5: {0.001: 47, 1e-06: 61, 1e-09: 78, 1e-12: 92}, nnz(L) = 0.048433 %, prec creation time = 0.00762909 sec
------------------------------------------------------------

['div_k_grad', 128, 0.7]
  Precs are combined
  Chosen big linsys
  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 2712, 1e-06: nan, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 160, 1e-06: 209, 1e-09: 261, 1e-12: 307}, nnz(L) = 0.018215 %, prec creation time = 0.00046787 sec
Lt1: {0.001: 97, 1e-06: 128, 1e-09: 159, 1e-12: 189

In [None]:
params_grid = [
#     ['poisson',    128, .1],
#     
    ['div_k_grad', 128, .1],
#     
    ['div_k_grad', 128, .5],
#     
    ['div_k_grad', 128, .7]
]
benchmark_all(params_grid, N=200, cg_iter=4000, pcg_iter=350, small_linsys=False)

['div_k_grad', 128, 0.1]
  Precs are combined
  Chosen big linsys
  Loop starts
  Loop ends
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 809, 1e-06: 1261, 1e-09: nan, 1e-12: nan}, nnz(A) = 0.030327 %
L0:  {0.001: 133, 1e-06: 170, 1e-09: 216, 1e-12: 258}, nnz(L) = 0.018215 %, prec creation time = 0.00069374 sec
Lt1: {0.001: 81, 1e-06: 104, 1e-09: 131, 1e-12: 158}, nnz(L) = 0.024224 %, prec creation time = 0.00467447 sec
Lt5: {0.001: 43, 1e-06: 55, 1e-09: 70, 1e-12: 84}, nnz(L) = 0.048458 %, prec creation time = 0.01073039 sec
------------------------------------------------------------

['div_k_grad', 128, 0.5]
  Precs are combined
  Chosen big linsys
  Loop starts
  Loop ends
  I is done


In [5]:
params_grid = [
    ['poisson',    32,  .1],
    ['poisson',    64,  .1],
#     
    ['div_k_grad', 32,  .1],
    ['div_k_grad', 64,  .1],
#     
    ['div_k_grad', 32,  .5],
    ['div_k_grad', 64,  .5],
#     
    ['div_k_grad', 32,  .7],
    ['div_k_grad', 64,  .7],
]
benchmark_all(params_grid, N=200, cg_iter=1800, pcg_iter=200, small_linsys=True)

['poisson', 32, 0.1]
  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 94, 1e-06: 118, 1e-09: 155, 1e-12: 185}, nnz(A) = 0.476074 %
L0:  {0.001: 32, 1e-06: 41, 1e-09: 53, 1e-12: 63}, nnz(L) = 0.286865 %, prec creation time = 0.00009640 sec
Lt1: {0.001: 20, 1e-06: 27, 1e-09: 34, 1e-12: 40}, nnz(L) = 0.378513 %, prec creation time = 0.00033292 sec
Lt5: {0.001: 12, 1e-06: 15, 1e-09: 20, 1e-12: 23}, nnz(L) = 0.758362 %, prec creation time = 0.00064327 sec
------------------------------------------------------------

['poisson', 64, 0.1]
  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done

CG:  {0.001: 189, 1e-06: 253, 1e-09: 318, 1e-12: 369}, nnz(A) = 0.120544 %
L0:  {0.001: 61, 1e-06: 79, 1e-09: 102, 1e-12: 122}, nnz(L) = 0.072479 %, prec creation time = 0.00019341 sec
Lt1: {0.001: 39, 1e-06: 50, 1e-09: 64, 1e-12: 76}, nnz(L) = 0.096136 %, prec creation time = 0.00108202 sec
Lt5: {0.001: 22, 1e-06: 28, 1e-09: 36, 1e-12: 43}, nnz(

---

---

# Check number of iterations

In [None]:
pde = 'div_k_grad'      # 'poisson', 'div_k_grad'
grid = 64            # 32, 64, 128
variance = 1.5        # 0.1, 0.5, 1.0 1.5

In [None]:
s1 = perf_counter()
A, _, b, _, _ = dataset_qtt(pde, grid, variance, lhs_type='fd', return_train=False)
dt_data = perf_counter() - s1

In [None]:
i = 101
fill_factor = 10
threshold = 1e-4

In [None]:
L0 = ilupp.ichol0(jBCOO_to_scipyCSR(A[i, ...]))
print(f'IC0 NNZ(L) = {(L0.nnz*100) / (L0.shape[-1] ** 2):.3f}%')

In [None]:
_, res_I = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=None, seed=42)
print('CG is done')

# L = ilupp.icholt(jBCOO_to_scipyCSR(A[i, ...]), add_fill_in=fill_factor, threshold=threshold)
L = L0
prec = partial(llt_prec_trig_solve, L=jsparse.BCOO.from_scipy_sparse(L)[None, ...])
print('Prec is generated', end='\n\n')

s_prec = perf_counter()
_, res = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=prec, seed=42)
dt_cg = perf_counter() - s_prec

plt.plot(range(res_I.shape[-1]), jnp.linalg.norm(res_I, axis=1).mean(0), label="CG")
plt.plot(range(res.shape[-1]), jnp.linalg.norm(res, axis=1).mean(0), label="PCG")
plt.xlabel('Iteration')
plt.ylabel('$\|res\|$')
plt.legend()
plt.yscale('log')
plt.grid();

cg_tol = iter_per_residual(jnp.linalg.norm(res_I, axis=1).mean(0))
pcg_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))

print(f'CG: {cg_tol}')
print(f'PCG: {pcg_tol}')
print(f'NNZ(L) = {(L.nnz*100) / (L.shape[-1] ** 2):.3f}%')