Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ApproximateGP compatibility #347

Closed
cisprague opened this issue Jan 8, 2020 · 11 comments
Closed

ApproximateGP compatibility #347

cisprague opened this issue Jan 8, 2020 · 11 comments
Labels
bug Something isn't working

Comments

@cisprague
Copy link

馃悰 Bug: gpytorch.models.ApproximateGP compatibility

To reproduce

** Code snippet to reproduce **

import gpytorch, torch, numpy as np, matplotlib.pyplot as plt, tqdm, botorch, utils
from scipy.interpolate import griddata

class Model(botorch.models.gpytorch.BatchedMultiOutputGPyTorchModel, gpytorch.models.ApproximateGP):

    def __init__(self, idim, grid_size):

        '''
        Variational GP.

        idim: input dimensions
        grid_size: number of inducing points
        '''

        # variational distribution and strategy
        vdist = gpytorch.variational.CholeskyVariationalDistribution(grid_size)
        vstra = gpytorch.variational.VariationalStrategy(
            self,
            torch.rand(grid_size, idim),
            vdist,
            learn_inducing_locations=True
        )

        # become an approximate GP
        super().__init__(vstra)
        
        # mean
        self.mean = gpytorch.means.ConstantMean()

        # covariance (kernel)
        self.cov = gpytorch.kernels.MaternKernel(ard_num_dims=2)
        self.cov = gpytorch.kernels.ScaleKernel(self.cov)

        # likelihood
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # loss record
        self.loss = list()

    def forward(self, x):

        '''
        Returns a multivariate normal distribution,
        parameterised by the mean and covariance.

        x: inputs
        '''

        m = self.mean(x)
        c = self.cov(x)
        return gpytorch.distributions.MultivariateNormal(m, c)

    def fit(self, xtrn, xstd, ytrn, lr=1e-1, epo=None, verbose=True):

        '''
        Optimises the hyperparameters of the GP
        with noisy inputs and noisy outputs.
        Accomidates noisy inputs by randomly
        sampling them at every iteration.

        xtrn: means of training data
        xstd: standard deviation of training data
        ytrn: noisy output targets
        lr: learning rate
        epo: number of training iterations
        '''

        # preprocess
        botorch.models.utils.validate_input_scaling(xtrn, ytrn)
        self._validate_tensor_args(xtrn, ytrn)
        self._set_dimensions(xtrn, ytrn)
        xtrn, ytrn, _  = self._transform_tensor_args(xtrn, ytrn)

        # let's train MFer
        self.train()
        self.likelihood.train()

        # optimisation algorithm
        opt = torch.optim.Adam(self.parameters(), lr=lr)

        # loss function
        mll = gpytorch.mlls.PredictiveLogLikelihood(
            self.likelihood, self, xtrn.size(-2)
        )

        # convergence checker
        convergence = botorch.optim.utils.ConvergenceCriterion()

        # progress bar
        pb = tqdm.tqdm(range(epo))

        # training iterations
        for i in pb:

            # zero gradients
            opt.zero_grad()

            # sample noisy inputs
            xsmp = torch.distributions.Normal(xtrn, xstd).rsample()

            # make prediction
            output = self(xsmp)

            # loss
            loss = -mll(output, ytrn)

            # compute gradient
            loss.backward()

            # optimisation step
            opt.step()

            # record loss
            self.loss.append(loss.item())

            # set progress bar description
            if verbose: pb.set_description('Loss {}'.format(loss.item()))

            # if converged
            if convergence.evaluate(loss.detach()):
                break

        # done training :)
        self.eval()
        self.likelihood.eval()

class Decider(object):

    def __init__(self, objective, model):

        # underlying objective function
        self.objective = objective

        # the underlying model
        self.model = model

        # observations
        self.xsmean = torch.empty(torch.Size([0, 2]))
        self.xsstd  = torch.empty(torch.Size([0, 1]))
        self.f = torch.empty(torch.Size([0, 1]))
        self.fmax = torch.empty(torch.Size([0, 1]))

    def decide(self, x, xstd, lr=1e-1):

        # sample the objective function
        obj = self.objective(
            torch.distributions.Normal(x, xstd).rsample()
        )

        # record sample
        self.xsmean = torch.cat((self.xsmean, x))
        self.xsstd = torch.cat((self.xsstd, xstd))
        self.f = torch.cat((self.f, obj))
        self.fmax = torch.cat((self.fmax, self.f.max().view(-1,1)))

        # fit the model to the observations
        self.model.fit(self.xsmean, self.xsstd, self.f, lr=lr, epo=1000, verbose=True)

        # initialise acquisition function
        acq = botorch.acquisition.ProbabilityOfImprovement(
            self.model,
            self.f.max()
        )

        # optimise acquisition function
        x, ximp = botorch.optim.optimize_acqf(
            acq,
            torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
            1,
            2,
            500
        )

        # return proposed sample and improvement
        return x, xstd, ximp

    def decision_process(self, x, xstd, ndec=4, lr=1e-2):

        # decision making loop
        for _ in range(ndec):

            # get sample candidate and improvement
            x, xstd, ximp = self.decide(x, xstd, lr=lr)

def franke(x, noise=0.02):
    n, d = x.shape
    X, Y = x[:,0].view(n, 1), x[:,1].view(n, 1)
    term1 = .75*torch.exp(-((9*X - 2).pow(2) + (9*Y - 2).pow(2))/4)
    term2 = .75*torch.exp(-((9*X + 1).pow(2))/49 - (9*Y + 1)/10)
    term3 = .5*torch.exp(-((9*X - 7).pow(2) + (9*Y - 3).pow(2))/4)
    term4 = .2*torch.exp(-(9*X - 4).pow(2) - (9*Y - 7).pow(2))
    f = term1 + term2 + term3 - term4
    noise = torch.randn(f.size())*noise
    return f + noise

def botst():

    # initial sample nump_points=1 x input_dim=2
    x = torch.rand(1,2).uniform_(0, 1)
    # initial sample num_points=1 x 1
    xstd = torch.rand(1,1).uniform_(0, 0.01)

    # instantiate Bayesian decider
    decider = Decider(franke, Model(2, 10))

    # make decision
    decider.decision_process(x, xstd, ndec=400, lr=1e-1)

    # test points
    n = 50
    xv, yv = torch.meshgrid([torch.linspace(0, 1, n), torch.linspace(0, 1, n)])
    xtst = torch.cat((
        xv.contiguous().view(xv.numel(), 1),
        yv.contiguous().view(yv.numel(), 1)),
        dim=1
    )
    ytst = franke(xtst, noise=0)

    # plot ground truth, learnt model + sampled points, max objective
    fig, ax = plt.subplots(3)

    x, y = np.meshgrid(xtst[:,0], xtst[:,1])
    z = griddata(
        (xtst[:,0], xtst[:,1]),
        ytst.flatten(),
        (x, y),
        method='linear'
    )
    ax[0].contourf(x, y, z)
    ax[0].set_title('Ground truth')

    with torch.no_grad():
        x, y = np.meshgrid(xtst[:,0], xtst[:,1])
        res = decider.model(xtst)
        res = decider.model.likelihood(res)
        z = griddata(
            (xtst[:,0], xtst[:,1]),
            res.mean,
            (x, y),
            method='linear'
        )
        ax[1].contourf(x, y, z)
        ax[1].plot(decider.xsmean[:,0], decider.xsmean[:,1], 'kx')
        ax[1].set_title('Gaussian process')

    ax[2].plot(decider.fmax[:,0], 'k-')

    plt.tight_layout()

    plt.show()

if __name__ == '__main__':
    botst()

** Stack trace/error message **

Loss 1.4877734184265137:   0%|                                                                                                           | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "bug.py", line 256, in <module>
    botst()
  File "bug.py", line 210, in botst
    decider.decision_process(x, xstd, ndec=400, lr=1e-1)
  File "bug.py", line 186, in decision_process
    x, xstd, ximp = self.decide(x, xstd, lr=lr)
  File "bug.py", line 174, in decide
    500
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/optim/optimize.py", line 170, in optimize_acqf
    fixed_features=fixed_features,
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/gen.py", line 123, in gen_candidates_scipy
    options={k: v for k, v in options.items() if k != "method"},
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/scipy/optimize/_minimize.py", line 600, in minimize
    callback=callback, **options)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/scipy/optimize/lbfgsb.py", line 335, in _minimize_lbfgsb
    f, g = func_and_grad(x)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/scipy/optimize/lbfgsb.py", line 285, in func_and_grad
    f = fun(x, *args)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/scipy/optimize/optimize.py", line 326, in function_wrapper
    return function(*(wrapper_args + args))
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/scipy/optimize/optimize.py", line 64, in __call__
    fg = self.fun(x, *args)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/gen.py", line 110, in f
    loss = -acquisition_function(X_fix).sum()
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/utils/transforms.py", line 141, in decorated
    return method(cls, X)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/acquisition/analytic.py", line 231, in forward
    posterior = self._get_posterior(X=X)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/acquisition/analytic.py", line 62, in _get_posterior
    posterior = self.model.posterior(X)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/botorch/models/gpytorch.py", line 255, in posterior
    mvn = self(X)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/gpytorch/models/approximate_gp.py", line 81, in __call__
    return self.variational_strategy(inputs, prior=prior)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/gpytorch/variational/variational_strategy.py", line 165, in __call__
    return super().__call__(x, prior=prior)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/gpytorch/variational/_variational_strategy.py", line 127, in __call__
    variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/gpytorch/module.py", line 24, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/home/cisprague/anaconda3/lib/python3.7/site-packages/gpytorch/variational/variational_strategy.py", line 100, in forward
    interp_term = torch.triangular_solve(induc_data_covar.double(), L, upper=False)[0].to(full_inputs.dtype)
RuntimeError: The size of tensor a (2) must match the size of tensor b (500) at non-singleton dimension 0

Expected Behavior

BoTorch should be compatible with any GPyTorch model via inhereting from botorch.models.gpytorch.GPyTorchModel, but this does not seem to work in the case of a gpytorch.models.ApproximateGP with a variational strategy. It should work like other models do.
I am trying to model a mapping from n x d heteroskedastic inputs to n x m homoskedastic outputs.

System information

Please complete the following information:

  • BoTorch version: 0.1.4
  • GPyTorch version: 0.3.6
  • Torch version: 1.3.1

Additional context

Using a mobile robot, with growing position uncertainty, to build an environmental map with sensors.

@cisprague cisprague added the bug Something isn't working label Jan 8, 2020
@Balandat
Copy link
Contributor

Balandat commented Jan 8, 2020

While approximate GPs should work with the standard model/posterior API, we haven鈥檛 actually used/tested them extensively. I鈥檓 currently traveling, will take a look at the specific issue here incentive I get back.

@Balandat
Copy link
Contributor

OK so upon further digging it seems that this particular failure you're running into has to do with caching cholesky factors inside gpytorch. Specifically, if I manually compute the cholesky factor for induc_induc_covar I get a factor of the right size, while in the current code it seems to get a cached result. Maybe @gpleiss or @jacobrgardner have any immediate thoughts (I'm not very familiar with the variational strategy implementation).

Screen Shot 2020-01-10 at 6 01 26 AM

@Balandat
Copy link
Contributor

This may be related to cornellius-gp/gpytorch#10

@jacobrgardner
Copy link

Let me take a look and see if this is on our end -- in general we believe the variational code is as stable as the exact code at this point, but there may be some issue related to multiple batches or some other particularly complicated use case.

@Balandat
Copy link
Contributor

yeah I guess something funky must be going on with the caching here, if you look at the screenshot self._cholesky_factor(induc_induc_covar) should return psd_safe_cholesky(induc_induc_covar.evaluate()) if no cached value is used (unless I'm missing something).

@jacobrgardner
Copy link

@cisprague
Copy link
Author

Thanks for the help everyone. If manually computing the Cholesky factor gives the correct size, is there a straightforward way to do this with the GPyTorch or BoTorch API?

Just to reiterate, my goal is to map multidimensional heteroskedastic inputs to multidimensional homoskedastic outputs. Is there an easier way to do this? It seems variational methods are the go-to for this.

@thomasahle
Copy link

This is also needed to support the BernoulliLikelihood, which only works with ApproximateGP.
I wrote a BayesianOptimizer supporting ApproximateGP here as a demo: https://github.com/thomasahle/noisy-bayesian-optimization
It only supports lower confidence bounds so far.

@Balandat
Copy link
Contributor

@thomasahle Here is a simple demo for using the BoTorch acquisition function & optimization machinery with an Approximate GP with Bernoulli Likelihood (model taken from the gpytorch tutorial):
botorch_bernoulli_approx_gp.ipynb.txt

For the basic use case, this is as simple as

  1. Mixing in GPyTorchModel as another superclass to the GP model
  2. defining a _num_outputs attribute (used internally by BoTorch)

The BatchedMultiOutputGPyTorchModel does some trickery that likely doesn't play well with variational inference, we'll have to take a closer look at that.

@Balandat
Copy link
Contributor

@cisprague as Jake said, #1047 will probably fix at least part of this issue.

@Balandat
Copy link
Contributor

Balandat commented May 4, 2020

I'm going to close this for now as this should work fine when not using the BatchedMultiOutputGPyTorchModel. Feel free to re-open if needed.

@Balandat Balandat closed this as completed May 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants