In [None]:
import torch

from rlaopt.solvers import PCGConfig, SAPConfig, SAPAccelConfig
from rlaopt.preconditioners import NystromConfig

from scalable_gp_inference.kernel_linsys import KernelLinSys

In [None]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_chunks = 5

n = 1000000
d = 3
k = 10
reg = 1e-8 * n
kernel_type = "rbf"
kernel_lengthscale = torch.tensor([1.0, 2.0, 3.0], device=device)
residual_tracking_idx = None
distributed = True
devices = set([torch.device(f"cuda:{i}") for i in range(n_chunks)])

X = torch.randn(n, d, device=device)
B = torch.randn(n, k, device=device)

In [None]:
kernel_linsys = KernelLinSys(X, B, reg, kernel_type, kernel_lengthscale, 
                            residual_tracking_idx, distributed, devices)

In [None]:
nystrom_config = NystromConfig(rank=100, rho=reg, damping_mode="adaptive")
accel_config = SAPAccelConfig(mu=reg, nu=10.0)
solver_config = SAPConfig(
    precond_config=nystrom_config,
    max_iters=300,
    atol=1e-6,
    rtol=1e-6,
    blk_sz=n // 10,
    accel_config=accel_config,
    device=device,
)
# solver_config = PCGConfig(
#     precond_config=nystrom_config,
#     max_iters=1000,
#     atol=1e-6,
#     rtol=1e-6,
#     device=device,
# )

In [None]:
solution, log = kernel_linsys.solve(
                    solver_config=solver_config,
                    W_init=torch.zeros_like(B),
                    log_in_wandb=True,
                    wandb_init_kwargs={"project": "test_krr_linsys_class"},
                )

In [None]:
final_log_entry = log[list(log.keys())[-1]]
print("Final log entry key:", list(log.keys())[-1])
print("Final log entry:", final_log_entry)