In [None]:
import numpy as np
import probnum as pn
import linpde_gp

pn.config.default_solver_linpde_gp = linpde_gp.solvers.CholeskySolver(dense=True)

In [None]:
from dotenv import load_dotenv
import os

load_dotenv()

figures_path = os.environ.get("FIGURES_DIR")

# General solution theory

In [None]:
# Set seed
np.random.seed(2349032509)

# Disable verbose output
import pykeops
pykeops.verbose = False

In [None]:
a = b = 1
T = 2.
domain = linpde_gp.domains.Box([[0, T], [0, a], [0, b]])
spatial_domain = linpde_gp.domains.Box([[0., a], [0., b]])

c = 1

In [None]:
N_samples = 10
mean = np.array([
    [1.0, -0.5],
    [-0.5, 0.],
    ])
cov = np.array([
    [0.1, 0.2],
    [0.2, 0.3],
    ])
coefficients = np.random.normal(mean, cov, size=(N_samples,) + mean.shape)
coefficients[0]

In [None]:
from linpde_gp.problems.pde import WaveEquationDirichletProblem

def get_problem(coefficients):
    return linpde_gp.problems.pde.WaveEquationDirichletProblem(
        t0=0.,
        T=T,
        spatial_domain=spatial_domain,
        c=c,
        initial_values=linpde_gp.functions.TruncatedSineSeries(
            spatial_domain,
            coefficients=coefficients,
        ),
    )

sample_problem = get_problem(coefficients[0])

In [None]:
from linpde_gp.benchmarking import SolutionErrorEstimator

error_estimator = SolutionErrorEstimator(sample_problem.solution, domain)

# Prior

In [None]:
lengthscale_t = 0.5
lengthscale_x = a / 3
lengthscale_y = b / 3
output_scale = 1.0
N_ic_xy = 15
N_bc = T * 20

lengthscale_t /= 2
lengthscale_x /= 2
lengthscale_y /= 2

def get_prior(l_t, l_x, l_y, output_scale):
    return pn.randprocs.GaussianProcess(
        mean=linpde_gp.functions.Zero(input_shape=(3,)),
        cov=output_scale**2
        * linpde_gp.randprocs.covfuncs.TensorProduct(
            linpde_gp.randprocs.covfuncs.Matern((), nu=2.5, lengthscales=l_t),
            linpde_gp.randprocs.covfuncs.Matern((), nu=2.5, lengthscales=l_x),
            linpde_gp.randprocs.covfuncs.Matern((), nu=2.5, lengthscales=l_y),
        ),
    )

u_prior = get_prior(lengthscale_t, lengthscale_x, lengthscale_y, output_scale)

In [None]:
from linpde_gp.linfuncops.diffops import TimeDerivative

N_ic_xy = 25
N_bc_t = int(T * 20)
N_bc_spatial = 20

def condition_ic(prior):
    X_ic = sample_problem.initial_domain.uniform_grid((N_ic_xy, N_ic_xy))
    Y_ic = sample_problem.initial_condition.values(X_ic[..., 1:])

    u_ic = prior.condition_on_observations(X=X_ic, Y=Y_ic)
    return u_ic.condition_on_observations(X=X_ic, Y=Y_ic, L=TimeDerivative((3,)))

u_ic = condition_ic(u_prior)

In [None]:
def condition_bc(prior):
    boundary_x = domain.uniform_grid((N_bc_t, N_bc_spatial, 2))
    boundary_y = domain.uniform_grid((N_bc_t, 2, N_bc_spatial))

    u_bc = prior.condition_on_observations(X=boundary_x, Y=np.zeros(boundary_x.shape[:-1]))
    return u_bc.condition_on_observations(X=boundary_y, Y=np.zeros(boundary_y.shape[:-1]))

u_ic_bc = condition_bc(u_ic)

In [None]:
grid_l_t = np.linspace(domain[0][0] + 0.1, domain[0][1], 4)
grid_l_xy = np.linspace(domain[1][0] + 0.1, domain[1][1], 4)

In [None]:
from gp_fvm.finite_volumes import get_grid_from_depth
from tqdm.notebook import tqdm

def get_u_ic_bc(lengthscale_t, lengthscale_xy):
    u_prior = get_prior(lengthscale_t, lengthscale_xy, lengthscale_xy, output_scale)
    u_ic = condition_ic(u_prior)
    return condition_bc(u_ic)

def error_fv(depth, lengthscale_t, lengthscale_xy):
    u_ic_bc = get_u_ic_bc(lengthscale_t, lengthscale_xy)
    domains = get_grid_from_depth(domain, depth)
    fv = linpde_gp.linfunctls.FiniteVolumeFunctional(domains, sample_problem.pde.diffop)
    solver=linpde_gp.solvers.CholeskySolver(dense=True)
    if depth >= 5:
        solver = linpde_gp.solvers.itergp.IterGP_CG_Solver(threshold=1e-2, max_iterations=1000, num_actions_compressed=1000)
    u_fv = u_ic_bc.condition_on_observations(L=fv, Y=np.zeros(domains.shape), solver=solver, fresh_start=True)
    return error_estimator(u_fv)

def best_lengthscales(depth):
    min_error = np.inf
    l_t_min = None
    l_xy_min = None

    for l_t in grid_l_t:
        for l_xy in grid_l_xy:
            error = error_fv(depth, l_t, l_xy)
            if error < min_error:
                min_error = error
                l_t_min = l_t
                l_xy_min = l_xy
    return l_t_min, l_xy_min

