In [1]:
import sys
import os
import jax

jax.config.update("jax_enable_x64", False)
os.environ["CUDA_VISIBLE_DEVICES"] = ''
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [3]:
import numpy as np
import jax.numpy as jnp
jnp.ones(1)

Array([1.], dtype=float32)

In [4]:
# import parafields
# import numpy as np
# import matplotlib.pyplot as plt

# grid = 400
# cov_model = 'gaussian'
# var = .5

# bounds = [90, 280]

# ls_contr = []
# while len(ls_contr) != 500:
#     field = parafields.generate_field(cells=[grid+1, grid+1], covariance=cov_model, variance=var)
#     k = field.evaluate()
#     contrast = np.exp(k.max() - k.min())
#     if not bounds[0] <= contrast <= bounds[1]:
#         continue
#     ls_contr.append(contrast)

In [5]:
# print(np.round(np.min(ls_contr), 2), np.round(np.mean(ls_contr), 2), np.round(np.max(ls_contr), 2))
# plt.hist(ls_contr, bins=100);

In [6]:
import jax.numpy as jnp

from data.qtt import poisson, div_k_grad

In [7]:
path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/'
try: os.mkdir(os.path.join(path, 'paper_datasets'))
except: pass

save_dir = os.path.join(path, 'paper_datasets')
N_train = 1000
N_test = 200

In [6]:
from time import perf_counter

boundaries_ls = [
    [[80, 250]],
]
var_ls = [0.5]

t_data = perf_counter()
for g, bound_ls in zip([256], boundaries_ls):
    for var_i, b_ in zip(var_ls, bound_ls):
        A_train, b_train, x_train, k_stats_train = div_k_grad(2, g, bounds=b_, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(2, g, bounds=b_, cov_model='gaussian', var=var_i)

t_data = perf_counter() - t_data

In [8]:
from copy import deepcopy
from functools import partial
from jax import random

from data.dataset import load_dataset
from data.graph_utils import spmatrix_to_graph
from config import default_precorrector_gnn_config
from scipy_linsolve import make_Chol_prec_from_bcoo, batched_cg_scipy, single_lhs_cg
from train import construction_time_with_gnn, train_inference_finetune, load_hp_and_model, make_neural_prec_model
from architecture.neural_preconditioner_design import PreCorrectorGNN, NaiveGNN, PreCorrectorMLP, PreCorrectorMLP_StaticDiag, PreCorrectorMultiblockGNN

In [9]:
data_config = {
    'data_dir': path,
    'pde': 'div_k_grad',
    'grid': 256,
    'variance': .5,
    'lhs_type': 'l_ic0',
    'N_samples_train': 1000,
    'N_samples_test': 200,
    'fill_factor': 1,
    'threshold': 1e-4
}
class_time_mean_test = 5.461 * 1e-3

name = 'us1puz'
model_path = os.path.join(path, 'results_cases/29.01_final_elliptic_grid_precor_gnn', name, name+'.eqx')

model_config = deepcopy(default_precorrector_gnn_config)
model_config['static_diag'] = False
model_config['mp']['aggregate_edges'] = 'max'

make_model = partial(make_neural_prec_model, model_type='precorrector_gnn')
model = make_model(random.PRNGKey(42), model_config)
model, model_config = load_hp_and_model(model_path, make_model)

In [10]:
test_set = load_dataset(data_config, return_train=False)
A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set


class_time_mean_test = 1.8 * 1e-3

time_gnn_mean, time_gnn_std = construction_time_with_gnn(model, A_test[0, ...], A_pad_test[0, ...], b_test[0, ...],
                                                         None, num_rounds=200,
                                                         pre_time_ic=class_time_mean_test)

In [12]:
f'{time_gnn_mean:.1e}'

'8.3e-02'

In [7]:
os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz')

'/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad512_Gauss0.5_train.npz'

In [13]:
# Div-k-grad grid 400
boundaries_ls = [
    [[90, 280]],
]
var_ls = [0.5]

for g, bound_ls in zip([400], boundaries_ls):
    for var_i, b_ in zip(var_ls, bound_ls):
        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, bounds=b_, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, bounds=b_, cov_model='gaussian', var=var_i)
        
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),
                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,
                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),
                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,
                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])

