In [1]:
import warnings
import sys
import os
import jax

# os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
jax.config.update("jax_enable_x64", True)
warnings.filterwarnings('ignore')
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.experimental.sparse as jsparse
import jax.numpy as jnp
from jax import device_put, random
from jax.lax import scan
import numpy as np

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from data.dataset import dataset_qtt
from linsolve.scipy_linsolve import batched_cg_scipy, make_Chol_prec
from utils import iter_per_residual, jBCOO_to_scipyCSR

plt.rcParams['figure.figsize'] = (11, 7)
plt.rcParams['font.size'] = 20
plt.rcParams["lines.linewidth"] = 3

In [3]:
A = jnp.zeros(10)

RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

# Benchmark dataset

In [None]:
def scipy_benchmark_linsys(A, b, cg_iter, pcg_iter, atol):
    L0, Lt1, Lt5 = [], [], []
    L0_time, Lt1_time, Lt5_time = [], [], []
    Annz, L0nnz, Lt1nnz, Lt5nnz = [], [], [], []
    
    for i in range(A.shape[0]):
        A_i = jBCOO_to_scipyCSR(A[i, ...])
        Annz.append((A_i.nnz*100) / (A_i.shape[-1] ** 2))
        
        # IC(0)
        s = perf_counter()
        L0_i = ilupp.ichol0(A_i)
        L0_time.append(perf_counter() - s)
        L0.append(L0_i)
        L0nnz.append((L0_i.nnz*100) / (L0_i.shape[-1] ** 2))

        # ICt(1)
        s = perf_counter()
        Lt1_i = ilupp.icholt(A_i, add_fill_in=1, threshold=1e-4)
        Lt1_time.append(perf_counter() - s)
        Lt1.append(Lt1_i)
        Lt1nnz.append((Lt1_i.nnz*100) / (Lt1_i.shape[-1] ** 2))
        
        # ICt(5)
        s = perf_counter()
        Lt5_i = ilupp.icholt(A_i, add_fill_in=5, threshold=1e-4)
        Lt5_time.append(perf_counter() - s)
        Lt5.append(Lt5_i)
        Lt5nnz.append((Lt5_i.nnz*100) / (Lt5_i.shape[-1] ** 2))
    
    ## Save precs props
    # Average time
    comb_L0_time = [np.mean(L0_time), np.std(L0_time)]
    comb_Lt1_time = [np.mean(Lt1_time), np.std(Lt1_time)]
    comb_Lt5_time  = [np.mean(Lt5_time), np.std(Lt5_time)]
    
    # Average nnz
    Annz = [np.mean(Annz), np.std(Annz)]
    L0nnz = [np.mean(L0nnz), np.std(L0nnz)]
    Lt1nnz = [np.mean(Lt1nnz), np.std(Lt1nnz)]
    Lt5nnz = [np.mean(Lt5nnz), np.std(Lt5nnz)]
    
    ## Assemble precs
    P_L0 = make_Chol_prec(L0)
    P_Lt1 = make_Chol_prec(Lt1)
    P_Lt5 = make_Chol_prec(Lt5)
    print('  Precs are combined')
        
    ## Run PCG
    # I
    if isinstance(cg_iter, int):
        _, I_iters_mean, I_iters_std, I_time_mean, I_time_std = batched_cg_scipy(A, b, P=None, atol=atol, maxiter=cg_iter)
    #     res_I_mean = list(iter_per_residual(res_ls.mean(0)).values())
    #     res_I_std = res_ls.std(0)[res_I_mean]
    #     time_I_mean = time_ls.mean(0)[res_I_mean]
    #     time_I_std = time_ls.std(0)[res_I_mean]
        print('  I is done')
    else:
        I_iters_mean, I_iters_std, I_time_mean, I_time_std = [-1]*4, [-1]*4, [-1]*4, [-1]*4
        print('  I is skipped')

    # L0
    _, L0_iters_mean, L0_iters_std, L0_time_mean, L0_time_std = batched_cg_scipy(A, b, P=P_L0, atol=atol, maxiter=pcg_iter)
#     res_L0_mean = list(iter_per_residual(res_ls.mean(0)).values())
#     res_L0_std = res_ls.std(0)[res_L0_mean]
#     time_L0_mean = time_ls.mean(0)[res_L0_mean]
#     time_L0_std = time_ls.std(0)[res_L0_mean]
    print('  L0 is done')

    # Lt1
    _, Lt1_iters_mean, Lt1_iters_std, Lt1_time_mean, Lt1_time_std = batched_cg_scipy(A, b, P=P_Lt1, atol=atol, maxiter=pcg_iter)