# lengthscales = []
# for i in tqdm(range(1, 7)):
#     lengthscales.append(best_lengthscales(i))

# lengthscales[0]

In [None]:
# Checkpoint
lengthscales = [(0.1, 1.0), (0.1, 0.4), (0.7333333333333333, 0.7), (0.7333333333333333, 0.7), (0.7333333333333333, 0.4), (0.7333333333333333, 0.4), (0.1, 0.1)]

In [None]:
def no_preconditioner(depth, lengthscale_t, lengthscale_xy):
    u_ic_bc = get_u_ic_bc(lengthscale_t, lengthscale_xy)
    domains = get_grid_from_depth(domain, depth)
    fv = linpde_gp.linfunctls.FiniteVolumeFunctional(domains, sample_problem.pde.diffop)
    solver = linpde_gp.solvers.itergp.IterGP_CG_Solver(threshold=1e-2, max_iterations=1000, num_actions_compressed=1000)
    u_fv = u_ic_bc.condition_on_observations(L=fv, Y=np.zeros(domains.shape), solver=solver, fresh_start=False)
    u_fv.representer_weights
    return u_fv.solver.solver_state.iteration

def cholesky_preconditioner(depth, lengthscale_t, lengthscale_xy):
    u_ic_bc = get_u_ic_bc(lengthscale_t, lengthscale_xy)
    domains = get_grid_from_depth(domain, depth)
    fv = linpde_gp.linfunctls.FiniteVolumeFunctional(domains, sample_problem.pde.diffop)
    solver = linpde_gp.solvers.itergp.IterGP_CG_Solver(threshold=1e-2, max_iterations=1000, num_actions_compressed=1000)
    u_fv = u_ic_bc.condition_on_observations(L=fv, Y=np.zeros(domains.shape), solver=solver, fresh_start=True)
    u_fv.representer_weights
    return u_fv.solver.solver_state.iteration

def cholesky_and_multigrid_preconditioner(depth, coarse_depth, lengthscale_t, lengthscale_xy):
    u_ic_bc = get_u_ic_bc(lengthscale_t, lengthscale_xy)
    domains = get_grid_from_depth(domain, depth)
    domains_coarse = get_grid_from_depth(domain, coarse_depth)
    fv = linpde_gp.linfunctls.FiniteVolumeFunctional(domains, sample_problem.pde.diffop)
    fv_coarse = linpde_gp.linfunctls.FiniteVolumeFunctional(domains_coarse, sample_problem.pde.diffop)
    u_coarse = u_ic_bc.condition_on_observations(L=fv_coarse, Y=np.zeros(domains_coarse.shape))
    solver = linpde_gp.solvers.itergp.IterGP_CG_Solver(threshold=1e-2, max_iterations=1000, num_actions_compressed=1000)
    u_fv = u_coarse.condition_on_observations(L=fv, Y=np.zeros(domains.shape), solver=solver, fresh_start=True)
    u_fv.representer_weights
    return u_fv.solver.solver_state.iteration

In [None]:
iters_no_preconditioner = []
iters_cholesky_preconditioner = []
iters_cholesky_and_multigrid_preconditioner = []
iters_cholesky_and_multigrid_preconditioner_depth_3 = []

for depth in range(6+1):
    print(f"Depth {depth}")
    l_t, l_xy = lengthscales[depth]
    iters_no_preconditioner.append(no_preconditioner(depth, l_t, l_xy))
    iters_cholesky_preconditioner.append(cholesky_preconditioner(depth, l_t, l_xy))
    if depth >= 3:
        iters_cholesky_and_multigrid_preconditioner.append(cholesky_and_multigrid_preconditioner(depth, 2, l_t, l_xy))
    if depth >= 4:
        iters_cholesky_and_multigrid_preconditioner_depth_3.append(cholesky_and_multigrid_preconditioner(depth, 3, l_t, l_xy))

In [None]:
num_observations_no_preconditioner = [8**i for i in range(6+1)]
num_observations_cholesky_preconditioner = num_observations_no_preconditioner
num_observations_cholesky_and_multigrid_preconditioner = [8**i for i in range(3, 6+1)]
num_observations_cholesky_and_multigrid_preconditioner_depth_3 = [8**i for i in range(4, 6+1)]

In [None]:
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes
plt.rcParams.update(bundles.icml2024())
plt.rcParams.update(figsizes.icml2024_full())
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set2.colors)

fig, ax = plt.subplots()
ax.plot(num_observations_no_preconditioner, iters_no_preconditioner, label="No preconditioner")
ax.plot(num_observations_cholesky_preconditioner, iters_cholesky_preconditioner, label="Cholesky preconditioner")
ax.plot(num_observations_cholesky_and_multigrid_preconditioner, iters_cholesky_and_multigrid_preconditioner, label="Cholesky + $8^2$ multigrid", linestyle="--")
ax.plot(num_observations_cholesky_and_multigrid_preconditioner_depth_3, iters_cholesky_and_multigrid_preconditioner_depth_3, label="Cholesky + $8^3$ multigrid", linestyle="--")
ax.set_xlabel("Number of observations")
ax.set_ylabel("Iterations")
ax.set_xscale('symlog', base=(2**3))
ax.set_xticks([8**i for i in range(7)])
ax.legend()
ax.grid()
fig.savefig(f"{figures_path}/preconditioner_ablation.pdf")