In [3]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = ''
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [4]:
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import sparse as jsparse
from jax.scipy.sparse.linalg import cg as jcg

import numpy as np
from scipy.sparse.linalg import LinearOperator, cg
from scipy.sparse import tril, triu, eye

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from linsolve.precond import llt_prec_trig_solve, llt_inv_prec
from utils import factorsILUp, batchedILUp, ILU2, jILU2, iter_per_residual, jBCOO_to_scipyCSR
from linsolve.cg import ConjGrad
from data.qtt import poisson, div_k_grad, scipy_validation, solve_precChol, solve_precLU

## Poisson dataset

In [5]:
path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/'
try:
    os.mkdir(os.path.join(path, 'paper_datasets'))
except:
    pass

save_dir = os.path.join(path, 'paper_datasets')
grid_ls = [32, 64, 128]
N_train = 1000
N_test = 200

In [7]:
# Poisson

for g in grid_ls:
    A_train, b_train, x_train = poisson(N_train, g)
    A_test, b_test, x_test = poisson(N_test, g)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_train.npz'), Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_test.npz'), Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test)

In [7]:
# Div-k-grad

for g in grid_ls:
    for var_i in [0.1, 0.5, 1.0, 1.5]:
        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, cov_model='gaussian', var=var_i)
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),
                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,
                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),
                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,
                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])

In [55]:
ls32, ls64, ls128 = [], [], []
for f in os.listdir(save_dir):
    if os.path.isdir(f):
        continue
    if 'div_k_grad32' in f:
        ls32.append(os.path.join(save_dir, f))
    elif 'div_k_grad64' in f:
        ls64.append(os.path.join(save_dir, f))
    elif 'div_k_grad128' in f:
        ls128.append(os.path.join(save_dir, f))
        
ls32.sort()
ls64.sort()
ls128.sort()

for d in ls32:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 32, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 32, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
        
print()
for d in ls64:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 64, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 64, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")

print()
for d in ls128:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 128, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 128, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")

Grid 32, test,  var 0.1. Min k: 5, max k: 11, mean k: 7
Grid 32, train, var 0.1. Min k: 5, max k: 13, mean k: 7
Grid 32, test,  var 0.5. Min k: 35, max k: 234, mean k: 91
Grid 32, train, var 0.5. Min k: 32, max k: 474, mean k: 91
Grid 32, test,  var 1.0. Min k: 136, max k: 2097, mean k: 565
Grid 32, train, var 1.0. Min k: 130, max k: 4241, mean k: 607
Grid 32, test,  var 1.5. Min k: 520, max k: 16452, mean k: 2755
Grid 32, train, var 1.5. Min k: 361, max k: 25140, mean k: 2620

Grid 64, test,  var 0.1. Min k: 5, max k: 12, mean k: 8
Grid 64, train, var 0.1. Min k: 5, max k: 13, mean k: 8
Grid 64, test,  var 0.5. Min k: 46, max k: 555, mean k: 117
Grid 64, train, var 0.5. Min k: 40, max k: 373, mean k: 111
Grid 64, test,  var 1.0. Min k: 256, max k: 3832, mean k: 837
Grid 64, train, var 1.0. Min k: 183, max k: 6276, mean k: 816
Grid 64, test,  var 1.5. Min k: 552, max k: 26604, mean k: 3930
Grid 64, train, var 1.5. Min k: 566, max k: 34226, mean k: 4009

Grid 128, test,  var 0.1. Min k:

In [46]:
ls32

['/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss0.5_test.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss0.5_train.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss0.1_train.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss1.0_test.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss1.0_train.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss1.5_train.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss1.5_test.npz',
 '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss0.1_test.npz']