In [5]:
# Poisson
for g in [256]:
    A_train, b_train, x_train = poisson(N_train, g)
    A_test, b_test, x_test = poisson(N_test, g)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_train.npz'), Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_test.npz'), Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test)

2025-01-21 12:02:15.984922: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
# Div-k-grad grid 64, 128
boundaries_ls = [
    [[400, 1200]],
    [[400, 1200]],
]
var_ls = [1.1]

for g, bound_ls in zip([64, 128], boundaries_ls):
    for var_i, b_ in zip(var_ls, bound_ls):
        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, bounds=b_, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, bounds=b_, cov_model='gaussian', var=var_i)
        
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),
                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,
                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),
                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,
                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])

In [7]:
# Div-k-grad grid 256
boundaries_ls = [
    [[80, 250], [250, 900], [400, 1200]],
]
var_ls = [0.5, 0.7, 1.1]

for g, bound_ls in zip([256], boundaries_ls):
    for var_i, b_ in zip(var_ls, bound_ls):
        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, bounds=b_, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, bounds=b_, cov_model='gaussian', var=var_i)
        
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),
                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,
                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),
                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,
                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])

In [8]:
ls32, ls64, ls128, ls256 = [], [], [], []
for f in os.listdir(save_dir):
    if os.path.isdir(f):
        continue
    if 'div_k_grad32' in f:
        ls32.append(os.path.join(save_dir, f))
    elif 'div_k_grad64' in f:
        ls64.append(os.path.join(save_dir, f))
    elif 'div_k_grad128' in f:
        ls128.append(os.path.join(save_dir, f))
    elif 'div_k_grad256' in f:
        ls256.append(os.path.join(save_dir, f))
        
ls32.sort()
ls64.sort()
ls128.sort()
ls256.sort()

for d in ls32:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 32, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
    elif 'test' in d:
        print(f"Grid 32, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
        
print()
for d in ls64:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 64, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
    elif 'test' in d:
        print(f"Grid 64, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")

print()
for d in ls128:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 128, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
    elif 'test' in d:
        print(f"Grid 128, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
        
print()
for d in ls256:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 256, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")
    elif 'test' in d:
        print(f"Grid 256, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, mean k: {a['kmean']:.0f}, max k: {a['kmax']:.0f}")

Grid 32, test,  var 0.1. Min k: 5, mean k: 7, max k: 11
Grid 32, train, var 0.1. Min k: 5, mean k: 7, max k: 11
Grid 32, test,  var 0.5. Min k: 37, mean k: 86, max k: 172
Grid 32, train, var 0.5. Min k: 35, mean k: 86, max k: 178
Grid 32, test,  var 0.7. Min k: 180, mean k: 280, max k: 692
Grid 32, train, var 0.7. Min k: 180, mean k: 278, max k: 682

Grid 64, test,  var 0.1. Min k: 5, mean k: 8, max k: 12
Grid 64, train, var 0.1. Min k: 5, mean k: 8, max k: 12
Grid 64, test,  var 0.5. Min k: 46, mean k: 103, max k: 188
Grid 64, train, var 0.5. Min k: 46, mean k: 103, max k: 199
Grid 64, test,  var 0.7. Min k: 202, mean k: 308, max k: 702
Grid 64, train, var 0.7. Min k: 200, mean k: 317, max k: 742
Grid 64, test,  var 1.1. Min k: 404, mean k: 757, max k: 1182
Grid 64, train, var 1.1. Min k: 401, mean k: 768, max k: 1199

Grid 128, test,  var 0.1. Min k: 6, mean k: 8, max k: 13
Grid 128, train, var 0.1. Min k: 6, mean k: 8, max k: 14
Grid 128, test,  var 0.5. Min k: 52, mean k: 117, max 