In [2]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [3]:
from copy import deepcopy

import optax
import numpy as np
from sklearn.model_selection import ParameterGrid

from utils import grid_script
from config import default_precorrector_gnn_config
from experiments.script_gnn_prec import script_gnn_prec

In [4]:
import os
import getpass
import logging
import traceback
from time import perf_counter

import pandas as pd
from jax import random, vmap, numpy as jnp

from data.dataset import load_dataset
from data.graph_utils import spmatrix_to_graph
from train import construction_time_with_gnn, train_inference_finetune
from scipy_linsolve import make_Chol_prec_from_bcoo, batched_cg_scipy, single_lhs_cg
from architecture.neural_preconditioner_design import PreCorrectorGNN, NaiveGNN, PreCorrectorMLP, PreCorrectorMLP_StaticDiag, PreCorrectorMultiblockGNN

In [5]:
# This notebook is intentionally hardcoded to the
# generalization experiment presented in the manuscript

In [6]:
seed = 42
path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/'
folder_model = 'results_cases/29.01_final_elliptic_grid_precor_gnn'
folder_log = 'results_cases/30.01_generalization_precor_gnn'
name = '009nt9'

In [7]:
data_config = {
    'data_dir': path,
    'pde': 'div_k_grad',
    'grid': 128,
    'variance': .7,
    'lhs_type': 'l_ic0',
    'N_samples_train': np.nan,
    'N_samples_test': 200,
    'fill_factor': np.nan,
    'threshold': np.nan
}

train_config = {
    'model_type': 'precorrector_gnn',
    'loss_type': np.nan,
    'batch_size': np.nan,
    'optimizer': np.nan,
    'lr': np.nan,
    'optim_params': np.nan,
    'epoch_num': np.nan
}

config = {
    'path': path,
    'folder_model': folder_model,
    'folder_log': folder_log,
    'name': name,
    'model_use': 'inference',
    'save_model': False,
    'cg_maxiter': 700,
    'cg_atol': 1e-12,
    'data_config': data_config,
    'model_config': np.nan,
    'train_config': train_config,
    'seed': seed
}

In [8]:
model_config = deepcopy(default_precorrector_gnn_config)
model_config['static_diag'] = False
model_config['mp']['aggregate_edges'] = 'max'
config['model_config'] = model_config

In [9]:
key = random.PRNGKey(config['seed'])

base_dir = os.path.join(config['path'], config['folder_model'], config['name'])
log_dir = os.path.join(config['path'], config['folder_log'], config['name'])
        
model_file = os.path.join(base_dir, config['name']+'.eqx')
log_file = os.path.join(log_dir, config['name']+'.log')
loss_file = os.path.join(log_dir, 'losses_'+config['name']+'.npz')

2025-01-30 21:51:03.884076: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [9]:
iters_dict = {
    64: {
        0.5: -42,
        0.7: -42,
        1.1: -42
    },
    128: {
        0.5: -42,
        1.1: -42
    },
    256: {
        0.5: -42,
        0.7: -42,
        1.1: -42
    },
}

model, _, _ = train_inference_finetune(key, None, model_config, train_config, model_path=model_file,
                                       model_use=config['model_use'], save_model=config['save_model'])

for grid in iters_dict.keys():
    for var in iters_dict[grid].keys():
        
        data_config['grid'] = grid
        data_config['variance'] = var
        
        test_set = load_dataset(data_config, return_train=False)
        A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set
        
        L = vmap(model, in_axes=(0), out_axes=(0))(spmatrix_to_graph(A_pad_test, b_test))
        P = make_Chol_prec_from_bcoo(L)
        
        cg_func = single_lhs_cg(batched_cg_scipy, single_lhs=True if A_test.shape[0] == 1 else False)
        iters_stats, _, nan_flag = cg_func(A=A_test, b=b_test, pre_time=0, x0='random',
                                                    key=key, P=P, atol=config['cg_atol'],
                                                    maxiter=config['cg_maxiter'], thresholds=[1e-3, 1e-6, 1e-9, 1e-12])
        
        iters_dict[grid][var] = [iters_stats, nan_flag]

In [10]:
2+3

5

In [12]:
iters_dict

