<a href="https://colab.research.google.com/github/sparks-baird/self-driving-lab-demo/blob/main/notebooks/qnipv-mwe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# check if in colab
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    %pip install ax-platform

In [2]:
from typing import Any, Dict, Optional

import torch

# from ax.core.objective import ScalarizedObjective
from ax.modelbridge import get_sobol
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.transforms.unit_x import UnitX
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.random.sobol import SobolGenerator
from ax.service.ax_client import AxClient
from ax.core.observation import ObservationFeatures
from botorch.acquisition.active_learning import (
    MCSampler,
    qNegIntegratedPosteriorVariance,
)
from botorch.acquisition.input_constructors import (
    MaybeDict,
    acqf_input_constructor,
    construct_inputs_mc_base,
)
from botorch.acquisition.objective import AcquisitionObjective
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


In [3]:
@acqf_input_constructor(qNegIntegratedPosteriorVariance)
def construct_inputs_qNIPV(
    model: Model,
    mc_points: Tensor,
    training_data: MaybeDict[SupervisedDataset],
    objective: Optional[AcquisitionObjective] = None,
    X_pending: Optional[Tensor] = None,
    sampler: Optional[MCSampler] = None,
    **kwargs: Any,
) -> Dict[str, Any]:
    if model.num_outputs == 1:
        objective = None

    base_inputs = construct_inputs_mc_base(
        model=model,
        training_data=training_data,
        sampler=sampler,
        X_pending=X_pending,
        objective=objective,
    )

    return {**base_inputs, "mc_points": mc_points}


def objective_function(x):
    f = x["x1"] ** 2 + x["x2"] ** 2 + x["x3"] ** 2
    return {"f": (f, None)}

In [4]:
parameters = [
    {"name": "x1", "type": "range", "bounds": [0.0, 5.0], "value_type": "float"},
    {"name": "x2", "type": "range", "bounds": [0.0, 10.0], "value_type": "float"},
    {"name": "x3", "type": "range", "bounds": [0.0, 15.0], "value_type": "float"},
]
ax_client_tmp = AxClient()
ax_client_tmp.create_experiment(parameters=parameters)

[INFO 07-08 05:53:51] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 07-08 05:53:51] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 5.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 07-08 05:53:51] ax.modelbridge.dispatch_utils: Using Models.GPEI since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 07-08 05:53:51] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=3 num_trials=None use_batch_trials=False
[INFO 07-08 05:53:51]

In [8]:
# MC Points
# WARNING: assumes only UnitX transform, https://ax.dev/docs/models.html#transforms
num_mc_sobol = 2**13
sobol = get_sobol(ax_client_tmp.experiment.search_space)
# mc_points = sobol.gen(1024).param_df.values
obs_features = [
    ObservationFeatures(parameters)
    for parameters in sobol.gen(num_mc_sobol).param_df.to_dict("records")
]
ux = UnitX(ax_client_tmp.experiment.search_space)
obs_features_ux = ux.transform_observation_features(obs_features)
mc_points = [list(obs.parameters.values()) for obs in obs_features_ux]
mcp = torch.tensor(mc_points)
mcp

tensor([[0.9892, 0.6544, 0.7226],
        [0.1807, 0.2227, 0.4768],
        [0.3692, 0.9154, 0.9483],
        ...,
        [0.3693, 0.4689, 0.1891],
        [0.1809, 0.6693, 0.7213],
        [0.9891, 0.2103, 0.4779]])

In [9]:
model_kwargs_val = {
    "surrogate": Surrogate(SingleTaskGP),
    "botorch_acqf_class": qNegIntegratedPosteriorVariance,
    "acquisition_options": {"mc_points": mcp},
}

gs = GenerationStrategy(
    steps=[
        GenerationStep(model=Models.SOBOL, num_trials=5),
        GenerationStep(
            model=Models.BOTORCH_MODULAR, num_trials=15, model_kwargs=model_kwargs_val
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="active_learning_experiment",
    parameters=parameters,
    objective_name="f",
    minimize=True, # doesn't do anything (?)
)

for _ in range(20):
    trial_params, trial_index = ax_client.get_next_trial()
    data = objective_function(trial_params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=data["f"])

[INFO 07-08 06:15:45] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 07-08 06:15:45] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 5.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 07-08 06:15:45] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 2.694745, 'x2': 0.122303, 'x3': 11.500634}.
[INFO 07-08 06:15:45] ax.service.ax_client: Completed trial 0 with data: {'f': (139.541185, None)}.
[INFO 07-08 06:15:45] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.896566, 'x2': 4.253006, 'x3': 0.689612}.
[INFO 07-08 06:15:45] ax.service.ax_client: Completed trial 1 with data: {'f': (1

## Code Graveyard

In [None]:
# @acqf_input_constructor(qNegIntegratedPosteriorVariance)
# def construct_inputs_qNIPV(
#     model: Model,
#     training_data: MaybeDict[SupervisedDataset],
#     mc_points: Optional[Tensor] = None,
#     objective: Optional[AcquisitionObjective] = None,
#     X_pending: Optional[Tensor] = None,
#     sampler: Optional[MCSampler] = None,
#     **kwargs: Any,
# ) -> Dict[str, Any]:
#     if model.num_outputs == 1:
#         objective = None

#     base_inputs = construct_inputs_mc_base(
#         model=model,
#         training_data=training_data,
#         sampler=sampler,
#         X_pending=X_pending,
#         objective=objective,
#     )
    
#     if mc_points is None:
#         # generate sobol points
#         bounds = torch.tensor([[0.0, 1.0]] * len(list(model.parameters)))
#         mc_points = torch.Tensor(SobolGenerator().gen(1024, bounds))

#     return {**base_inputs, "mc_points": mc_points}


# def objective_function(x):
#     f = x["x1"] ** 2 + x["x2"] ** 2 + x["x3"] ** 2
#     return {"f": (f, None)}