In [None]:
import numpy as np
import torch

from rlaopt.solvers import PCGConfig, SAPConfig, SAPAccelConfig
from rlaopt.preconditioners import NystromConfig

from scalable_gp_inference.hparam_training import train_exact_gp_subsampled
from scalable_gp_inference.gp_inference import GPInference
from scalable_gp_inference.sdd_config import SDDConfig

from experiments.data_processing.load_torch import LOADERS

In [None]:
seed = 0

In [None]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:1")

dataset_name = "3droad"
data = LOADERS[dataset_name](split_proportion=0.1, split_shuffle=True, split_seed=seed, standardize=True, dtype=torch.float64, device=device)
Xtr = data.Xtr
ytr = data.ytr
Xtst = data.Xtst
ytst = data.ytst

kernel_type = "matern32"
distributed = True
devices = set([torch.device("cuda:1"), torch.device("cuda:3"), torch.device("cuda:4")])
num_posterior_samples = 64
num_random_features = 64

subsample_size = 10000

In [None]:
hparams = train_exact_gp_subsampled(Xtr=Xtr, ytr=ytr, kernel_type=kernel_type, opt_hparams={"lr": 0.1},
            training_iters=100, subsample_size=subsample_size, num_trials=1)

In [None]:
print(hparams)

In [None]:
gp_inference_model = GPInference(
    Xtr=Xtr,
    ytr=ytr,
    Xtst=Xtst,
    ytst=ytst,
    kernel_type=kernel_type,
    kernel_hparams=hparams,
    num_posterior_samples=num_posterior_samples,
    num_random_features=num_random_features,
    distributed=distributed,
    devices=devices,
)

In [None]:
max_iters = 1000
nystrom_config = NystromConfig(rank=100, rho=hparams.noise_variance, damping_mode="adaptive")
accel_config = SAPAccelConfig(mu=hparams.noise_variance, nu=10.0)
solver_config = SAPConfig(
    precond_config=nystrom_config,
    max_iters=max_iters,
    atol=1e-6,
    rtol=1e-6,
    blk_sz=Xtr.shape[0] // 10,
    accel_config=accel_config,
    device=device,
)
# solver_config = SDDConfig(
#     momentum=0.9,
#     step_size=100 / Xtr.shape[0],
#     theta=100 / max_iters,
#     blk_size=Xtr.shape[0] // 10,
#     max_iters=max_iters,
#     device=device,
#     atol=1e-6,
#     rtol=1e-6,
# )
# solver_config = PCGConfig(
#     precond_config=nystrom_config,
#     max_iters=max_iters,
#     atol=1e-6,
#     rtol=1e-6,
#     device=device,
# )

In [None]:
results = gp_inference_model.perform_inference(
    solver_config=solver_config,
    W_init=None,
    use_full_kernel=False,
    eval_freq=10,
    log_in_wandb=True,
)