#     res_Lt1_mean = list(iter_per_residual(res_ls.mean(0)).values())
#     res_Lt1_std = res_ls.std(0)[res_Lt1_mean]
#     time_Lt1_mean = time_ls.mean(0)[res_Lt1_mean]
#     time_Lt1_std = time_ls.std(0)[res_Lt1_mean]
    print('  Lt1 is done')

    # Lt5
    _, Lt5_iters_mean, Lt5_iters_std, Lt5_time_mean, Lt5_time_std = batched_cg_scipy(A, b, P=P_Lt5, atol=atol, maxiter=pcg_iter)
#     res_Lt5_mean = list(iter_per_residual(res_ls.mean(0)).values())
#     res_Lt5_std = res_ls.std(0)[res_Lt5_mean]
#     time_Lt5_mean = time_ls.mean(0)[res_Lt5_mean]
#     time_Lt5_std = time_ls.std(0)[res_Lt5_mean]
    print('  Lt5 is done')
    out = (
        (Annz, I_iters_mean, I_iters_std, I_time_mean, I_time_std, [-1, -1]),
        (L0nnz, L0_iters_mean, L0_iters_std, L0_time_mean, L0_time_std, comb_L0_time),
        (Lt1nnz, Lt1_iters_mean, Lt1_iters_std, Lt1_time_mean, Lt1_time_std, comb_Lt1_time),
        (Lt5nnz, Lt5_iters_mean, Lt5_iters_std, Lt5_time_mean, Lt5_time_std, comb_Lt5_time),        
    )
    return out

In [None]:
# def combine_mean_std(arr_mean, arr_std, tol_ls=[1e-3, 1e-6, 1e-9, 1e-12]):
#     for i, t in enumerate(tol_ls):
#         mean_t, std_t = arr_mean[i], arr_std[i]
#         print(f'{t}: {mean_t:.1f}±{std_t:.3f}', end='; ')
#     print()
#     return

def print_results(prec_name, results, tol_ls=[1e-3, 1e-6, 1e-9, 1e-12]):
    nnz, res_mean, res_std = results[0], results[1], results[2]
    cg_time_mean, cg_time_std, prec_time = results[3], results[4], results[5]
    
    print(f'{prec_name}, nnz = {nnz[0]:.3f}±{nnz[1]:.3f}, prec time = {prec_time[0]:6f}±{prec_time[1]:6f}')
    print('Residuals:   ', end='')
    for i, t in enumerate(tol_ls):
        mean_t, std_t = res_mean[i], res_std[i]
        print(f'{t}: {mean_t:.0f}±{std_t:.2f}', end='; ')
    print()
    
    print('Time:        ', end='')
    for i, t in enumerate(tol_ls):
        mean_t, std_t = cg_time_mean[i], cg_time_std[i]
        print(f'{t}: {mean_t:.4f}±{std_t:.5f}', end='; ')
    print(end='\n\n')
    return

In [None]:
def scipy_benchmark_all(params_grid, N, cg_iter, pcg_iter, atol):
    for p in params_grid:
        print(p)
        A, _, b, _, _ = dataset_qtt(pde=p[0], grid=p[1], variance=p[2], lhs_type='fd', return_train=False, N_samples=N, precision='f64')
        out = scipy_benchmark_linsys(A, b, cg_iter=cg_iter, pcg_iter=pcg_iter, atol=atol)
        for i, name in enumerate(['I', 'L0', 'Lt1', 'Lt5']):
            print_results(name, out[i])
        print('------------------------------------------------------------------------------------------------\n')
    return

In [None]:
params_grid = [
#     ['poisson',    32,  .1],
#     ['poisson',    64,  .1],
#     ['poisson',    128, .1],
# # 
#     ['div_k_grad', 32,  .1],
#     ['div_k_grad', 64,  .1],
#     ['div_k_grad', 128, .1],
# #     
#     ['div_k_grad', 32,  .5],
#     ['div_k_grad', 64,  .5],
#     ['div_k_grad', 128, .5],
# #
#     ['div_k_grad', 32,  .7],
#     ['div_k_grad', 64,  .7],
    ['div_k_grad', 128, .7]
]

In [7]:
scipy_benchmark_all(params_grid, N=200, cg_iter=None, pcg_iter=350, atol=0)

['div_k_grad', 128, 0.7]


2024-09-22 10:46:31.603035: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  I is skipped
  L0 is done
  Lt1 is done
  Lt5 is done
