<a href="https://colab.research.google.com/github/sparks-baird/self-driving-lab-demo/blob/main/notebooks/qnipv-mwe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install ax-platform

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ax-platform
  Downloading ax_platform-0.3.3-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botorch==0.8.5 (from ax-platform)
  Downloading botorch-0.8.5-py3-none-any.whl (530 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m530.3/530.3 kB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard==2.13.3 (from ax-platform)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Collecting pyro-ppl>=1.8.4 (from botorch==0.8.5->ax-platform)
  Downloading pyro_ppl-1.8.5-py3-none-any.whl (732 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m732.5/732.5 kB[0m [31m67.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gpytorch==1.10 (from botorch==0.8.5->ax-platform)
  Downloading gpytorch-1.10-py3-none-any.whl (255 kB)
[2K     [9

In [2]:
from typing import Any, Dict, Optional

import torch

# from ax.core.objective import ScalarizedObjective
from ax.modelbridge import get_sobol
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.ax_client import AxClient
from botorch.acquisition.active_learning import (
    MCSampler,
    qNegIntegratedPosteriorVariance,
)
from botorch.acquisition.input_constructors import (
    MaybeDict,
    acqf_input_constructor,
    construct_inputs_mc_base,
)
from botorch.acquisition.objective import AcquisitionObjective
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


@acqf_input_constructor(qNegIntegratedPosteriorVariance)
def construct_inputs_qNIPV(
    model: Model,
    mc_points: Tensor,
    training_data: MaybeDict[SupervisedDataset],
    objective: Optional[AcquisitionObjective] = None,
    X_pending: Optional[Tensor] = None,
    sampler: Optional[MCSampler] = None,
    **kwargs: Any,
) -> Dict[str, Any]:
    if model.num_outputs == 1:
        objective = None

    base_inputs = construct_inputs_mc_base(
        model=model,
        training_data=training_data,
        sampler=sampler,
        X_pending=X_pending,
        objective=objective,
    )

    return {**base_inputs, "mc_points": mc_points}


def objective_function(x):
    f = x["x1"] ** 2 + x["x2"] ** 2 + x["x3"] ** 2
    return {"f": (f, None)}


parameters = [
    {"name": "x1", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
    {"name": "x2", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
    {"name": "x3", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
]
ax_client_tmp = AxClient()
ax_client_tmp.create_experiment(parameters=parameters)
sobol = get_sobol(ax_client_tmp.experiment.search_space)
mc_points = sobol.gen(1024).param_df.values
mcp = torch.tensor(mc_points)

model_kwargs_val = {
    "surrogate": Surrogate(SingleTaskGP),
    "botorch_acqf_class": qNegIntegratedPosteriorVariance,
    "acquisition_options": {"mc_points": mcp},
}

gs = GenerationStrategy(
    steps=[
        GenerationStep(model=Models.SOBOL, num_trials=5),
        GenerationStep(
            model=Models.BOTORCH_MODULAR, num_trials=15, model_kwargs=model_kwargs_val
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="active_learning_experiment",
    parameters=parameters,
    objective_name="f",
    minimize=True,
)

for _ in range(20):
    trial_params, trial_index = ax_client.get_next_trial()
    data = objective_function(trial_params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=data["f"])

[INFO 06-27 04:40:04] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 06-27 04:40:04] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).
[INFO 06-27 04:40:04] ax.modelbridge.dispatch_utils: Using Models.GPEI since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 06-27 04:40:04] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=3 num_trials=None use_batch_trials=False
[INFO 06-27 04:40:04] a