In [1]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = ''
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [2]:
import jax.numpy as jnp
from jax import random, vmap
from jax.experimental import sparse as jsparse
from jax.scipy.sparse.linalg import cg as jcg

import numpy as np
from scipy.sparse.linalg import LinearOperator, cg
from scipy.sparse import tril, triu, eye

import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import ilupp

from linsolve.precond import llt_prec_trig_solve, llt_inv_prec
from utils import factorsILUp, batchedILUp, ILU2, jILU2, iter_per_residual, jBCOO_to_scipyCSR
from linsolve.cg import ConjGrad
from data.qtt import poisson, div_k_grad, scipy_validation, solve_precChol, solve_precLU

## Poisson dataset

In [3]:
path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/'
try:
    os.mkdir(os.path.join(path, 'paper_datasets'))
except:
    pass

save_dir = os.path.join(path, 'paper_datasets')
grid_ls = [32, 64, 128, 256]
N_train = 1000
N_test = 200

In [4]:
# Poisson

for g in grid_ls:
    A_train, b_train, x_train = poisson(N_train, g)
    A_test, b_test, x_test = poisson(N_test, g)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_train.npz'), Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train)
    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_test.npz'), Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test)

In [5]:
# Div-k-grad
boundaries_ls = [
    [[5, 11], [35, 180], [180, 700]],
    [[5, 12], [45, 200], [200, 750]],
    [[6, 14], [50, 300], [300, 800]],
    [[6, 15], [60, 300], [400, 900]],
]
var_ls = [0.1, 0.5, 0.7]

for g, bound_ls in zip(grid_ls, boundaries_ls):
    for var_i, b_ in zip(var_ls, bound_ls):
        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, bounds=b_, cov_model='gaussian', var=var_i)
        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, bounds=b_, cov_model='gaussian', var=var_i)
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),
                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,
                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])
        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),
                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,
                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])

In [4]:
ls32, ls64, ls128, ls256 = [], [], [], []
for f in os.listdir(save_dir):
    if os.path.isdir(f):
        continue
    if 'div_k_grad32' in f:
        ls32.append(os.path.join(save_dir, f))
    elif 'div_k_grad64' in f:
        ls64.append(os.path.join(save_dir, f))
    elif 'div_k_grad128' in f:
        ls128.append(os.path.join(save_dir, f))
    elif 'div_k_grad256' in f:
        ls256.append(os.path.join(save_dir, f))
        
ls32.sort()
ls64.sort()
ls128.sort()
ls256.sort()

for d in ls32:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 32, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 32, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
        
print()
for d in ls64:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 64, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 64, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")

print()
for d in ls128:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 128, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 128, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
        
print()
for d in ls256:
    a = jnp.load(os.path.join(save_dir, d))
    if 'train' in d:
        print(f"Grid 256, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")
    elif 'test' in d:
        print(f"Grid 256, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}")

Grid 32, test,  var 0.1. Min k: 5, max k: 11, mean k: 7
Grid 32, train, var 0.1. Min k: 5, max k: 11, mean k: 7
Grid 32, test,  var 0.5. Min k: 38, max k: 179, mean k: 86
Grid 32, train, var 0.5. Min k: 36, max k: 179, mean k: 86
Grid 32, test,  var 0.7. Min k: 180, max k: 672, mean k: 278
Grid 32, train, var 0.7. Min k: 180, max k: 697, mean k: 277

Grid 64, test,  var 0.1. Min k: 5, max k: 12, mean k: 8
Grid 64, train, var 0.1. Min k: 5, max k: 12, mean k: 8
Grid 64, test,  var 0.5. Min k: 45, max k: 198, mean k: 104
Grid 64, train, var 0.5. Min k: 45, max k: 200, mean k: 103
Grid 64, test,  var 0.7. Min k: 201, max k: 722, mean k: 316
Grid 64, train, var 0.7. Min k: 200, max k: 742, mean k: 318

Grid 128, test,  var 0.1. Min k: 6, max k: 14, mean k: 8
Grid 128, train, var 0.1. Min k: 6, max k: 14, mean k: 8
Grid 128, test,  var 0.5. Min k: 53, max k: 270, mean k: 115
Grid 128, train, var 0.5. Min k: 50, max k: 297, mean k: 116
Grid 128, test,  var 0.7. Min k: 300, max k: 795, mean k

In [7]:
def make_BCOO(Aval, Aind, N_points):
    return jsparse.BCOO((Aval, Aind), shape=(N_points**2, N_points**2))

In [8]:
a = jnp.load('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/paper_datasets/div_k_grad32_Gauss0.1_test.npz')

In [9]:
list(a.keys())

['Aval', 'Aind', 'b', 'x', 'kmin', 'kmax', 'kmean']

In [10]:
A_val = a['Aval']
A_ind = a['Aind']
A_val.shape, A_ind.shape

((200, 4992), (200, 4992, 2))