{64: {0.5: [{0.001: [48.1, 1.71],
    1e-06: [62.5, 2.12],
    1e-09: [76.7, 2.62],
    1e-12: [90.7, 3.07]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}],
  0.7: [{0.001: [52.3, 2.17],
    1e-06: [67.6, 2.83],
    1e-09: [82.6, 3.4],
    1e-12: [97.3, 4.03]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}],
  1.1: [{0.001: [59.9, 3.31],
    1e-06: [76.8, 3.95],
    1e-09: [93.1, 4.65],
    1e-12: [108.8, 5.4]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}]},
 128: {0.5: [{0.001: [68.0, 2.93],
    1e-06: [85.9, 3.49],
    1e-09: [102.7, 4.02],
    1e-12: [119.1, 4.57]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}],
  1.1: [{0.001: [91.6, 6.23],
    1e-06: [114.8, 7.78],
    1e-09: [136.2, 8.8],
    1e-12: [156.3, 9.75]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}]},
 256: {0.5: [{0.001: [117.1, 6.08],
    1e-06: [145.3, 7.21],
    1e-09: [170.5, 8.21],
    1e-12: [193.7, 9.18]},
   {0.001: 0, 1e-06: 0, 1e-09: 0, 1e-12: 0}],
  0.7: [{0.001: [139.9, 11.15],
    1e-06: [173.6, 13.51],
    1e

In [11]:
grid_ls = [32, 64, 128, 256]

In [14]:
for grid in grid_ls:
    print('grid =', grid)
    data_config['grid'] = grid
    data_config['variance'] = .5

    
    data_config['lhs_type'] = 'l_ic0'
    test_set = load_dataset(data_config, return_train=False)
    A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set

    print(' A')
    print('  shape', A_test[0, ...].shape, 'nnz', A_test[0, ...].nse*100/(A_test[0, ...].shape[0]**2))
    
    print(' IC(0)')
    print('  shape', A_pad_test[0, ...].shape, 'nnz', A_pad_test[0, ...].nse*100/(A_pad_test[0, ...].shape[0]**2))
    
    data_config['lhs_type'] = 'l_ict'
    data_config['threshold'] = 1e-4
    data_config['fill_factor'] = 1
    test_set = load_dataset(data_config, return_train=False)
    A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set
    
    print(' ICt(1)')
    print('  shape', A_pad_test[0, ...].shape, 'nnz', A_pad_test[0, ...].nse*100/(A_pad_test[0, ...].shape[0]**2))
    
    data_config['lhs_type'] = 'l_ict'
    data_config['threshold'] = 1e-4
    data_config['fill_factor'] = 5
    test_set = load_dataset(data_config, return_train=False)
    A_test, A_pad_test, b_test, bi_edges_test, x_test, class_time_mean_test, class_time_std_test = test_set
    
    print(' ICt(5)')
    print('  shape', A_pad_test[0, ...].shape, 'nnz', A_pad_test[0, ...].nse*100/(A_pad_test[0, ...].shape[0]**2))
    print()

grid = 32
 A
  shape (1024, 1024) nnz 0.47607421875
 IC(0)
  shape (1024, 1024) nnz 0.286865234375
 ICt(1)
  shape (1024, 1024) nnz 0.37860870361328125
 ICt(5)
  shape (1024, 1024) nnz 0.7584571838378906

grid = 64
 A
  shape (4096, 4096) nnz 0.12054443359375
 IC(0)
  shape (4096, 4096) nnz 0.072479248046875
 ICt(1)
  shape (4096, 4096) nnz 0.09613633155822754
 ICt(5)
  shape (4096, 4096) nnz 0.19243359565734863

grid = 128
 A
  shape (16384, 16384) nnz 0.03032684326171875
 IC(0)
  shape (16384, 16384) nnz 0.018215179443359375
 ICt(1)
  shape (16384, 16384) nnz 0.024223700165748596
 ICt(5)
  shape (16384, 16384) nnz 0.048476457595825195

grid = 256
 A
  shape (65536, 65536) nnz 0.007605552673339844
 IC(0)
  shape (65536, 65536) nnz 0.004565715789794922
 ICt(1)
  shape (65536, 65536) nnz 0.006079697050154209
 ICt(5)
  shape (65536, 65536) nnz 0.012163445353507996

