In [None]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.numpy as jnp
from jax import random, vmap, clear_caches, jit
import numpy as np

import optax
from equinox.nn import Conv1d
import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import cloudpickle

from data.dataset import dataset_qtt
from linsolve.cg import ConjGrad
from linsolve.precond import llt_prec_trig_solve, llt_inv_prec
# from model import MessagePassing, FullyConnectedNet, PrecNet, ConstantConv1d, MessagePassingWithDot, CorrectionNet

from utils import params_count, asses_cond, iter_per_residual, batch_indices
from data.graph_utils import direc_graph_from_linear_system_sparse
from synthetic_utils import load_synthetic_dataset
from train import train

plt.rcParams['figure.figsize'] = (11, 7)

# Setup experiment

In [3]:
trained_models = list(filter(lambda x: x != '', [a if 'matrix' in a else '' for a in os.listdir('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/trained_models')]))
trained_models

['matrix1_sol_ones_nnz0.5%_lr1e-4_epoch_num600.pkl',
 'matrix3_rhs_randn_size1e3_nnz0.5%_lr1e-4_epoch_num800.pkl',
 'matrix3_sol_ones_nnz0.5%_lr1e-4_epoch_num1200.pkl',
 'matrix1_rhs_randn_size2e3_nnz0.5%_5rounds_lr1e-4_epoch_num800.pkl',
 'matrixGersh_sol_ones_nnz5%_size1e3_5rounds_lr1e-4_epoch_num200.pkl',
 'matrix1_sol_ones_nnz0.5%_lr1e-4_epoch_num400.pkl',
 'matrix3_sol_ones_nnz0.5%_lr1e-4_epoch_num600.pkl',
 'matrix1_sol_ones_nnz0.5%.pkl',
 'check_synth_matrix1.pkl',
 'matrix1_rhs_randn_size2e3_nnz0.5%_lr1e-4_epoch_num800.pkl',
 'matrix3_rhs_randn_size1e3_nnz0.5%_5rounds_lr1e-4_epoch_num800.pkl',
 'matrix3_sol_ones_nnz0.5%_lr1e-4_epoch_num400.pkl',
 'matrixGersh_sol_ones_nnz5%_size1e3_rhsNormalization_5rounds_lr1e-4_epoch_num200.pkl',
 'matrix3_sol_ones_nnz0.5%_lr1e-4_epoch_num1800.pkl',
 'matrix3_sol_ones_nnz0.5%.pkl',
 'matrix3_rhs_randn_size1e3_nnz0.5%_5rounds_lr1e-4_epoch_num200.pkl',
 'matrix1_sol_ones_nnz0.5%_lr1e-4_epoch_num1200.pkl']

In [4]:
list(filter(lambda x: x != '', [a if 'matrixGersh' in a else '' for a in trained_models]))

['matrixGersh_sol_ones_nnz5%_size1e3_5rounds_lr1e-4_epoch_num200.pkl',
 'matrixGersh_sol_ones_nnz5%_size1e3_rhsNormalization_5rounds_lr1e-4_epoch_num200.pkl']

In [5]:
N_train = 1000
N_test = 200
dataset_name = 'matrixGersh_sol_ones_nnz5%_size1e3'
dataset_path = f'/mnt/local/data/vtrifonov/prec-learning-Notay-loss/synthetic_datasets/{dataset_name}.npz'

In [6]:
save_path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/trained_models/'
model_name = 'matrixGersh_sol_ones_nnz5%_size1e3_5rounds_lr1e-4_epoch_num200' # dataset_name + '_5rounds_lr1e-4_epoch_num800'

# Make dataset

In [7]:
s1 = perf_counter()
_, test_set = load_synthetic_dataset(dataset_path, N_train=N_train, N_test=N_test)
print(perf_counter() - s1)

Dataset is loaded
Padding: no


2024-11-27 09:50:23.418343: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


12.898306872695684


In [8]:
A_test, A_pad_test, L_test, b_test, u_exact_test = test_set
del _, test_set

In [9]:
with open(save_path + model_name + '.pkl', 'rb') as f:
    model = cloudpickle.load(f)

## Run batched CG for PreCorrector

In [10]:
from linsolve.scipy_linsolve import batched_cg_scipy, make_Chol_prec_from_bcoo
import jax.experimental.sparse as jsparse

In [11]:
# model()

In [12]:
nodes, edges, receivers, senders, _ = direc_graph_from_linear_system_sparse(A_pad_test, b_test)
L_preccor = []
for i in range(len(A_pad_test)):
    L_preccor.append(model((nodes[i, ...], edges[i, ...], receivers[i, ...], senders[i, ...]))[None, ...])
L_preccor = jsparse.bcoo_concatenate(L_preccor, dimension=0)

In [13]:
from linsolve.scipy_linsolve import cg_scipy

def batched_cg_scipy(A, b, P=None, atol=1e-12, maxiter=1000, x0='random'):
    assert (x0 == 'random') | (x0 == None)
    iters_ls = [[], [], [], []]
    time_ls = [[], [], [], []]
    P = P if P else [None]*A.shape[0]
    
    for i in range(A.shape[0]):
        A_i, b_i, P_i, = A[i], b[i, ...], P[i]
        sol, res_i, time_i = cg_scipy(A_i, b_i, P_i, atol=atol, maxiter=maxiter, x0=x0)        
        iters = iter_per_residual(res_i)
        
        iters_ls[0].append(iters[1e-3])
        
        if np.isnan(iters_ls[0][-1]):
            print(f'{i} - alert')
            continue
        
        time_ls[0].append(time_i[iters_ls[0][-1]])
    
    if np.isnan(iters_ls[0]).any():
        print(f'All nans to 1e-3? {np.isnan(iters_ls[0]).all()}')
    
    iters_mean = [
        np.mean(iters_ls[0])]#, np.mean(iters_ls[1]), np.mean(iters_ls[2]), np.mean(iters_ls[3])
    iters_std = [
        np.std(iters_ls[0])]#, np.std(iters_ls[1]), np.std(iters_ls[2]), np.std(iters_ls[3])
    time_mean = [
        np.mean(time_ls[0])]#, np.mean(time_ls[1]), np.mean(time_ls[2]), np.mean(time_ls[3])
    time_std = [
        np.std(time_ls[0])]#, np.std(time_ls[1]), np.std(time_ls[2]), np.std(time_ls[3])
    return sol, iters_mean, iters_std, time_mean, time_std

In [24]:
P_preccor = make_Chol_prec_from_bcoo(L_preccor)
_, iters_mean, iters_std, time_mean, time_std = batched_cg_scipy(A_test, b_test, P=P_preccor,
                                                                 atol=1e-3, maxiter=1500, x0=None)

In [25]:
print('iters')
display(iters_mean)
iters_std

iters


[5.245]

[0.514757224330072]

In [26]:
print('time')
display(time_mean)
time_std

time


[0.01911698089912534]

[0.003006789347212457]

## Run batched CG for baseline

In [17]:
from linsolve.scipy_linsolve import make_Chol_prec

In [21]:
P_baseline = make_Chol_prec(L_test)
_, iters_mean, iters_std, time_mean, time_std = batched_cg_scipy(A_test, b_test, P=P_baseline,
                                                                 atol=1e-3, maxiter=1500, x0=None)

In [22]:
print('iters')
display(iters_mean)
iters_std

iters


[5.245]

[0.514757224330072]

In [23]:
print('time')
display(time_mean)
time_std

time


[0.018118185698986055]

[0.0025012549468588917]