I, nnz = 0.030±0.000, prec time = -1.000000±-1.000000
Residuals:   0.001: -1±-1.00; 1e-06: -1±-1.00; 1e-09: -1±-1.00; 1e-12: -1±-1.00; 
Time:        0.001: -1.0000±-1.00000; 1e-06: -1.0000±-1.00000; 1e-09: -1.0000±-1.00000; 1e-12: -1.0000±-1.00000; 

L0, nnz = 0.018±0.000, prec time = 0.000511±0.000056
Residuals:   0.001: 156±5.64; 1e-06: 196±5.54; 1e-09: 228±5.74; 1e-12: 255±5.71; 
Time:        0.001: 1.0699±0.18899; 1e-06: 1.3378±0.22872; 1e-09: 1.5534±0.26056; 1e-12: 1.7380±0.28744; 

Lt1, nnz = 0.024±0.000, prec time = 0.003794±0.000377
Residuals:   0.001: 95±3.50; 1e-06: 119±3.40; 1e-09: 139±3.48; 1e-12: 156±3.50; 
Time:        0.001: 0.7165±0.09804; 1e-06: 0.8981±0.12170; 1e-09: 1.0442±0.14101; 1e-12: 1.1701±0.15705; 

Lt5, nnz = 0.048±0.000, prec time = 0.008916±0.000888
Residuals:   0.001: 48±1.65; 1e-06: 60±1.60; 1e-09: 70±1.78; 1e-12: 79±1.70; 
Time:        0.001: 0.5609±0.07046; 1e-06: 0.7054±0.0877

In [7]:
scipy_benchmark_all(params_grid, N=200, cg_iter=None, pcg_iter=270, atol=0)

['div_k_grad', 128, 0.1]


2024-09-22 09:34:27.168750: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  I is skipped
  L0 is done
  Lt1 is done
  Lt5 is done
I, nnz = 0.030±0.000, prec time = -1.000000±-1.000000
Residuals:   0.001: -1±-1.00; 1e-06: -1±-1.00; 1e-09: -1±-1.00; 1e-12: -1±-1.00; 
Time:        0.001: -1.0000±-1.00000; 1e-06: -1.0000±-1.00000; 1e-09: -1.0000±-1.00000; 1e-12: -1.0000±-1.00000; 

L0, nnz = 0.018±0.000, prec time = 0.000462±0.000083
Residuals:   0.001: 131±3.15; 1e-06: 168±1.93; 1e-09: 198±1.91; 1e-12: 224±2.04; 
Time:        0.001: 0.9912±0.15276; 1e-06: 1.2669±0.19113; 1e-09: 1.4882±0.22132; 1e-12: 1.6878±0.24817; 

Lt1, nnz = 0.024±0.000, prec time = 0.003374±0.000048
Residuals:   0.001: 80±1.87; 1e-06: 102±1.32; 1e-09: 120±1.09; 1e-12: 137±1.20; 
Time:        0.001: 0.5520±0.07420; 1e-06: 0.7057±0.09218; 1e-09: 0.8300±0.10881; 1e-12: 0.9421±0.12329; 

Lt5, nnz = 0.048±0.000, prec time = 0.007940±0.000072
Residuals:   0.001: 42±1.07; 1e-06: 54±0.80; 1e-09: 64±0.68; 1e-12: 73±0.67; 
Time:        0.001: 0.4799±0.07443; 1e-06: 0.6171±0.0948

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:
2+3

In [7]:
scipy_benchmark_all(params_grid, N=200, cg_iter=700, pcg_iter=270, atol=0)

['poisson', 32, 0.1]


2024-09-21 17:40:46.087786: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done
I, nnz = 0.476±0.000, prec time = -1.000000±-1.000000
Residuals:   0.001: 93±0.00; 1e-06: 116±0.00; 1e-09: 134±0.00; 1e-12: 150±0.00; 
Time:        0.001: 0.0092±0.06742; 1e-06: 0.0100±0.06742; 1e-09: 0.0106±0.06741; 1e-12: 0.0111±0.06741; 

L0, nnz = 0.287±0.000, prec time = 0.000082±0.000009
Residuals:   0.001: 31±0.00; 1e-06: 39±0.27; 1e-09: 49±0.00; 1e-12: 55±0.00; 
Time:        0.001: 0.0409±0.00163; 1e-06: 0.0509±0.00211; 1e-09: 0.0633±0.00249; 1e-12: 0.0707±0.00277; 

Lt1, nnz = 0.379±0.000, prec time = 0.000302±0.000033
Residuals:   0.001: 20±0.45; 1e-06: 25±0.00; 1e-09: 30±0.00; 1e-12: 35±0.00; 
Time:        0.001: 0.0253±0.00070; 1e-06: 0.0315±0.00056; 1e-09: 0.0373±0.00064; 1e-12: 0.0431±0.00071; 

Lt5, nnz = 0.758±0.000, prec time = 0.000614±0.000076
Residuals:   0.001: 10±0.00; 1e-06: 14±0.00; 1e-09: 17±0.00; 1e-12: 20±0.00; 
Time:        0.001: 0.0161±0.00069; 1e-06: 0.0215±0.00090; 1e-09: 0.0256±0.

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [2]:
import numpy as np

np.array([1, 2, 3, 4, ]).mean()

nan

In [7]:
scipy_benchmark_all(params_grid, N=200, cg_iter=700, pcg_iter=270, atol=0)

['poisson', 32, 0.1]


2024-09-19 14:46:53.161795: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


  Precs are combined
  I is done
  L0 is done
  Lt1 is done
  Lt5 is done
I
Residuals:   0.001: 93±0.00; 1e-06: 116±0.00; 1e-09: 134±0.00; 1e-12: 150±0.00; 
Time:        0.001: 0.0080±0.06500; 1e-06: 0.0086±0.06501; 1e-09: 0.0091±0.06501; 1e-12: 0.0096±0.06502; 

L0
Residuals:   0.001: 31±0.00; 1e-06: 39±0.00; 1e-09: 49±0.00; 1e-12: 55±0.00; 
Time:        0.001: 0.0420±0.00024; 1e-06: 0.0522±0.00031; 1e-09: 0.0649±0.00038; 1e-12: 0.0726±0.00041; 

Lt1
Residuals:   0.001: 20±0.00; 1e-06: 25±0.00; 1e-09: 30±0.00; 1e-12: 35±0.00; 
Time:        0.001: 0.0294±0.00222; 1e-06: 0.0361±0.00260; 1e-09: 0.0428±0.00298; 1e-12: 0.0495±0.00337; 

Lt5
Residuals:   0.001: 10±0.00; 1e-06: 14±0.00; 1e-09: 17±0.00; 1e-12: 20±0.00; 
Time:        0.001: 0.0182±0.00138; 1e-06: 0.0245±0.00175; 1e-09: 0.0292±0.00201; 1e-12: 0.0339±0.00227; 

------------------------------------------------------------------------------------------------

['poisson', 64, 0.1]
  Precs are combined
  I is done
  L0 is done
  Lt1

ValueError: axis 0 index 1024 exceeds matrix dimension 1024

---

---

# Check number of iterations

In [None]:
pde = 'div_k_grad'      # 'poisson', 'div_k_grad'
grid = 64            # 32, 64, 128
variance = 1.5        # 0.1, 0.5, 1.0 1.5

In [None]:
s1 = perf_counter()
A, _, b, _, _ = dataset_qtt(pde, grid, variance, lhs_type='fd', return_train=False)
dt_data = perf_counter() - s1

In [None]:
i = 101
fill_factor = 10
threshold = 1e-4

In [None]:
L0 = ilupp.ichol0(jBCOO_to_scipyCSR(A[i, ...]))
print(f'IC0 NNZ(L) = {(L0.nnz*100) / (L0.shape[-1] ** 2):.3f}%')

In [None]:
_, res_I = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=None, seed=42)
print('CG is done')

# L = ilupp.icholt(jBCOO_to_scipyCSR(A[i, ...]), add_fill_in=fill_factor, threshold=threshold)
L = L0
prec = partial(llt_prec_trig_solve, L=jsparse.BCOO.from_scipy_sparse(L)[None, ...])
print('Prec is generated', end='\n\n')

s_prec = perf_counter()
_, res = ConjGrad(A[i:i+1, ...], b[i:i+1, ...], N_iter=1000, prec_func=prec, seed=42)
dt_cg = perf_counter() - s_prec

plt.plot(range(res_I.shape[-1]), jnp.linalg.norm(res_I, axis=1).mean(0), label="CG")
plt.plot(range(res.shape[-1]), jnp.linalg.norm(res, axis=1).mean(0), label="PCG")
plt.xlabel('Iteration')
plt.ylabel('$\|res\|$')
plt.legend()
plt.yscale('log')
plt.grid();

cg_tol = iter_per_residual(jnp.linalg.norm(res_I, axis=1).mean(0))
pcg_tol = iter_per_residual(jnp.linalg.norm(res, axis=1).mean(0))

print(f'CG: {cg_tol}')
print(f'PCG: {pcg_tol}')
print(f'NNZ(L) = {(L.nnz*100) / (L.shape[-1] ** 2):.3f}%')