In [1]:
import warnings
import sys
import os

warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
sys.path.append('/mnt/local/data/vtrifonov/prec-learning-Notay-loss/')

In [9]:
import jax.numpy as jnp
from jax import random, vmap, clear_caches, jit
import numpy as np
import jax 

import optax
import equinox as eqx
from equinox.nn import Conv1d
import matplotlib.pyplot as plt
from functools import partial
from time import perf_counter
import cloudpickle

from data.dataset import dataset_qtt
from linsolve.cg import ConjGrad
from linsolve.precond import llt_prec_trig_solve, llt_inv_prec
from model import MessagePassing, FullyConnectedNet, PrecNet, ConstantConv1d, MessagePassingWithDot, CorrectionNet

from utils import params_count, asses_cond, iter_per_residual, batch_indices
from data.graph_utils import direc_graph_from_linear_system_sparse
from train import train

plt.rcParams['figure.figsize'] = (11, 7)

# Train/retrain/overwrite

In [10]:
save_path = '/mnt/local/data/vtrifonov/prec-learning-Notay-loss/trained_models/'
model_name = 'check27.11_div_k_grad_64_0.5_ilu0_llt_loss_1000epoch'
train_from_scratch = True

In [31]:
class CorrectionNet_old(eqx.Module):
    '''L = L + alpha * GNN(L)
    Perseving diagonal as: diag(A) = diag(D) from A = LDL^T'''
    NodeEncoder: eqx.Module
    EdgeEncoder: eqx.Module
    MessagePass: eqx.Module
    EdgeDecoder: eqx.Module
    alpha: jax.Array

    def __init__(self, NodeEncoder, EdgeEncoder, MessagePass, EdgeDecoder, alpha):
        super(CorrectionNet_old, self).__init__()
        self.NodeEncoder = NodeEncoder
        self.EdgeEncoder = EdgeEncoder
        self.MessagePass = MessagePass
        self.EdgeDecoder = EdgeDecoder
        self.alpha = alpha
        return    
    
    def __call__(self, train_graph, bi_edges_indx):#, lhs_graph):
        nodes, edges_init, receivers, senders = train_graph
        print([a.shape for a in [nodes, edges_init, receivers, senders]])
        norm = jnp.abs(edges_init).max()
        edges = edges_init / norm
#         nodes = nodes / jnp.abs(nodes).max()
        
        nodes = self.NodeEncoder(nodes[None, ...])
        edges = self.EdgeEncoder(edges[None, ...])
        nodes, edges, receivers, senders = self.MessagePass(nodes, edges, receivers, senders)
        edges = bi_direc_edge_avg(edges, bi_edges_indx)
        edges = self.EdgeDecoder(edges)[0, ...]
        
        edges = edges * norm
        edges = edges_init + self.alpha * edges
        
        nodes, edges, receivers, senders = graph_tril(nodes, jnp.squeeze(edges), receivers, senders)
        low_tri = graph_to_low_tri_mat_sparse(nodes, edges, receivers, senders)
        return low_tri

In [32]:
class CorrectionNet_new(eqx.Module):
    '''L = L + alpha * GNN(L)
    Perseving diagonal as: diag(A) = diag(D) from A = LDL^T'''
    NodeEncoder: eqx.Module
    EdgeEncoder: eqx.Module
    MessagePass: eqx.Module
    EdgeDecoder: eqx.Module
    alpha: jax.Array

    def __init__(self, NodeEncoder, EdgeEncoder, MessagePass, EdgeDecoder, alpha):
        super(CorrectionNet_new, self).__init__()
        self.NodeEncoder = NodeEncoder
        self.EdgeEncoder = EdgeEncoder
        self.MessagePass = MessagePass
        self.EdgeDecoder = EdgeDecoder
        self.alpha = alpha
        return    
    
    def __call__(self, train_graph):#, lhs_graph):
        nodes, edges_init, receivers, senders = train_graph
        print([a.shape for a in [nodes, edges_init, receivers, senders]])
        norm = jnp.abs(edges_init).max()
        edges = edges_init / norm
#         nodes = nodes / jnp.abs(nodes).max()
        
        nodes = self.NodeEncoder(nodes[None, ...])
        edges = self.EdgeEncoder(edges[None, ...])
        nodes, edges, receivers, senders = self.MessagePass(nodes, edges, receivers, senders)
#         edges = bi_direc_edge_avg(edges, bi_edges_indx)
        edges = self.EdgeDecoder(edges)[0, ...]
        
        edges = edges * norm
        edges = edges_init + self.alpha * edges
        
        nodes, edges, receivers, senders = graph_tril(nodes, jnp.squeeze(edges), receivers, senders)
        low_tri = graph_to_low_tri_mat_sparse(nodes, edges, receivers, senders)
        return low_tri

# Setup experiment

In [13]:
pde = 'div_k_grad'      # 'poisson', 'div_k_grad'
grid = 64            # 32, 64, 128
variance = .5        # 0.1, 0.5, 1.0 1.5
lhs_type = 'l_ilu0'      # 'fd', 'ilu0', 'ilu1', 'ilu2', 'ict', 'l_ict', 'a_pow'
N_train = 1000
N_test = 200
precision = 'f64'

fill_factor = 1     # int
threshold = 1e-4     # float
power = 2            # int
N_valid_CG = 300     # Number of CG iterations for validation in the very end

In [14]:
with_cond = False
layer_ = Conv1d
# layer_ = ConstantConv1d         # 'ConstantConv1d' to make a "zero" NN initialization; 'Conv1d' to make a random initialization
alpha = jnp.array([0.])

loss_type = 'llt'               # 'llt', 'llt-res', 'inv-prec'

2024-11-27 07:14:12.459240: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [15]:
batch_size = 8
epoch_num = 1000
lr = 1e-3
schedule_params = None #[1700, 2001, 300, 1e-1]    # [start, stop, step, decay_size]

In [16]:
# if (loss_type in {'notay', 'llt-res', 'llt-res-norm'} and dataset == 'simple') or (loss_type in {'llt', 'llt-norm'} and dataset == 'krylov'):
#     raise ValueError('Not valid dataset for a chosen loss')

if schedule_params != None:
    assert len(schedule_params) == 4
    
    start, stop, step, decay_size = schedule_params
    steps_per_batch = N_train // batch_size
    start, stop, step = start*steps_per_batch, stop*steps_per_batch, step*steps_per_batch
    lr = optax.piecewise_constant_schedule(
        lr,
        {k: v for k, v in zip(np.arange(start, stop, step), [decay_size, ] * len(jnp.arange(start, stop, step)))}
    )

In [17]:
model_config = {
    'node_enc': {
        'features': [1, 16, 16],
        'N_layers': 2,
        'layer_': layer_
    },
    'edge_enc': {
        'features': [1, 16, 16],
        'N_layers': 2,
        'layer_': layer_
    },
    'edge_dec': {
        'features': [16, 16, 1],
        'N_layers': 2,
        'layer_': layer_
    },
    'mp': {
        'edge_upd': {
            'features': [48, 16, 16],
            'N_layers': 2,
            'layer_': layer_
        },
        'node_upd': {
            'features': [32, 16, 16],
            'N_layers': 2,
            'layer_': layer_
        },
        'mp_rounds': 5
    }
}

# Make dataset

In [18]:
s1 = perf_counter()
A_train, A_pad_train, b_train, u_exact_train, bi_edges_train = dataset_qtt(pde, grid, variance, lhs_type,
                                                                           return_train=True, N_samples=N_train,
                                                                           fill_factor=fill_factor, threshold=threshold,
                                                                           power=power, precision=precision)
A_test, A_pad_test, b_test, u_exact_test, bi_edges_test = dataset_qtt(pde, grid, variance, lhs_type,
                                                                      return_train=False, N_samples=N_test,
                                                                      fill_factor=fill_factor, threshold=threshold,
                                                                      power=power, precision=precision)
print(perf_counter() - s1)

23.958647273480892


# Train model

In [19]:
seed = 42
NodeEncoder = FullyConnectedNet(**model_config['node_enc'], key=random.PRNGKey(seed))
EdgeEncoder = FullyConnectedNet(**model_config['edge_enc'], key=random.PRNGKey(seed))
EdgeDecoder = FullyConnectedNet(**model_config['edge_dec'], key=random.PRNGKey(seed))

mp_rounds = 5
MessagePass = MessagePassing(
    update_edge_fn = FullyConnectedNet(**model_config['mp']['edge_upd'], key=random.PRNGKey(seed)),    
    update_node_fn = FullyConnectedNet(**model_config['mp']['node_upd'], key=random.PRNGKey(seed)),
    mp_rounds=model_config['mp']['mp_rounds']
)

# model = PrecNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
#                 EdgeDecoder=EdgeDecoder, MessagePass=MessagePass)

model = CorrectionNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
                EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, alpha=alpha)

# w = jnp.zeros(A_pad_train[0, ...].nse)
# b = alpha
# model = ShiftNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
#                 EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, w=w, b=b)
print(f'Parameter number: {params_count(model)}')

Parameter number: 2754


