Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying Knowledge Gradient for time-dependent kernels #578

Closed
r-ashwin opened this issue Oct 20, 2020 · 37 comments
Closed

Modifying Knowledge Gradient for time-dependent kernels #578

r-ashwin opened this issue Oct 20, 2020 · 37 comments

Comments

@r-ashwin
Copy link

r-ashwin commented Oct 20, 2020

Issue description

I want to modify KG for time-dependent problems as follows. Given x in X (some compact space) and 0 <= t <= T, I have a GP model with prior GP(mu, k_xt), where k_xt = k_x * k_t with k_x capturing covariance in 'x' space and k_t in 't' space. At time t I have data D_t = {(x_i, t_i), y_i }, i=1,...,n and t>t_n. I want to define KG as follows

a_KG(x, t) = E_x'[max_x' mu(x', T) | {(x, t), y_i}]

where y_i is sampled from GP(mu(x, t), k_xt) | D_t). In other words, my 'fantasy model' is at current time t however, my 'inner optimization' problem maximizes the posterior at T predicted via the fantasy model. Also my acquisition function a_KG is defined at t

Question: How should I modify the qKnowledgeGradient class to achieve this, so I can take advantage of the efficient one-shot implementation of qKG? I have provided code for the GP I am using if you want to work with that.

Any help is greatly appreciated! Please let me know if you need more information. Thanks!

(apologies for trying to write equations in Markdown)

import math
import torch

from botorch.fit import fit_gpytorch_model
from botorch.models import SingleTaskGP
from botorch.utils import standardize
from gpytorch.mlls import ExactMarginalLogLikelihood

import gpytorch                                      # main GP library
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np
from botorch.fit import fit_gpytorch_model           # Wrapper for gpytorch to use in BO
from botorch.models.gpytorch import GPyTorchModel
def canned_dynamic_gp(train_x, train_y):
    '''
    fits a single-task GP for f(x,t) with a product kernel k_xt = k_x * k_t

    :param train_x:
    :param train_y:
    :return: gp object
    '''

    class ExactGPModel(gpytorch.models.ExactGP):
        num_outputs = 1  # to inform the BoTorch api
        def __init__(self, train_x, train_y, likelihood):
            super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.Rbfx_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
            self.Rbft_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

        def forward(self, x):
            X = x[:, :-1]
            t = x[:, -1]
            mean_x = self.mean_module(x)
            covar_x = self.Rbfx_module(X) * self.Rbft_module(t)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    gp         = ExactGPModel(train_x, train_y, likelihood)
    mll        = ExactMarginalLogLikelihood(gp.likelihood, gp)
    gp.likelihood.noise_covar.raw_noise_constraint.upper_bound = 1e-3 # constraint on observation noise
    fit_gpytorch_model(mll);

    return gp
# Synthetic time-dependent test function
def quadratic(x, a=-4., s=0.5):
    y = a * np.sum(np.power(x - s, 2), axis=1)  # quadratic
    return y
def f_xt_d(x, p, coeff=None, t=None):
    n, _ = np.shape(x)
    if t is None and x.shape[1] > p:
        t = x[:, -1]
        X = x[:, :-1]
    elif t is not None and x.shape[1] == p:
        X = x
    else:
        raise ValueError('x must have p+1 columns when t is None')

    if len(t) != n and len(t) != 1:
        raise ValueError('t should be of length 1 or n')

    if coeff is None:
        coeff = np.array([1., 1.])

    phi_xt = np.vstack((2 * np.sum(np.multiply(X, np.atleast_2d(np.sin(t)).T ), axis=1),
                        -np.power(np.atleast_2d(np.sin(t)), 2)
                        ))
    f_xt = np.matmul(np.atleast_2d(phi_xt).T, coeff)
    return f_xt
training_size = 100
p  = 1  # x-dimensions
lb = 0. # x lower bound
ub= 1. # x upper bound
T  = 4. # t upper bound

bounds  = torch.stack([torch.zeros(2), torch.tensor([1,T])])
fxt_func= f_xt_d

t          = np.linspace(0, 3.9, training_size)
train_x = lb + (ub-lb) * np.random.uniform(size=[training_size, p])
train_x = np.hstack((train_x, np.atleast_2d(t).T))
train_y = quadratic(np.atleast_2d(train_x[:,0]).T) + fxt_func(train_x, p,)

# convert everything to Torch
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
# GP model
gp = canned_dynamic_gp(train_x, train_y)

System Info

Please provide information about your setup, including

  • BoTorch Version 0.3.1
  • GPyTorch Version 1.2.0
  • PyTorch Version 1.6.0
  • Windows and Linux
@Balandat
Copy link
Contributor

This is an interesting problem. It's quite related to the multi-fidelity setting, where we take a measurement at some fidelity and then project to a "target fidelity". This is done in qMultiFidelityKnowledgeGradient: https://github.com/pytorch/botorch/blob/master/botorch/acquisition/knowledge_gradient.py#L415-L417

I imagine you can do something similar where you essentially return the posterior of the fantasy model evaluated at time T. Since you probably don't want to optimize over t (i.e. when to make the observation)` you will probably want to not pass this in as a variable to optimize, but instead keep that time value an attribute on the acquisition instance.

I'm pretty swamped right now, but i can take a look later this week - hope the pointers above are helpful in the meantime.

@r-ashwin
Copy link
Author

Thanks for the pointers -- I will give it a shot. Whenever you get the time, an illustration with some toy code would be great, since I am very new to the Pytorch paradigm.

@saitcakmak
Copy link
Contributor

Hi @r-ashwin. You can achieve this by passing the fixed_features argument to optimize_acqf.

See the following simple example where I pretend that X = (x, t) and fix t = 0.5 by passing fixed_features = {1: 0.5}. fixed_features is a dictionary where the key is the index of the variable that you want fixed, and the value is the value to fix it to. By passing fixed_features = {1: 0.5}, I fix X[..., 1] = 0.5, which corresponds to the time in my example. If you had multiple values to fix, e.g. X = (x_1, x_2, t_1, t_2), you would then use fixed_features = {2: fixed_t_1, 3: fixed_t_2} etc. Hope this helps!

import torch
from botorch.models import SingleTaskGP
from botorch.acquisition import qKnowledgeGradient
from botorch.optim import optimize_acqf

dim = 2
model = SingleTaskGP(torch.rand(10, dim), torch.rand(10, 1))

kg = qKnowledgeGradient(model=model)

candidate, value = optimize_acqf(
    acq_function=kg,
    bounds=torch.tensor([[0., 0.], [1., 1.]]),
    q=1,
    num_restarts=5,
    raw_samples=10,
    fixed_features={1: 0.5},
)

assert candidate[..., -1] == 0.5

@r-ashwin
Copy link
Author

@saitcakmak Thanks for the tip -- I was not aware of the fixed_features argument! While it does not solve my problem, it is key ingredient in seamlessly ensuring my a_KG(x,t) is computed at a fixed t. The other part would be to modify the fantasy model to compute only at T as @Balandat suggested.

@saitcakmak
Copy link
Contributor

saitcakmak commented Oct 23, 2020

Let me see if I understand this correctly. You want KG to be evaluated using fantasies conditioned on some t, while optimizing the value function at T. The solution suggested above would use the same value for both t and T, and that's not what you want.

In that case, a simple wrapper around qKG (below) may work. You could probably achieve this by passing an objective to qKG that returns mu(x, T) as well. I just don't have that much experience with those.

class aKG(qKnowledgeGradient):
    def __init__(self, model: Model, T: float, **kwargs):
        super(aKG, self).__init__(model, **kwargs)
        self.T = T

    def forward(self, X: Tensor):
        X[..., 1:, -1] = self.T
        return super(aKG, self).forward(X)

Ps. This wrapper approach is not fully compatible with the heuristic for generating the inner solutions. The heuristic would maximize mu(x, t) over x, t, and pick the inner solutions via a softmax approach among the optimizers of mu(x, t). If the x's maximizing mu(x, t) are different than those maximizing mu(x, T), this may produce low quality restart points. You wouldn't run into this if you were to use an objective instead (the heuristic would maximize the objective, which is mu(x, T)). So, that may be the better solution.

@r-ashwin
Copy link
Author

Yes, my fantasy model is conditioned at t but it is used to evaluate the GP posterior at T. Your suggestion makes sense to me but let me give it a shot and get back to you.

@Balandat
Copy link
Contributor

Alternatively to the fixed_features argument, there is a FixedFeatureAcquisitionFunction wrapper that you should be able to use for that purpose. The benefit of that is that it actually reduces the dimensionality of the optimization problem (whereas the fixed_features approach does not, the optimizer basically has to figure out that the fixed dimension is irrelevant during the optimization).

@Balandat
Copy link
Contributor

@saitcakmak The issue with the approach suggested above is that it wouldn't fantasize from t but from T. What you need is to

  1. use the project operation in the qMultiFidelityKnowledgeGradient I mentioned in my above comment. You essentially have to implement a simple callable that modifies the time in your input to be evaluated at T
  2. wrap the whole thing in a FixedFeatureAcquisitionFunction where the time in the input gets clamped to t instead.

Hope this helps.

@saitcakmak
Copy link
Contributor

@Balandat, it would work properly if used with fixed_features, right? I guess X[..., 1:, -1] = self.T is assuming q=1, but otherwise a more general version like X[..., -self.num_fantasies:, -1] = self.T should not affect the fantasies, which are fixed to t by the fixed_features. Hopefully, I am not missing something here.

I like your approach better since it is less hacky. One issue I see with it is that it loses the smart heuristic used to generate the initial_conditions for KG (my approach gets somewhat messy here as well). Since FixedFeatureAcquisitionFunction is not a subclass of qKnowledgeGradient, optimize_acqf uses gen_batch_initial_conditions rather than gen_one_shot_kg_initial_conditions. I just thought this might be worth mentioning. It is probably fine as long as raw_samples is large enough.

@r-ashwin
Copy link
Author

r-ashwin commented Oct 24, 2020

Yes, my fantasy model is conditioned at t but it is used to evaluate the GP posterior at T. Your suggestion makes sense to me but let me give it a shot and get back to you.

@saitcakmak So it looks like I cannot use gpytorch.models.ExactGP because it does not have fantasize() implemented. Therefore, when I instead subclass botorch.models.SingleTaskGP it appears that there is some tensor shape mismatch while performing matmul. It appears as though this is due to the product composition of the kernels that I am using, i.e., k = k_x * k_t. I am not sure if these operations are not yet supported in the GPyTorch wrappers within BoTorch or I am missing something here.
I am debugging this on my end, but if this seems like a familiar issue, your feedback is appreciated. See code and error below.

Update:

I get the same error when using the multifidelity qKG. Happy to share that code as well if necessary, but did not want to clutter the space. It looks like the common problem for both cases is that the posterior() is not compatible with the product kernel.

def canned_dynamic_gp2(train_x, train_y):
    '''
    fits a single-task GP for f(x,t) with a product kernel k_xt = k_x * k_t

    :param train_x:
    :param train_y:
    :return: gp object
    '''

    class ExactGPModel(SingleTaskGP, GPyTorchModel):
        _num_outputs = 1  # to inform the BoTorch api
        ard_num_dims = 1 
        def __init__(self, train_x, train_y, likelihood):
            super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.Rbfx_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
            self.Rbft_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

        def forward(self, x):
            X = x[..., :-1]
            t = x[..., -1]
#             print(X.shape)
#             print('X', X)
            mean_x = self.mean_module(x)
            covar_x = self.Rbfx_module(X) * self.Rbft_module(t)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    gp         = ExactGPModel(train_x, train_y, likelihood)
    mll        = ExactMarginalLogLikelihood(gp.likelihood, gp)
    gp.likelihood.noise_covar.raw_noise_constraint.upper_bound = 1e-3 # constraint on observation noise
    fit_gpytorch_model(mll);

    return gp

bounds  = torch.stack([torch.zeros(2), torch.tensor([1,4])])
gp = canned_dynamic_gp2(train_x, train_y.reshape(-1,1))

from botorch.acquisition import qKnowledgeGradient
from botorch.models.model import Model
from torch import Tensor

class aKG(qKnowledgeGradient):
    def __init__(self, model: Model, T: float, **kwargs):
        super(aKG, self).__init__(model, **kwargs)
        self.T = T

    def forward(self, X: Tensor):
        X[..., 1:, -1] = self.T
        return super(aKG, self).forward(X)

qKG = aKG(gp, 2.)

eval_X = torch.cat([torch.linspace(0.,1.,50).reshape(-1,1), 1.95 * torch.ones(50).reshape(-1,1)])
eval_X = eval_X.reshape(50, -1, 2)
options = {
    "num_restarts": 10,
    "raw_samples": 1000,
}
random_evals = qKG.evaluate(eval_X, bounds=bounds, options=options)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\Anaconda3\lib\site-packages\gpytorch\lazy\lazy_tensor.py in mul(self, other)
   1149         try:
-> 1150             _mul_broadcast_shape(self.shape, other.shape)
   1151         except RuntimeError:

~\Anaconda3\lib\site-packages\gpytorch\utils\broadcasting.py in _mul_broadcast_shape(error_msg, *shapes)
     19                 if error_msg is None:
---> 20                     raise RuntimeError("Shapes are not broadcastable for mul operation")
     21                 else:

RuntimeError: Shapes are not broadcastable for mul operation

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-45-a1b272311984> in <module>
      5     "raw_samples": 1000,
      6 }
----> 7 random_evals = qKG.evaluate(eval_X, bounds=bounds, options=options)

~\Anaconda3\lib\site-packages\botorch\utils\transforms.py in decorated(cls, X, **kwargs)
    196         if cls.X_pending is not None:
    197             X = torch.cat([X, match_batch_shape(cls.X_pending, X)], dim=-2)
--> 198         return method(cls, X, **kwargs)
    199 
    200     return decorated

~\Anaconda3\lib\site-packages\botorch\utils\transforms.py in decorated(cls, X, **kwargs)
    167                 )
    168             X = X if X.dim() > 2 else X.unsqueeze(0)
--> 169             return method(cls, X, **kwargs)
    170 
    171         return decorated

~\Anaconda3\lib\site-packages\botorch\acquisition\knowledge_gradient.py in evaluate(self, X_actual, bounds, **kwargs)
    213         # construct the fantasy model of shape `num_fantasies x b`
    214         fantasy_model = self.model.fantasize(
--> 215             X=X_actual, sampler=self.sampler, observation_noise=True
    216         )
    217 

~\Anaconda3\lib\site-packages\botorch\models\model.py in fantasize(self, X, sampler, observation_noise, **kwargs)
    122         propagate_grads = kwargs.pop("propagate_grads", False)
    123         with settings.propagate_grads(propagate_grads):
--> 124             post_X = self.posterior(X, observation_noise=observation_noise)
    125         Y_fantasized = sampler(post_X)  # num_fantasies x batch_shape x n' x m
    126         return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)

~\Anaconda3\lib\site-packages\botorch\models\gpytorch.py in posterior(self, X, output_indices, observation_noise, **kwargs)
    299                     X=X, original_batch_shape=self._input_batch_shape
    300                 )
--> 301             mvn = self(X)
    302             if observation_noise is not False:
    303                 if torch.is_tensor(observation_noise):

~\Anaconda3\lib\site-packages\gpytorch\models\exact_gp.py in __call__(self, *args, **kwargs)
    312 
    313             # Get the joint distribution for training/test data
--> 314             full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
    315             if settings.debug().on():
    316                 if not isinstance(full_output, MultivariateNormal):

~\Anaconda3\lib\site-packages\gpytorch\module.py in __call__(self, *inputs, **kwargs)
     26 
     27     def __call__(self, *inputs, **kwargs):
