In [1]:
import torch

from rlaopt.solvers import PCGConfig, SAPConfig, SAPAccelConfig
from rlaopt.preconditioners import NystromConfig

from scalable_gp_inference.hparam_training import train_exact_gp_subsampled
from scalable_gp_inference.gp_inference import GPInference

In [2]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

<torch._C.Generator at 0x7f3e380ab8b0>

In [3]:
device = torch.device("cuda:2")

ntr = 10000
ntst = 1000
d = 3
data_noise = 0.04

freqs = 2 * torch.pi * torch.randn(d)
Xtr = torch.linspace(0, 1, ntr).unsqueeze(1).expand(-1, d)
Xtst = torch.linspace(0, 1, ntst).unsqueeze(1).expand(-1, d)
ytr = torch.sin(Xtr @ freqs) + torch.randn(Xtr.shape[0]) * (data_noise ** 0.5)
ytst = torch.sin(Xtst @ freqs) + torch.randn(Xtst.shape[0]) * (data_noise ** 0.5)

Xtr = Xtr.to(device)
Xtst = Xtst.to(device)
ytr = ytr.to(device)
ytst = ytst.to(device)


kernel_type = "rbf"
distributed = True
devices = set([torch.device("cuda:2"), torch.device("cuda:3"), torch.device("cuda:4")])
num_posterior_samples = 64
num_random_features = 64

In [4]:
hparams = train_exact_gp_subsampled(Xtr=Xtr, ytr=ytr, kernel_type=kernel_type, opt_hparams={"lr": 0.1},
            training_iters=100, subsample_size=Xtr.shape[0], num_trials=1)

In [5]:
print(hparams)

GPHparams(signal_variance=1.340938055313761, kernel_lengthscale=tensor([0.6043, 0.6043, 0.6043], device='cuda:2'), noise_variance=0.04129490914970158)


In [6]:
gp_inference_model = GPInference(
    Xtr=Xtr,
    ytr=ytr,
    Xtst=Xtst,
    ytst=ytst,
    kernel_type=kernel_type,
    kernel_hparams=hparams,
    num_posterior_samples=num_posterior_samples,
    num_random_features=num_random_features,
    distributed=distributed,
    devices=devices,
)

In [7]:
nystrom_config = NystromConfig(rank=100, rho=hparams.noise_variance, damping_mode="adaptive")
accel_config = SAPAccelConfig(mu=hparams.noise_variance, nu=10.0)
solver_config = SAPConfig(
    precond_config=nystrom_config,
    max_iters=1000,
    atol=1e-6,
    rtol=1e-6,
    blk_sz=ntr // 10,
    accel_config=accel_config,
    device=device,
)
# solver_config = PCGConfig(
#     precond_config=nystrom_config,
#     max_iters=1000,
#     atol=1e-6,
#     rtol=1e-6,
#     device=device,
# )

In [8]:
results = gp_inference_model.perform_inference(
    solver_config=solver_config,
    W_init=None,
    eval_freq=10,
    log_in_wandb=True,
)

Initialized with clean caches. PID: 2554832
Initialized with clean caches. PID: 2554832


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpratikrathore8[0m ([33msketchy-opts[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[PID 2557937] Computing kernel for device cuda:2...
[PID 2557937] Kernel cached. Cache size: 1
[PID 2557939] Computing kernel for device cuda:4...
[PID 2557939] Kernel cached. Cache size: 1
[PID 2557938] Computing kernel for device cuda:3...
[PID 2557938] Kernel cached. Cache size: 1
[PID 2557953] Computing kernel for device cuda:3...
[PID 2557953] Kernel cached. Cache size: 1
[PID 2557952] Computing kernel for device cuda:2...
[PID 2557952] Kernel cached. Cache size: 1
[PID 2557954] Computing kernel for device cuda:4...
[PID 2557954] Kernel cached. Cache size: 1
[PID 2557953] Using cached kernel for device cuda:3
[PID 2557954] Using cached kernel for device cuda:4
[PID 2557952] Using cached kernel for device cuda:2
[PID 2557938] Using cached kernel for device cuda:3
[PID 2557937] Using cached kernel for device cuda:2
[PID 2557939] Using cached kernel for device cuda:4
[PID 2557937] Using cached kernel for device cuda:2
[PID 2557938] Using cached kernel for device cuda:3
[PID 2557939] 

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
cum_time,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▆▆▆▆▇▇██
iter_time,▁▇▆▄▃▅▃▄▄▅▄▄▄▄▄█▆▆▆▄▆▇▆▇▆

0,1
cum_time,41.58854
iter_time,2.11672


Cleared global caches on shutdown. PID: 2554832
Cleared global caches on shutdown. PID: 2554832