In [20]:
# data = (X_train, X_test, y_train, y_test)
data = (
    [A_train, A_pad_train, b_train, bi_edges_train, u_exact_train],
    [A_test, A_pad_test, b_test, bi_edges_test, u_exact_test],
    jnp.array([1]), jnp.array([1])
)
train_config = {
    'optimizer': optax.adam,
    'lr': lr,
    'optim_params': {},#{'weight_decay': 1e-8}, 
    'epoch_num': epoch_num,
    'batch_size': batch_size,
}

In [33]:
model_new = CorrectionNet_new(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
                EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, alpha=alpha)
model_old = CorrectionNet_old(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, 
                EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, alpha=alpha)
nodes, edges, receivers, senders, _ = direc_graph_from_linear_system_sparse(A_pad_test, b_test)


In [34]:
L = vmap(model_new, in_axes=(0), out_axes=(0))((nodes, edges, receivers, senders))#, bi_edges_test)#, (lhs_nodes, lhs_edges, lhs_receivers, lhs_senders))

[(4096,), (20224,), (20224,), (20224,)]


NameError: name 'graph_tril' is not defined

In [35]:
L = vmap(model_old, in_axes=((0, 0, 0, 0), 0), out_axes=(0))((nodes, edges, receivers, senders), bi_edges_test)#, (lhs_nodes, lhs_edges, lhs_receivers, lhs_senders))

[(4096,), (20224,), (20224,), (20224,)]


NameError: name 'bi_direc_edge_avg' is not defined

In [None]:
if train_from_scratch:
    s = perf_counter()
    model, losses = train(model, data, train_config, loss_name=loss_type, repeat_step=1, with_cond=with_cond)
    dt = perf_counter() - s
    
    with open(save_path + model_name + '.pkl', 'wb') as f:
        cloudpickle.dump(model, f)
else:
    with open(save_path + model_name + '.pkl', 'rb') as f:
        model = cloudpickle.load(f)
    losses, dt = [np.nan], np.nan

In [None]:
dt

In [None]:
print('alpha:', end='')
model.alpha

In [None]:
_, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(range(len(losses[0])), losses[1], label='Test')
axes[0].plot(range(len(losses[0])), losses[0], label='Train')
axes[0].legend()
axes[0].set_yscale('log')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss');
axes[0].grid();

axes[1].plot(range(len(losses[0])), losses[2], label='Test')
axes[1].legend()
axes[1].set_yscale('log')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Cond of $P^{-1}A$')
axes[1].grid();

plt.tight_layout()

print(f'Final values\n  train loss: {losses[0][-1]:.4f}\n   test loss: {losses[1][-1]:.4f}\n    LLT cond: {losses[2][-1]:.0f}')
print(f'\nMinimim test loss `{jnp.min(losses[1]).item():.4f}` at epoch `{jnp.argmin(losses[1]).item():.0f}`')
print(f'\nMinimim test P^(-1)A cond `{jnp.min(losses[2]).item():.0f}` at epoch `{jnp.argmin(losses[2]).item():.0f}`')

In [None]:
losses[1][500], losses[1][-1]

## Make precs

In [16]:
nodes, edges, receivers, senders, _ = direc_graph_from_linear_system_sparse(A_pad_test, b_test)
# lhs_nodes, lhs_edges, lhs_receivers, lhs_senders, _ = direc_graph_from_linear_system_sparse(A_test, b_test)

L = vmap(model, in_axes=(0), out_axes=(0))((nodes, edges, receivers, senders))#, bi_edges_test)#, (lhs_nodes, lhs_edges, lhs_receivers, lhs_senders))
# del model, data, A_train, A_pad_train, b_train, u_exact_train, bi_edges_train, bi_edges_test
# clear_caches()

# Apply model to CG

In [19]:
from linsolve.scipy_linsolve import batched_cg_scipy, make_Chol_prec_from_bcoo, cg_scipy
from utils import jBCOO_to_scipyCSR

import scipy.sparse.linalg as sci_sp_linalg
import scipy.linalg as sci_linalg

In [20]:
P = make_Chol_prec_from_bcoo(L)

In [22]:
_, iters_mean, iters_std, time_mean, time_std = batched_cg_scipy(A_test, b_test, P=P, atol=1e-12, maxiter=300)

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:
print('iters')
display(iters_mean)
iters_std

### Classical prec

In [69]:
from linsolve.scipy_linsolve import make_Chol_prec

In [None]:
P_class = make_Chol_prec(L)

In [None]:
_, iters_mean, iters_std, time_mean, time_std = batched_cg_scipy(A_test, b_test, P=P_class, atol=1e-12, maxiter=400)

In [None]:
print('iters')
display(iters_mean)
iters_std