---> 28         outputs = self.forward(*inputs, **kwargs)
     29         if isinstance(outputs, list):
     30             return [_validate_module_outputs(output) for output in outputs]

<ipython-input-40-efe4cedf8577> in forward(self, x)
     23 #             print('X', X)
     24             mean_x = self.mean_module(x)
---> 25             covar_x = self.Rbfx_module(X) * self.Rbft_module(t)
     26             return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
     27 

~\Anaconda3\lib\site-packages\gpytorch\lazy\lazy_tensor.py in __mul__(self, other)
   1889 
   1890     def __mul__(self, other):
-> 1891         return self.mul(other)
   1892 
   1893     def __radd__(self, other):

~\Anaconda3\lib\site-packages\gpytorch\lazy\lazy_tensor.py in mul(self, other)
   1151         except RuntimeError:
   1152             raise RuntimeError(
-> 1153                 "Cannot multiply LazyTensor of size {} by an object of size {}".format(self.shape, other.shape)
   1154             )
   1155 

RuntimeError: Cannot multiply LazyTensor of size torch.Size([50, 40, 40]) by an object of size torch.Size([50, 50])

@saitcakmak
Copy link
Contributor

It looks like self.Rbfx_module(X) is 50 x 40 x 40 and self.Rbft_module(t) is 50 x 50. This is due to t = x[..., -1] squeezing the last dimension. You can fix this bug by replacing it with t = x[..., -1:].

I also noticed that you're using qKG.evaluate with the aKG wrapper defined above. I just wanted to point out that this will not be solving the inner optimization at T unless you modify the evaluate method.

@Balandat
Copy link
Contributor

SingleTaskGP does some rearranging of dimensions internally to allow efficient multi-output modeling (outputs are interpreted as batches). To get the fantasize method for your custom model, you should subclass from ExactGP and GPyTorchModel (instead of from SingleTaskTP and GPyTorchModel).

@r-ashwin
Copy link
Author

To get the fantasize method for your custom model, you should subclass from ExactGP and GPyTorchModel (instead of from SingleTaskTP and GPyTorchModel).

Actually, subclassing SingleTaskGP works as long as I ensure t is 2d as suggested by @saitcakmak. On the other hand, when I subclass from ExactGP and GPyTorchModel, I get ModuleAttributeError since it misses _input_batch_shape.

If I may suggest, forward compatibility between any GPyTorch model and BoTorch would be great for users like myself that are interested in using several GPyTorch models within BoTorch.

@r-ashwin
Copy link
Author

@saitcakmak Just to make sure I understood correctly, in your aKG wrapper above, you are updating X[...,-1] with T. However, if I want to get a fantasy model at t I should instead replace T in aKG with t, correct? Once I do this, I will have the correct fantasy model provided to evaluate(), however, how can I modify the _get_value_function to make sure the inner optimization is done at T? It appears as though it would still be doing the inner opt at t.

@saitcakmak
Copy link
Contributor

On the other hand, when I subclass from ExactGP and GPyTorchModel, I get ModuleAttributeError since it misses _input_batch_shape.

I think you're running into a bug that I introduced here:

batch_shape = acq_function.model._input_batch_shape

That line would raise an AttributeError whenever you use qKG.evaluate with a model that is not a subclass of BatchedMultiOutputGPyTorchModel (this sets the _input_batch_shape attribute). For the models I have used, _input_batch_shape = model.train_inputs[0].shape[:-2], so you could probably replace that line with batch_shape = acq_function.model.train_inputs[0].shape[:-2] in your local installation to fix that issue. I am not sure if this is true for all GPyTorch models though. Otherwise, I believe BoTorch is compatible with all GPyTorch models.

@Balandat, is it the case that _input_batch_shape is always given by model.train_inputs[0].shape[:-2]? What would be the best way to fix this bug? I will make a separate issue for this and another bug I found regarding evaluate.

in your aKG wrapper above, you are updating X[...,-1] with T. However, if I want to get a fantasy model at t I should instead replace T in aKG with t, correct?

The forward method of qKG splits X in two over the q-batch dimension (dim=-2): First q, i.e., X[..., :q, :] are the candidates used for fantasize, and the remaining num_fantasies, e.g., X[..., -num_fantasies:, :], are the solutions to the inner problems. The aKG wrapper above, assuming q=1, fixes the t entry of inner solutions (X[..., -num_fantasies:, :] equivalently X[..., 1:, :] if q=1) to T. It does not do anything to X[..., :q, :] that is used to generate the fantasies. You could then use fixed_features as explained above to fix these to t (this will fix all of them to t, but inner solutions are then overwritten by the wrapper). But this is a dirty solution, so you may want to follow @Balandat 's suggestion (re-iterated below) instead.

Once I do this, I will have the correct fantasy model provided to evaluate()

You cannot modify evaluate with X[..., -num_fantasies:, :] = T as I did for forward. Evaluate only takes an n x q x d input which are the candidates used for fantasize. It is only intended to evaluate the value of qKG acqf at that candidate point and should not be used for optimization.

however, how can I modify the _get_value_function to make sure the inner optimization is done at T?

To make sure inner optimization is done at T, you could pass scipy_options={"fixed_features": {-1, T}} (make sure to replace -1 with the actual index) to evaluate.

In your case, it is much cleaner to use qMultiFidelityKnowledgeGradient with project=fix_to_T with

