In [None]:
import torch

from scalable_gp_inference.hparam_training import _train_exact_gp

In [2]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

<torch._C.Generator at 0x7f51cc0eb9d0>

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda:1")
n = 10000
d = 3

# Training data is n points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, n).unsqueeze(1).expand(-1, d)
# True function is sin(2*pi*x) with Gaussian noise
freqs = 2 * torch.pi * torch.randn(d)
train_y = torch.sin(train_x @ freqs) + \
    torch.randn(train_x.shape[0]) * (0.04 ** 0.5)

train_x = train_x.to(device)
train_y = train_y.to(device)

In [4]:
hparams = _train_exact_gp(train_x, train_y, "rbf", {"lr": 0.1}, 100)
print(hparams.signal_variance)
print(hparams.kernel_lengthscale)
print(hparams.noise_variance)

1.3410016050517224
tensor([0.6043, 0.6043, 0.6043], device='cuda:1')
0.041294909693002424


In [5]:
hparams_doubled = hparams + hparams
print(hparams_doubled.signal_variance)
print(hparams_doubled.kernel_lengthscale)
print(hparams_doubled.noise_variance)

2.682003210103445
tensor([1.2086, 1.2086, 1.2086], device='cuda:1')
0.08258981938600485


In [6]:
print(hparams / 3)
print(hparams / 3.0)

GPHparams(signal_variance=0.4470005350172408, kernel_lengthscale=tensor([0.2014, 0.2014, 0.2014], device='cuda:1'), noise_variance=0.013764969897667475)
GPHparams(signal_variance=0.4470005350172408, kernel_lengthscale=tensor([0.2014, 0.2014, 0.2014], device='cuda:1'), noise_variance=0.013764969897667475)


In [7]:
print(hparams.to("cpu"))

GPHparams(signal_variance=1.3410016050517224, kernel_lengthscale=tensor([0.6043, 0.6043, 0.6043]), noise_variance=0.041294909693002424)


In [14]:
import pickle

In [15]:
with open("hparam.pkl", "wb") as f:
    pickle.dump(hparams.to("cpu"), f)
with open("hparam.pkl", "rb") as f:
    hparams_loaded = pickle.load(f)

print(hparams_loaded)

GPHparams(signal_variance=1.3410016050517224, kernel_lengthscale=tensor([0.6043, 0.6043, 0.6043]), noise_variance=0.041294909693002424)
