In [1]:
import torch

from rlaopt.solvers import PCGConfig, SAPConfig, SAPAccelConfig
from rlaopt.preconditioners import NystromConfig

from scalable_gp_inference.gp_inference import GPInference

In [2]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

<torch._C.Generator at 0x7fb5982af8d0>

In [3]:
device = torch.device("cuda:2")

ntr = 10000
ntst = 1000
d = 3
noise_variance = 1e-6 * ntr
kernel_type = "rbf"
kernel_lengthscale = torch.tensor([1.0, 2.0, 3.0], device=device)
distributed = True
devices = set([torch.device("cuda:2"), torch.device("cuda:3"), torch.device("cuda:4")])
num_posterior_samples = 64
num_random_features = 64

Xtr = torch.randn(ntr, d, device=device)
ytr = torch.randn(ntr, device=device)
Xtst = torch.randn(ntst, d, device=device)
ytst = torch.randn(ntst, device=device)

In [4]:
gp_inference_model = GPInference(
    Xtr=Xtr,
    ytr=ytr,
    Xtst=Xtst,
    ytst=ytst,
    noise_variance=noise_variance,
    kernel_type=kernel_type,
    kernel_lengthscale=kernel_lengthscale,
    distributed=distributed,
    devices=devices,
    num_posterior_samples=num_posterior_samples,
    num_random_features=num_random_features,
)

In [5]:
nystrom_config = NystromConfig(rank=100, rho=noise_variance, damping_mode="adaptive")
accel_config = SAPAccelConfig(mu=noise_variance, nu=10.0)
solver_config = SAPConfig(
    precond_config=nystrom_config,
    max_iters=1000,
    atol=1e-6,
    rtol=1e-6,
    blk_sz=ntr // 10,
    accel_config=accel_config,
    device=device,
)
# solver_config = PCGConfig(
#     precond_config=nystrom_config,
#     max_iters=1000,
#     atol=1e-6,
#     rtol=1e-6,
#     device=device,
# )

In [6]:
results = gp_inference_model.perform_inference(
    solver_config=solver_config,
    W_init=None,
    eval_freq=10,
    log_in_wandb=True,
)

Initialized with clean caches. PID: 3763803
Initialized with clean caches. PID: 3763803


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpratikrathore8[0m ([33msketchy-opts[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[PID 3764080] Computing kernel for device cuda:3...
[PID 3764080] Kernel cached. Cache size: 1
[PID 3764079] Computing kernel for device cuda:2...
[PID 3764079] Kernel cached. Cache size: 1
[PID 3764081] Computing kernel for device cuda:4...
[PID 3764081] Kernel cached. Cache size: 1
[PID 3764095] Computing kernel for device cuda:2...
[PID 3764095] Kernel cached. Cache size: 1
[PID 3764097] Computing kernel for device cuda:4...
[PID 3764097] Kernel cached. Cache size: 1
[PID 3764096] Computing kernel for device cuda:3...
[PID 3764096] Kernel cached. Cache size: 1
[PID 3764095] Using cached kernel for device cuda:2
[PID 3764096] Using cached kernel for device cuda:3
[PID 3764097] Using cached kernel for device cuda:4
[PID 3764079] Using cached kernel for device cuda:2
[PID 3764080] Using cached kernel for device cuda:3
[PID 3764081] Using cached kernel for device cuda:4
[PID 3764080] Using cached kernel for device cuda:3
[PID 3764079] Using cached kernel for device cuda:2
[PID 3764081] 

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


[PID 3764079] Using cached kernel for device cuda:2
[PID 3764080] Using cached kernel for device cuda:3
[PID 3764081] Using cached kernel for device cuda:4


0,1
cum_time,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
iter_time,█▆█▇▇▆▁▁▁▂█▇▇▆▅▇▇▇▃▂▂▁█▇▂▇▇▂▂█▆▆▄▁▁▂▅▇█▄

0,1
cum_time,279.68425
iter_time,3.96749


Cleared global caches on shutdown. PID: 3763803
Cleared global caches on shutdown. PID: 3763803