def fix_to_T(X):
    X[..., -1] = T  # replace with the actual T
    return X

You could then wrap qMultiFidelityKnowledgeGradient in FixedFeatureAcquisitionFunction to get your fantasies at t as well. This is just a cleaner way of doing it. You wouldn't have to modify the source code, and there's less room for mistakes.

Note: You cannot use evaluate with qMultiFidelityKnowledgeGradient. It will ignore the project operator and produce buggy output.

@Balandat
Copy link
Contributor

If I may suggest, forward compatibility between any GPyTorch model and BoTorch would be great for users like myself that are interested in using several GPyTorch models within BoTorch.

Yes, this is a goal, but unfortunately that's not easy to do since the basic GPyTorch models don't carry around some of the metadata that we need in botorch. For instance, GPyTorch models don't always use an explicit outcome dimension, which makes their interpretation ambiguous without additional information.

We can and should work on minimizing the discrepancies here, and there are probably some aspects that we can fix. But I fear that at least right now supporting fully generic GPyTorch model plug-in seems quite challenging.

@r-ashwin
Copy link
Author

@saitcakmak Got it - thanks for the clarification. As of now both approaches suggested (qKnowledgeGradient and qMultiFidelityKnowledgeGradient) work without error, after fixing the array dimensions.

@r-ashwin
Copy link
Author

r-ashwin commented Oct 29, 2020

I tried the following but get an error. Per the API reference, values can be a list of length d_f (=1) in my case and columns can be an iterable with the indices of the columns to be fixed.

# project fidelity to T
T = 2.
def project(X, T=T):
    X[...,-1] = T
    return X

qMFKG = qMultiFidelityKnowledgeGradient(gp, project=project)
# fix the time to t in acquisition function
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
columns  = [1] # 2nd column of a [...,2] tensor
Values   = [1.95] # value to fix column in columns
qMFKG_FF = FixedFeatureAcquisitionFunction(qMFKG, 2, columns, Values)
c =[]
from botorch.optim import optimize_acqf
candidate, value   = optimize_acqf(acq_function   = qMFKG_FF,
                                       bounds         = bounds,
                                       q              = 1,
                                       num_restarts   = 5,
                                       raw_samples    = 1000,)
c
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-54-5331c77315f8> in <module>
      6                                        num_restarts   = 5,
      7                                        raw_samples    = 1000,
----> 8                                        fixed_features = {1: 1.95},)
      9 c

~\Anaconda3\lib\site-packages\botorch\optim\optimize.py in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, sequential, **kwargs)
    155             num_restarts=num_restarts,
    156             raw_samples=raw_samples,
--> 157             options=options,
    158         )
    159 

~\Anaconda3\lib\site-packages\botorch\optim\initializers.py in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, options)
    107                     end_idx = min(start_idx + batch_limit, X_rnd.shape[0])
    108                     Y_rnd_curr = acq_function(
--> 109                         X_rnd[start_idx:end_idx].to(device=device)
    110                     ).cpu()
    111                     Y_rnd_list.append(Y_rnd_curr)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~\Anaconda3\lib\site-packages\botorch\acquisition\fixed_feature.py in forward(self, X)
     81             `_construct_X_full`).
     82         """
---> 83         X_full = self._construct_X_full(X)
     84         return self.acq_func(X_full)
     85 

~\Anaconda3\lib\site-packages\botorch\acquisition\fixed_feature.py in _construct_X_full(self, X)
    100         if d_prime + d_f != self.d:
    101             raise ValueError(
--> 102                 f"Feature dimension d' ({d_prime}) of input must be "
    103                 f"d - d_f ({self.d - d_f})."
    104             )

ValueError: Feature dimension d' (2) of input must be d - d_f (1).

@Balandat
Copy link
Contributor

Hmm looks like this is some interaction between KG and FixedFeatureAcquisitionFunction. I don't see anything obviously wrong with your code, let me take a closer look tomorrow.

@saitcakmak
Copy link
Contributor

I am guessing the bounds you use below are 2-dim, and so optimize_acqf uses 2D X to evaluate qMFKG_FF. Since you fixed 1 dimension, qMFKG_FF expects 1D input. You should be passing only the bounds corresponding to X[..., 0] to optimize_acqf, e.g., bounds=torch.tensor([[0.], [1.]]).

candidate, value   = optimize_acqf(acq_function   = qMFKG_FF,
                                      bounds         = bounds,
                                      q              = 1,
                                      num_restarts   = 5,
                                      raw_samples    = 1000,)

@r-ashwin
Copy link
Author

r-ashwin commented Oct 29, 2020

Alternatively, this works (notice that I have used fixed_features arg with qMFKG). Therefore is using fixed_features with qMFKG equivalent to using qMFKG within FixedFeaturesAcquisitionFunction except from an efficiency point of view during the optimization?

candidate, value   = optimize_acqf(acq_function   = qMFKG,
                                   bounds         = bounds,
                                   q              = 1,
                                   num_restarts   = 10,
                                   raw_samples    = 1000,
                                   fixed_features = {1:1.95}
    )
c

@r-ashwin
Copy link
Author

@saitcakmak tried that and this is what I got. Looks like most of my issues come down to tensor dimension mismatch. Is there a place I can find all the definitions of tensor shape concretely defined? Or is defining all my tensors e.g., train_x and test_x, with shape t-batch x q-batch x d good practice to ensure safety with botorch?

bounds=torch.tensor([[0.], [1.]])
candidate, value   = optimize_acqf(acq_function   = qMFKG_FF,
                                   bounds         = bounds,
                                   q              = 1,
                                   num_restarts   = 10,
                                   raw_samples    = 1000,
#                                    fixed_features = {1:1.95}
    )
c
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-92-af3e311dd84b> in <module>
      6                                    q              = 1,
      7                                    num_restarts   = 10,
----> 8                                    raw_samples    = 1000,
      9 #                                    fixed_features = {1:1.95}
     10     )

~\Anaconda3\lib\site-packages\botorch\optim\optimize.py in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, sequential, **kwargs)
    155             num_restarts=num_restarts,
    156             raw_samples=raw_samples,
--> 157             options=options,
    158         )
    159 

~\Anaconda3\lib\site-packages\botorch\optim\initializers.py in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, options)
    107                     end_idx = min(start_idx + batch_limit, X_rnd.shape[0])
    108                     Y_rnd_curr = acq_function(
--> 109                         X_rnd[start_idx:end_idx].to(device=device)
    110                     ).cpu()
    111                     Y_rnd_list.append(Y_rnd_curr)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~\Anaconda3\lib\site-packages\botorch\acquisition\fixed_feature.py in forward(self, X)
     82         """
     83         X_full = self._construct_X_full(X)
---> 84         return self.acq_func(X_full)
     85 
     86     def _construct_X_full(self, X: Tensor) -> Tensor:

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~\Anaconda3\lib\site-packages\botorch\utils\transforms.py in decorated(cls, X, **kwargs)
    167                 )
    168             X = X if X.dim() > 2 else X.unsqueeze(0)
--> 169             return method(cls, X, **kwargs)
    170 
    171         return decorated

~\Anaconda3\lib\site-packages\botorch\acquisition\knowledge_gradient.py in forward(self, X)
    391                 maximized at fixed `X_actual[b]`.
    392         """
--> 393         X_actual, X_fantasies = _split_fantasy_points(X=X, n_f=self.num_fantasies)
    394 
    395         # We only concatenate X_pending into the X part after splitting

~\Anaconda3\lib\site-packages\botorch\acquisition\knowledge_gradient.py in _split_fantasy_points(X, n_f)
    458     if n_f > X.size(-2):
    459         raise ValueError(
--> 460             f"n_f ({n_f}) must be less than the q-batch dimension of X ({X.size(-2)})"
    461         )
    462     split_sizes = [X.size(-2) - n_f, n_f]

ValueError: n_f (64) must be less than the q-batch dimension of X (1)

@saitcakmak
Copy link
Contributor

saitcakmak commented Oct 29, 2020

Oh, I see what is going on with this one.

candidate, value   = optimize_acqf(acq_function   = qMFKG_FF,
                                       bounds         = bounds,
                                       q              = 1,
                                       num_restarts   = 5,
                                       raw_samples    = 1000,)

This here queries qMFKG_FF with n x q x d = (num_restarts or raw_samples) x 1 x 1 (=bounds.shape[-1])-dim tensor. qMFKG_FF appends the fixed dimension, and queries qMFKG with n x q x (x+1) = n x 1 x 2-dim tensor. But qMFKG is a one-shot acquisition function, so it expects a q dimension of size num_fantasies + q, which is 64+1=65 in this case. Setting q=65 in optimize_acqf would fix this, but I am not sure if this is the best way to go.

Alternatively, this works (notice that I have used fixed_features arg with qMFKG). Therefore is using fixed_features with qMFKG equivalent to using qMFKG within FixedFeaturesAcquisitionFunction except from an efficiency point of view during the optimization?

candidate, value   = optimize_acqf(acq_function   = qMFKG,
                                   bounds         = bounds,
                                   q              = 1,
                                   num_restarts   = 10,
                                   raw_samples    = 1000,
                                   fixed_features = {1:1.95}
    )
c

Yes, this should be equivalent. There are differences in how optimize_acqf handles the initial_condition generation under the hood as well. Besides that, this would work.

@r-ashwin
Copy link
Author

This here queries qMFKG_FF with n x q x d = (num_restarts or raw_samples) x 1 x 1 (=bounds.shape[-1])-dim tensor. qMFKG_FF appends the fixed dimension, and queries qMFKG with n x q x (x+1) = n x 1 x 2-dim tensor. But qMFKG is a one-shot acquisition function, so it expects a q dimension of size num_fantasies + q, which is 64+1=65 in this case. Setting q=65 in optimize_acqf would fix this, but I am not sure if this is the best way to go.

Your explanation makes sense and it does remove the error when q=64. However, I do not understand why the user should infer the default num_fantasies and set the q accordingly while when they are specifically interested in q=1...

@r-ashwin
Copy link
Author

r-ashwin commented Nov 2, 2020

I have a follow-up question, if you have any thoughts on this. Placing it here since it is related to the original question.

Can I fantasize at multiple ts? Currently, I fantasize at t and solve the inner optimization at T. What if I have additional fidelities t < t1 < t2 < ... < T and I want to generate fantasy models from t1, t2, .... and solve inner opt at T.

Is this related to the expand argument to qMultifidelityKnowledgeGradient? If so is there a quick example you can share as to how to use it? Thanks very much!

@saitcakmak
Copy link
Contributor

What if I have additional fidelities t < t1 < t2 < ... < T and I want to generate fantasy models from t1, t2, .... and solve inner opt at T.

I can interpret this in two ways: i) you want to jointly fantasize at t1, t2, ..., i.e., generate a single fantasy model that is conditioned on multiple parallel observations, ii) you want to generate multiple fantasy models, each one conditioned on a different t. I will assume it is the first option.

expand: A callable mapping a `batch_shape x q x d` input tensor a `batch_shape x (q + q_e)' x d`-dim output tensor, where `q_e` additional points in each q-batch correspond additional ("trace") observations.

Based on this documentation, assuming q=1, and your solution is of the form (x, t), you could define expand as follows.

def my_expand(X):
    x = X[..., :-1]  # extract the x part
    t_list = [t1, t2, t3]  # the additional t we want to consider, scalar tensors.
    X_list = [X]
    for t in t_list:
        X_list.append(torch.cat([x, t.repeat(*x.shape[:-1], 1)], dim=-1))
    return torch.cat(X_list, dim=-2)

my_expand here takes input tensor [[x, t]] and returns [[x, t], [x, t1], [x, t2], [x, t3]]. You could modify this in anyway you want, keeping mind that qMFKG will be jointly fantasizing over [[x, t], [x, t1], [x, t2], [x, t3]], treating them as parallel observations.

@r-ashwin
Copy link
Author

r-ashwin commented Nov 3, 2020

@saitcakmak Awesome! Thanks!

@r-ashwin
Copy link
Author

r-ashwin commented Nov 3, 2020

@saitcakmak When I use expand along with the fixed_features arg within optimize_acqf what is the expected behavior? Lets say I set fixed_features={1:t} to ensure fantasize at t then use expand to fantasize jointly at t1 , then the X returned by your my_expand above should have in the last column only t's and t1's? That's what I understand I should expect, but I don't see that. What I see instead is the last column ranging between [bounds[0][-1], bounds[1][-1]], along with the expanded entries with t1.

@saitcakmak
Copy link
Contributor

The expected behavior is that all calls to the acquisition_function from within gen_candidates_scipy (used by optimize_acqf for optimization) will have X[..., 1] = t. Note that optimize_acqf calls the acquisition_function to generate the initial_conditions, before calling gen_candidates_scipy. This first call to acquisition_function will have X randomized. So, if you're checking the values in debug mode, you should check them after gen_candidates_scipy.

expand is used for generating the fantasies:

fantasy_model = self.model.fantasize(
X=self.expand(X_eval), sampler=self.sampler, observation_noise=True
)

Using fixed_features, when called from within gen_candidates_scipy, you will have X_eval[..., 1] = t, where X_eval is n x 1 x 2. expand(X_eval) will then be n x 4 x 2 (following my_expand definition above) with expand(X_eval)[..., 0, 1] = t, expand(X_eval)[..., 1, 1] = t1, expand(X_eval)[..., 1, 1] = t2 and expand(X_eval)[..., 1, 1] = t3. Any fantasy model generated here will be jointly over these four solutions.

@r-ashwin
Copy link
Author

r-ashwin commented Nov 8, 2020

@saitcakmak , two remarks:

  1. The qKG.evaluate() method seems to produce different output each time I evaluate. Of course, it has stochasticity inherent in it, but I am using num_restarts=20 and raw_samples=1250 for a 1D problem and trying to see how high they should be to see atleast similar results between repetitions. I am running on CPUs and increasing raw_samples beyond this value leads to an insufficient-memory error.
  2. The candidate selected by optimize_acqf with qKG and exactly same values for num_restarts and raw_samples as in 1 appears to be different from the maximizer that I observe from the plot of qKG.evaluate()

For V&V purposes, how can I ensure the output of optimize_acqf and qKG.evaluate() are consistent with each other? Also, is there an envelope-theorem implementation of KG in BoTorch that I can use to compare against the one-shot implementation?
Happy to share a minimal example if necessary.

@saitcakmak
Copy link
Contributor

  1. In my testing, I don't see anything that is not explained by the inherent stochasticity. To further improve optimization quality, you can also increase num_restarts. It would increase the computational cost, but you wouldn't run into memory issues with moderate values. Another thing that would help with memory usage would be to evaluate fewer solutions in parallel, e.g., if you have X.shape=[100, 1, 2], you can call qKG.evaluate(X[0:10], ...), which would use 1/10 of the memory.

  2. Optimizing qKG is expected to produce inferior solutions to the inner problems than qKG.evaluate(). Since qKG is optimized in a one-shot manner, the inner solutions are not globally optimized (which qKG.evaluate() tries to do). I experimented with one-shot optimization extensively, and in general it works quite well compared to nested / envelope-theorem based optimization. If you are comparing qKG with optimize_acqf with qKG.evaluate(), you can set a larger optimization budget for qKG for fairness.

how can I ensure the output of optimize_acqf and qKG.evaluate() are consistent with each other?

Setting a larger optimization budget for optimize_acqf should help.

is there an envelope-theorem implementation of KG in BoTorch

BoTorch doesn't implement this. You could easily write your own KG implementation. The forward pass of your acquisition function can call qKG.evaluate(). Just make sure that you do not inherit from qKnowledgeGradient since this will result in optimize_acqf assuming that you have a one-shot implementation. Note that this will be significantly more expensive than the one-shot implementation. Something like this should work:

class NestedKG(MCAcquisitionFunction):
    def __init__(...):
        # define your init here, you can mostly copy qKG

    def forward(X: Tensor) -> Tensor:
        qKnowledgeGradient.evaluate(self, X)  # passing self here is crucial

If the recommendations here do not solve the issue and you think there is a different bug in play, I'd be happy to look into it deeper if you share a reproducible example.

@r-ashwin
Copy link
Author

r-ashwin commented Nov 9, 2020

Okay I will prepare a reproducible example and drop it here.

One thing that is worth clarifying before that is that in the forward method in qKG, the argument X has X[..., q:, :] as the current solutions to the inner optimization problem. Then, we overwrite X[...,q:,-1] = T to project them to T. This means that the inner optimization problem is first solved over the joint (x,t) space and then we overwrite the last column of the solutions to T. If this is true (?) then this is not what I intend to do. Instead, I want the inner optimization problem to maximize mu(x,T) -- in other words, first project GP posterior to T and optimize as opposed to an optimize-then-project approach. In general, argmax_x mu(x, t_i) != argmax_x mu(x, t_j) for all i!=j .

In this regard, passing an objective to qKG seems appropriate to me. This way, I just need to subclass PosteriorMean and overwrite the X in its forward method to ensure the last column is T. Here is a solution I propose

  • the current implementation expects the objective to be of type ScalarizedObjective and not acquisition -- I could perhaps overcome this by modifying the _get_value_function() where I could replace the PosteriorMean with a PosteriorMean_at_T which, in its forward() method projects the last column of X to T. Does this sound right?
  • also is this guaranteed to solve the inner optimization problem at T?

Thanks for all your responses so far -- they were very useful!

@saitcakmak
Copy link
Contributor

This means that the inner optimization problem is first solved over the joint (x,t) space and then we overwrite the last column of the solutions to T.

If you use qKG.evaluate() the inner problems are optimized over (x, t) (ignores the project operator, #594 fixes this). If you instead use optimize_acqf and optimize qMFKG (uses forward calls), then the solutions are projected, so you are optimizing mu(x, T). In either case, due to some inner workings of optimize_acqf, I would recommend using #594 with the project operator (it modifies the initial condition generation procedure to account for project). This readily achieves the goal of your proposed approach, i.e., optimizes the inner problem always over mu(x, T) (assuming projection to T).

You can install #594 via pip install --upgrade git+https://github.com/saitcakmak/botorch.git@value_function. The PR hasn't been reviewed yet, so use at your own discretion.

@r-ashwin
Copy link
Author

r-ashwin commented Nov 9, 2020

I see - thanks! Let me try both and see how it goes.


Update: I was able to check that your implementation in #594 does indeed what I want. However, I am not sure I am able to see that the optimize_acqf and evaluate are consistent. Firstly, the candidates are different. Secondly, the value of the optimizer acquisition function are different between the two -- I thought they should be off only by current_value, but this seems more than that. A minimal example to reproduce my results is below.

@r-ashwin
Copy link
Author

@Balandat @saitcakmak
A related question: Am I correct in understanding that qKG.evaluate() does not have gradients propagated? If so do you have a quick recipe for how I could compute the gradient of qKG.evaluate(X) wrt X ?

@saitcakmak
Copy link
Contributor

@r-ashwin I think the discrepancy you observe between qMFKG.evaluate() and optimize_acqf is due to your GP model. With your custom model, evaluate produces weird output. The KG acquisition function is supposed to be a continuous function of the candidate, which is not the case when using your custom GP (there are weird jumps in the plot). If you replace it with a SingleTaskGP the problem goes away. I observed some strange behavior with model.fantasize() when using your GP, so it is possible that fantasize is where things break down.

I ran additional testing with SingleTaskGP, and didn't observe any discrepancy between qKG.evaluate() and optimize_acqf.

Am I correct in understanding that qKG.evaluate() does not have gradients propagated?

That is correct. You'd have to modify it and re-evaluate the inner solutions with propagate_grads.

_, values = gen_candidates_scipy(
initial_conditions=initial_conditions,
acquisition_function=value_function,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
options=kwargs.get("scipy_options"),
)
# get the maximizer for each batch
values, _ = torch.max(values, dim=0)

Would be replaced by something like

solutions, values = gen_candidates_scipy( 
    initial_conditions=initial_conditions, 
    acquisition_function=value_function, 
    lower_bounds=bounds[0], 
    upper_bounds=bounds[1], 
    options=kwargs.get("scipy_options"), 
) 
# get the maximizer for each batch 
indices = torch.argmax(values, dim=0)
solutions = solutions[indices]  # get the optimal solution for each inner problem
with settings.propagate_grads(True):
    values = value_function(solutions)  # evaluate the value function while propagating the gradient to X

I haven't tested this, but it should give the general idea. I've used similar implementations in the past, but they tend to be significantly slower than the one-shot approach.

@r-ashwin
Copy link
Author

r-ashwin commented Nov 10, 2020

That's interesting because I am indeed subclassing SingleTaskGP. Also, as far as the GP fit for the true function, everything seems fine.
As far as I tested evaluate() for my problems, it always gave a continuous function, but it was not differentiable everywhere, which I thought was expected.

PS: thanks also for the gradient tip. From what I have tested so far, the evaluate method seems to be maximized at reasonable locations for my problem, compared to optimize_acqf

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants