In [1]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
from copy import deepcopy
from functools import partial
from time import perf_counter

import numpy as np
import pandas as pd
from jax import random, vmap, numpy as jnp
from jax.experimental import sparse as jsparse

from data.dataset import load_dataset
from data.graph_utils import spmatrix_to_graph
from config import default_precorrector_gnn_config
from scipy_linsolve import make_Chol_prec_from_bcoo, batched_cg_scipy, single_lhs_cg
from train import construction_time_with_gnn, train_inference_finetune, load_hp_and_model, make_neural_prec_model
from architecture.neural_preconditioner_design import PreCorrectorGNN, NaiveGNN, PreCorrectorMLP, PreCorrectorMLP_StaticDiag, PreCorrectorMultiblockGNN

In [3]:
key = random.PRNGKey(42)
path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/'
folder = 'results_cases/28.03_final_elliptic_grid400_precor_gnn/'
name = 'rjlvwh'
model_path = os.path.join(path, folder, name, name+'.eqx')

In [6]:
data_config = {
    'data_dir': path,
    'pde': 'div_k_grad',
    'grid': 400,
    'variance': .5,
    'lhs_type': 'l_ic0',
    'N_samples_train': 1000,
    'N_samples_test': 200,
    'fill_factor': 1,
    'threshold': 1e-4
}
class_time_mean_test = 5.461 * 1e-3

cg_atol = 1e-12
cg_maxiter = 500

model_config = deepcopy(default_precorrector_gnn_config)
model_config['static_diag'] = True
model_config['mp']['aggregate_edges'] = 'sum'

In [7]:
make_model = partial(make_neural_prec_model, model_type='precorrector_gnn')
model = make_model(key, model_config)
model, model_config = load_hp_and_model(model_path, make_model)

test_set = load_dataset(data_config, return_train=False)
A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set

time_gnn_mean, time_gnn_std = construction_time_with_gnn(model, A_test[0, ...], A_pad_test[0, ...], b_test[0, ...],
                                                         bi_edges_test[0, ...], num_rounds=A_test.shape[0],
                                                         pre_time_ic=class_time_mean_test)

L = []
for i in range(A_test.shape[0]):
    L.append(vmap(model, in_axes=(0), out_axes=(0))(spmatrix_to_graph(A_pad_test[i:i+1, ...], b_test[i:i+1, ...])))
P = make_Chol_prec_from_bcoo(jsparse.bcoo_concatenate(L, dimension=0))

print(f'Precs are combined:')
print(f' GNN prec construction time (sec) : mean = {time_gnn_mean:.3e}, std = {time_gnn_std:.3e}.\n')

# CG with PreCorrector's prec
cg_func = single_lhs_cg(batched_cg_scipy, single_lhs=True if A_test.shape[0] == 1 else False)
iters_stats, time_stats, nan_flag = cg_func(A=A_test, b=b_test, pre_time=time_gnn_mean, x0='random',
                                            key=key, P=P, atol=cg_atol,
                                            maxiter=cg_maxiter, thresholds=[1e-3, 1e-6, 1e-9, 1e-12])
print('CG with GNN is finished:')
print(f' iterations to atol([mean, std]): %s;', iters_stats)
print(f' time to atol([mean, std]): %s;', time_stats)
print(f' number of linsystems for which CG did not conerge to atol: %s.\n', nan_flag)

LinAlgError: A is singular: zero entry on diagonal.

In [None]:
print('CG with GNN is finished:')
print(f' iterations to atol([mean, std]): %s;', iters_stats)
print(f' time to atol([mean, std]): %s;', time_stats)
print(f' number of linsystems for which CG did not conerge to atol: %s.\n', nan_flag)