Skip to content

Batching Gaussian Processes #1679

Open
Open
@mileslucas

Description

@mileslucas

To begin I tried logging in with GitHub and also creating an account on the pyro forums, but neither of those is working.

Problem

I need to fit a batch of four independent Gaussian Processes and I don't want to have to use for loops for fitting each one. The current GP's are able to broadcast properly to my outputs, but I can't batch them so that the inputs are independent.

My input data is a tensor of shape (264, 3). So each input has 3 feature dimensions. My observed data has shape (4, 264). where 4 is my batch shape and signifies the four independent Gaussian processes I'm trying to train.

I'm using an RBF kernel and the formula I want looks like this-
image
where X (and Z) have shape (Nsamples, 3)

This ARD kernel is easy enough to implement like

amplitude = torch.tensor(20.)
lengthscale = torch.tensor([300., 30., 30.])
kernel = gp.kernels.RBF(input_dim=3, variance=amplitude, lengthscale=lengthscale)

My first problem

In my derivation, I use priors on my variance and lengthscales such that each lengthscale in the kernel has its own Gamma prior and the variance has a uniform prior

image

Unfortunately I can't set a multidimensional prior on the kernel lengthscale, regardless of kernel dimensions-

# Dummy data
X = torch.ones((264, 3))
y = torch.ones((4, 264)) 

amplitude = torch.tensor(20.)
lengthscale = torch.tensor([300., 30., 30.])
priors = torch.tensor([
    [2, 2, 2], 
    [0.0075, 0.075, 0.075]
])
kernel = gp.kernels.RBF(input_dim=3, variance=amplitude, lengthscale=lengthscale)
kernel.set_prior("lengthscale", dist.Gamma(priors[0], priors[1]))
kernel.set_prior("variance", dist.Uniform(10, 150))

gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1.))

# Optimize
optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = int(1e4)
for i in tqdm.trange(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

yields ValueError: Model and guide event_dims disagree at site 'GPR/RBF/lengthscale': 0 vs 1

I can get the same code to work above without the priors.

Batching Problems

Beyond the priors, I also would really like to batch my GPs to take advantage of Tensor math or even parallelization instead of doing a for loop for each of my batch dimensions. I have tried a similar approach to my problems in TensorFlow Probability and I had a lot of struggle with setting up the ARD kernel and overall using TF but their batch scheme worked really well.

The approach I would like to use is something like this

# Dummy data
X = torch.ones((264, 3))
y = torch.ones((4, 264)) 

amplitude = 20 * torch.ones(4)
lengthscale = torch.tensor(np.tile([300., 30., 30.], (4, 1)))
kernel = gp.kernels.RBF(input_dim=3, variance=amplitude, lengthscale=lengthscale)

gpr = gp.models.GPRegression(X, y, kernel, noise=torch.ones(4))

optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = int(1e4)
for i in tqdm.trange(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

This fails with RuntimeError: The size of tensor a (264) must match the size of tensor b (4) at non-singleton dimension 0

Moreover, I'd love to bring priors into this to have a final working product like

# Dummy data
X = torch.ones((264, 3))
y = torch.ones((4, 264)) 

amplitude = 20 * torch.ones(4)
lengthscale = torch.tensor(np.tile([300., 30., 30.], (4, 1)))
kernel = gp.kernels.RBF(input_dim=3, variance=amplitude, lengthscale=lengthscale)
gpriors = torch.tensor([
    [2, 2, 2], 
    [0.0075, 0.075, 0.075]
])
kernel.set_prior("lengthscale", dist.Gamma(gpriors[0], gpriors[1]))
kernel.set_prior("variance", dist.Uniform(10, 150))

gpr = gp.models.GPRegression(X, y, kernel, noise=torch.ones(4))

optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = int(1e4)
for i in tqdm.trange(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions