Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] fit_gpytorch_model complains of tensors on multiple devices when using the KroneckerMultiTaskGP #1323

Closed
123epsilon opened this issue Jul 22, 2022 · 18 comments
Labels
bug Something isn't working

Comments

@123epsilon
Copy link

🐛 Bug

Despite my model and all tensors in the script being on the GPU, fit_gpytorch_model complains about tensors existing on both cuda:0 and CPU.

To reproduce

This code works when I just use a MultiOutput model with a series of SingleExact GPs and the corresponding SumMarginalLogLikelihoods, but trying to implement this with the multitask GP seems to cause an error. This doesn't seem to have anything to do with my problem function as it works with other models using essentially the same code, all I've changed is the model and the likelihood calculation.

Code snippet to reproduce

from problem_function import reconstruction_problem
from botorch.utils.transforms import normalize
from botorch.models.transforms.outcome import Standardize
from botorch.models.multitask import KroneckerMultiTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_model
from botorch.utils.sampling import draw_sobol_samples
import torch

def generate_initial_data(problem, n=6):
    # generate training data
    train_x = draw_sobol_samples(bounds=problem.bounds, n=n, q=1).squeeze(1)
    train_obj_true = problem(train_x)
    return train_x, train_obj_true

problem_kwargs = {
    'master_h5': 'master.h5',
    'position_csv': 'pos.csv',
    'probe_npy': '../probe.npy',
    'reconstruct_iter': 100
}

tkwargs = {
        "dtype": torch.double,
        "device": torch.device("cuda" if torch.cuda.is_available() else "CPU"),
}

problem = reconstruction_problem(device='cuda', **problem_kwargs).to(**tkwargs)

train_x, train_obj = generate_initial_data(problem, n=6)

print(train_x.size()) #torch.Size([6, 3])
print(train_obj.size()) #torch.Size([6, 2])
print(train_x.is_cuda) #True
print(train_obj.is_cuda) #True

train_x = normalize(train_x, problem.bounds)
model = KroneckerMultiTaskGP(train_x, train_obj, outcome_transform=Standardize(m=2)).to('cuda')

print(next(model.parameters()).is_cuda) #True
print(next(model.likelihood.parameters()).is_cuda) #True

mll = ExactMarginalLogLikelihood(model.likelihood, model)

print(next(model.parameters()).is_cuda) #True
print(next(model.likelihood.parameters()).is_cuda) #True
print(train_x.is_cuda) #True
print(train_obj.is_cuda) #True
print(next(mll.parameters()).is_cuda) #True

fit_gpytorch_model(mll)
 

** Stack trace/error message **

Traceback (most recent call last):
  File "/gpfs/fs1/home/ac.arhammkhan/tike_exp/experiments/mttest.py", line 42, in <module>
    fit_gpytorch_model(mll)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/fit.py", line 130, in fit_gpytorch_model
    mll, _ = optimizer(mll, track_iterations=False, **kwargs)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/fit.py", line 239, in fit_gpytorch_scipy
    res = minimize(
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_minimize.py", line 692, in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py", line 308, in _minimize_lbfgsb
    sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 263, in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 158, in __init__
    self._update_fun()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
    self._update_fun_impl()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
    self.f = fun_wrapped(self.x)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
    fx = fun(np.copy(x), *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 76, in __call__
    self._compute_if_needed(x, *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 70, in _compute_if_needed
    fg = self.fun(x, *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 217, in _scipy_objective_and_grad
    return _handle_numerical_errors(error=e, x=x)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 244, in _handle_numerical_errors
    raise error  # pragma: nocover
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 215, in _scipy_objective_and_grad
    loss = -mll(*args).sum()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/module.py", line 30, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py", line 63, in forward
    res = self._add_other_terms(res, params)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py", line 43, in _add_other_terms
    res.add_(prior.log_prob(closure(module)).sum())
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/lkj_prior.py", line 105, in log_prob
    log_prob_corr = self.correlation_prior.log_prob(correlations)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/lkj_prior.py", line 62, in log_prob
    return super().log_prob(X_cholesky)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/prior.py", line 27, in log_prob
    return super(Prior, self).log_prob(self.transform(x))
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/torch/distributions/lkj_cholesky.py", line 116, in log_prob
    order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Additional Info

When I run this script and change the problem initialization line by removing .to(**tkwargs
(i.e. problem = reconstruction_problem(device='cuda', **problem_kwargs))

Then the script throws a different error:

Traceback (most recent call last):
  File "/gpfs/fs1/home/ac.arhammkhan/tike_exp/experiments/mttest.py", line 42, in <module>
    fit_gpytorch_model(mll)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/fit.py", line 128, in fit_gpytorch_model
    sample_all_priors(mll.model)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 40, in sample_all_priors
    raise RuntimeError(
RuntimeError: Must provide inverse transform to be able to sample from prior.

Is fit_gpytorch_model incompatible with this model?

Expected Behavior

The fit_gpytorch_model method should run without error.

System information

Please complete the following information:
BoTorch 0.6.5
GPyTorch 1.7.0
PyTorch 1.11.0post202

Linux 5.4.0-117-generic

@123epsilon 123epsilon added the bug Something isn't working label Jul 22, 2022
@Balandat
Copy link
Contributor

Hmm this looks like a bug. @esantorella would you be able to take a look at this next week?

@123epsilon
Copy link
Author

@Balandat Is there any workaround for this you can recommend?

@Balandat
Copy link
Contributor

Balandat commented Jul 28, 2022

I tried this out but wasn't able to reproduce this issue. I had to make some changes since I don't have the problem_function module in your code example. The following code runs without error for me on a GPU machine:

# from problem_function import reconstruction_problem
from botorch.utils.transforms import normalize
from botorch.models.transforms.outcome import Standardize
from botorch.models.multitask import KroneckerMultiTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_model
from botorch.utils.sampling import draw_sobol_samples
import torch

def generate_initial_data(problem, n=6):
    # generate training data
    train_x = draw_sobol_samples(bounds=problem.bounds, n=n, q=1).squeeze(1)
    train_obj_true = problem(train_x)
    return train_x, train_obj_true

problem_kwargs = {
    'master_h5': 'master.h5',
    'position_csv': 'pos.csv',
    'probe_npy': '../probe.npy',
    'reconstruct_iter': 100
}

tkwargs = {
        "dtype": torch.double,
        "device": torch.device("cuda" if torch.cuda.is_available() else "CPU"),
}

# problem = reconstruction_problem(device='cuda', **problem_kwargs).to(**tkwargs)
# train_x, train_obj = generate_initial_data(problem, n=6)

train_x = torch.rand(6, 3, **tkwargs)
train_obj = torch.rand(6, 2, **tkwargs)

print(train_x.size()) #torch.Size([6, 3])
print(train_obj.size()) #torch.Size([6, 2])
print(train_x.is_cuda) #True
print(train_obj.is_cuda) #True

# train_x = normalize(train_x, problem.bounds)
train_x = normalize(train_x, torch.tensor([[0., 0., 0.], [1., 1., 1.]], **tkwargs))

model = KroneckerMultiTaskGP(train_x, train_obj, outcome_transform=Standardize(m=2)).to('cuda')

print(next(model.parameters()).is_cuda) #True
print(next(model.likelihood.parameters()).is_cuda) #True

mll = ExactMarginalLogLikelihood(model.likelihood, model)

print(next(model.parameters()).is_cuda) #True
print(next(model.likelihood.parameters()).is_cuda) #True
print(train_x.is_cuda) #True
print(train_obj.is_cuda) #True
print(next(mll.parameters()).is_cuda) #True

fit_gpytorch_model(mll)

If I modify train_x = normalize(train_x, torch.tensor([[0., 0., 0.], [1., 1., 1.]], **tkwargs)) to instead be train_x = normalize(train_x, torch.tensor([[0., 0., 0.], [1., 1., 1.]])) then I do get the device error, but at a different place in the stack (in normalize).

Would it be possible to share the exact code you're using? Without that it'll be hard to debug this.

@123epsilon
Copy link
Author

The backend I'm using for some of my problem calculations is a bit involved to install, luckily I have a more minimal example based on your code above:

from botorch.utils.transforms import normalize
from botorch.models.transforms.outcome import Standardize
from botorch.models.multitask import KroneckerMultiTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_model
import torch
from botorch.utils.sampling import draw_sobol_samples

tkwargs = {
        "dtype": torch.double,
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    }

print(f'Device {tkwargs["device"]}')

train_x = torch.rand(1,3, **tkwargs)
train_obj = torch.rand(1,2,**tkwargs)

print(train_x.size())
print(train_obj.size())
print(train_x.is_cuda)
print(train_obj.is_cuda)
train_x = normalize(train_x, torch.tensor([[0., 0., 0.], [1., 1., 1.]], **tkwargs))

model = KroneckerMultiTaskGP(train_x, train_obj, outcome_transform=Standardize(m=2))
print(next(model.parameters()).is_cuda)
print(next(model.likelihood.parameters()).is_cuda)
mll = ExactMarginalLogLikelihood(model.likelihood, model)

print(next(model.parameters()).is_cuda)
print(next(model.likelihood.parameters()).is_cuda)
print(train_x.is_cuda)
print(train_obj.is_cuda)
print(next(mll.parameters()).is_cuda)

fit_gpytorch_model(mll)

This still produces the same error:

Device cuda
torch.Size([1, 3])
torch.Size([1, 2])
True
True
True
True
True
True
True
True
True
Traceback (most recent call last):
  File "/gpfs/fs1/home/ac.arhammkhan/tike_exp/experiments/mttest.py", line 55, in <module>
    fit_gpytorch_model(mll)#.to(**tkwargs))
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/fit.py", line 130, in fit_gpytorch_model
    mll, _ = optimizer(mll, track_iterations=False, **kwargs)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/fit.py", line 239, in fit_gpytorch_scipy
    res = minimize(
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_minimize.py", line 692, in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py", line 308, in _minimize_lbfgsb
    sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 263, in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 158, in __init__
    self._update_fun()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
    self._update_fun_impl()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
    self.f = fun_wrapped(self.x)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
    fx = fun(np.copy(x), *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 76, in __call__
    self._compute_if_needed(x, *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 70, in _compute_if_needed
    fg = self.fun(x, *args)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 217, in _scipy_objective_and_grad
    return _handle_numerical_errors(error=e, x=x)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 244, in _handle_numerical_errors
    raise error  # pragma: nocover
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 215, in _scipy_objective_and_grad
    loss = -mll(*args).sum()
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/module.py", line 30, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py", line 63, in forward
    res = self._add_other_terms(res, params)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py", line 43, in _add_other_terms
    res.add_(prior.log_prob(closure(module)).sum())
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/lkj_prior.py", line 105, in log_prob
    log_prob_corr = self.correlation_prior.log_prob(correlations)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/lkj_prior.py", line 62, in log_prob
    return super().log_prob(X_cholesky)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/gpytorch/priors/prior.py", line 27, in log_prob
    return super(Prior, self).log_prob(self.transform(x))
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/torch/distributions/lkj_cholesky.py", line 116, in log_prob
    order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@Balandat
Copy link
Contributor

Hmm looks like your example doesn't do the .to('cuda') on the instantiated KroneckerMultiTaskGP? If you do that does this fix the issue? (note that this was part of your original code example).

@123epsilon
Copy link
Author

Hmm looks like your example doesn't do the .to('cuda') on the instantiated KroneckerMultiTaskGP? If you do that does this fix the issue? (note that this was part of your original code example).

Explicitly adding this to the model initialization line as in my original code example results in the same error. Its worth noting that even in the example without that code, the model parameters are reported to be on the GPU according to torch (the print(next(model.parameters()).is_cuda) statement).

@Balandat
Copy link
Contributor

Balandat commented Aug 2, 2022

Interesting. I don't run into this error (this is on pytorch and gpytorch current dev version). My first guess would be that this is either some change in type promotions on the pytorch side, or some changes in the recent gpytorch setup (some of the changes from gpytorch 1.7.0 -> 1.8.0 deal with moving tensors to the correct devices). Can you try running this on pytorch 1.12 with gpytorch 1.8.0 to see if that fixes the issue?

@123epsilon
Copy link
Author

Thank you for the help! I can confirm that the above script works with pytorch 1.12.0, gpytorch 1.8.0, and botorch 0.6.5 with no device errors.

As an aside (I can open a separate issue for this as well, but if you have any intuition now it would be helpful) when I attempted to train on the CPU after a long period of time of optimizing over my custom problem function with no issue, seemingly randomly my training script throws the following error:

Traceback (most recent call last):
  File "/lcrc/project/ECP-EZ/arham/cesm_exp/experiments/multitask_mobo.py", line 109, in <module>
    fit_gpytorch_model(mll, sequential=sequential)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/fit.py", line 128, in fit_gpytorch_model
    sample_all_priors(mll.model)
  File "/home/ac.arhammkhan/.conda/envs/ptychodus/lib/python3.10/site-packages/botorch/optim/utils.py", line 40, in sample_all_priors
    raise RuntimeError(
RuntimeError: Must provide inverse transform to be able to sample from prior.

Looking at the source I can't understand why this property of the model would change at any point during training - I am following the same loop as what is in the MOBO tutorial in the docs except that I am initializing the model and computing the mll as I am above. This issue seems to appear after running for a few hours.

@saitcakmak
Copy link
Contributor

So, you'd only call sample_all_priors if the initial attempt at model fitting failed. This happens here. Looks like KroneckerMultiTaskGP has some priors that do not support this functionality.

@saitcakmak
Copy link
Contributor

Looks like both MultitaskGaussianLikelihood and IndexKernel are missing a setting_closure. I don't have much context into what a setting_closure is. The GPyTorch description reads:

            setting_closure (callable, optional):
                A function taking in the module instance and a tensor in (transformed) parameter space,
                initializing the internal parameter representation to the proper value by applying the
                inverse transform. Enables setting parametres directly in the transformed space, as well
                as sampling parameter values from priors (see `sample_from_prior`)

I guess it needs to be provided for you to be able to sample from the priors and set those values on the model.

@Balandat
Copy link
Contributor

This is most likely the LKJCovariancePrior over the intra-task correlation matrix, defined by default here:

task_covar_prior = LKJCovariancePrior(

If you trace this down this is registered here: https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/kernels/index_kernel.py#L71

Basically this would need a setting_closure. In this case we're passing a covariance matrix Sigma in, so what we'd have to do here is define the closure to take in Sigma, factor it into a correlation matrix C and the variances var, perform a root decomposition of C and then set the covar_factor and var attributes of the IndexKernel.

@Balandat
Copy link
Contributor

cc @j-wilson who I just talked to about LKJ priors earlier today...

@123epsilon
Copy link
Author

Hi just wondering on the status of this issue and whether there is anything I could do in the meanwhile to make use of a multitask GP, in my use case all inputs correspond to all objective functions which is why I defaulted to using the Kronecker model above.

@jackliu333
Copy link

Does fit_gpytorch_model(mll) only work with CPU and not GPU? I have to move mll from cuda to CPU, run fit_gpytorch_model(mll) and then move back to GPU to complete one BO loop. Is there a GPU-only way to do that?

@Balandat
Copy link
Contributor

@jackliu333 what makes you think that fit_gpytorch_model (now deprecated and superseded by fit_gpytorch_mll) does not work when the mll lives on the GPU? Are you getting errors when trying to do this?

@jackliu333
Copy link

jackliu333 commented Dec 15, 2022

Yes I met the following error:

from botorch.fit import fit_gpytorch_mll
fit_gpytorch_mll(mll);
---------------------------------------------------------------------------
MDNotImplementedError                     Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/botorch/utils/dispatcher.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
     92         try:
---> 93             return func(*args, **kwargs)
     94         except MDNotImplementedError:

8 frames
MDNotImplementedError: 

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/torch/_tensor.py](https://localhost:8080/#) in __array__(self, dtype)
    955             return self.numpy()
    956         else:
--> 957             return self.numpy().astype(dtype, copy=False)
    958 
    959     # Wrap Numpy array again in a suitable tensor when done, to support e.g.

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

which disappears after I move mll to CPU:
fit_gpytorch_mll(mll.cpu());

The model is in GPU:
print(next(model.parameters()).is_cuda) True

@Balandat
Copy link
Contributor

Moving this to a new issue: #1566

@esantorella
Copy link
Member

Closing in favor of #1860 for clarity since the original issue is resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants