[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sparks-baird/self-driving-lab-demo/blob/main/notebooks/6.2.1-multi-fidelity-continuous.ipynb)

# Continous Multi-fidelity Optimization

In [14]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

In [15]:
# For notebook gallery thumbnail
if IN_COLAB:
    from IPython.display import Image
    Image(filename="continuous-multi-fidelity.png")


In [the previous notebook](https://colab.research.google.com/github/sparks-baird/self-driving-lab-demo/blob/main/notebooks/6.2-multi-fidelity.ipynb), I provided a brief introduction of multi-fidelity
optimization in the context of the physical sciences. This notebook will cover
Bayesian optimization using two continuous fidelity parameters (`atime` and `astep`). We'll compare the total integration time using the
multi-fidelity optimization with the integration time costs of running the simulation at the
the lowest fidelities, the highest fidelities (default), and approximately halfway in-between. For validation, the
objective (in this case, `frechet`) will be evaluated at the upper bound of `atime` and `astep`. In
this experiment, we'll allow `atime` to vary between `0..100` (upper limit is 255) which is:
$2.78 \mu s .. 280.78\mu s$. Since multiple fidelity parameters [aren't yet supported
with Ax](https://github.com/facebook/Ax/issues/1211) (i.e. requires using BoTorch directly), we'll fix `astep` to `999` (limits are `0..65534`). In terms
of physical integration time, this corresponds to lowest and highest integration times
of $278 \mu s$ and $281 ms$, respectively.

For reference, here are the docs for `atime` and `astep` (taken from [`public_mqtt_sdl_demo/lib/as7341_sensor.py`](https://github.com/sparks-baird/self-driving-lab-demo/blob/main/src/public_mqtt_sdl_demo/lib/as7341_sensor.py)):
```python
"""
...
atime : int, optional
    The integration time step size in 2.78 microsecond increments, by default 100
astep : int, optional
    The integration time step count. Total integration time will be (ATIME + 1)
    * (ASTEP + 1) * 2.78µS, by default 999, meaning 281 ms assuming atime=100
...
"""
```

We'll be using the Knowledge Gradient (KG) acquisition function. 

From [the BoTorch docs](https://botorch.org/tutorials/one_shot_kg):
> ### The one-shot Knowledge Gradient acquisition function
> The Knowledge Gradient (KG) (see [2, 3]) is a look-ahead acquisition function that quantifies the expected increase in the maximum of the modeled black-box function f from obtaining additional (random) observations collected at the candidate set x. KG often shows improved Bayesian Optimization performance relative to simpler acquisition functions such as Expected Improvement, but in its traditional form it is computationally expensive and hard to implement.
>
> ...
> 
> [2] P. Frazier, W. Powell, and S. Dayanik. A Knowledge-Gradient policy for sequential information collection. SIAM Journal on Control and Optimization, 2008.
> 
> [3] J. Wu and P. Frazier. The parallel knowledge gradient method for batch bayesian optimization. NIPS 2016.


In this tutorial, we'll still use Ax rather than delving into pure BoTorch code. Since
Ax doesn't yet support discrete fidelity parameters, for the next notebook, we'll use
BoTorch exclusively. BoTorch has [a tutorial for continuous multi-fidelity optimization](https://botorch.org/tutorials/multi_fidelity_bo#Multi-Fidelity-BO-in-BoTorch-with-Knowledge-Gradient)
which has been [adapted for Ax in a GitHub
issue](https://github.com/facebook/Ax/issues/475). We will largely base the
implementation on the example from the GitHub issue.

As before, we need to set up our `SelfDrivingLabDemo` classes. We will use the
physical experimental setting since the simulations use only a simple multiplier to account for
integration time. We'll also use the same `frechet` objective function.

In [16]:
if IN_COLAB:
    %pip install self-driving-lab-demo[ax-platform]

In [17]:
from uuid import uuid4  # universally unique identifier
from self_driving_lab_demo import SelfDrivingLabDemoLight, mqtt_observe_sensor_data

dummy = True  # @param {type:"boolean"}
pico_id = "test"  # @param {type:"string"}
log_to_mongodb = False  # speed up the evaluation
if dummy:
    num_repeats = 2
    atime_max = 5
    astep_max = 5
    time_limit_multiplier = 1
else:
    num_repeats = 5  # @param {type:"integer"}
    atime_max = 100
    astep_max = 999  # fixed for now (see note above)
    time_limit_multiplier = 10  # @param {type:"number"}

model_gen_kwargs = (
    dict(num_fantasies=2, num_restarts=2, raw_samples=8) if dummy else None
)

simulation = False  # @param {type:"boolean"}
SESSION_ID = str(uuid4())  # random session ID


def calc_integration_time_s(atime, astep):
    """
    Calculate integration time (i.e., time cost) of light sensor.

    atime : int, optional
    The integration time step size in 2.78 microsecond increments, by default 100
    astep : int, optional
        The integration time step count. Total integration time will be (ATIME + 1)
        * (ASTEP + 1) * 2.78µS, by default 999, meaning 281 ms assuming atime=100
    """
    return ((atime + 1) * (astep + 1) * 2.78) / 1e6


# total seconds of integration time
time_limit_s = time_limit_multiplier * calc_integration_time_s(atime_max, astep_max)
seeds = range(10, 10 + num_repeats)
print(f"session ID: {SESSION_ID}")

sdls = [
    SelfDrivingLabDemoLight(
        autoload=True,  # perform target data experiment automatically
        simulation=simulation,
        observe_sensor_data_fn=mqtt_observe_sensor_data,  # (default)
        observe_sensor_data_kwargs=dict(
            pico_id=pico_id, session_id=SESSION_ID, mongodb=log_to_mongodb
        ),
        target_seed=seed,
    )
    for seed in seeds
]

session ID: f0d33d7d-c270-474f-a45a-e5873c994745


In [18]:
bounds = dict(R=sdls[0].bounds["R"], G=sdls[0].bounds["G"], B=sdls[0].bounds["B"])
params = [dict(name=nm, type="range", bounds=bnd) for nm, bnd in bounds.items()]
atime_bnd = [0, atime_max]  # instead of [0, 255]
astep_bnd = [0, astep_max]  # instead of [0, 65534]
params.append(
    dict(
        name="atime",
        type="range",
        is_fidelity=True,
        bounds=atime_bnd,
        target_value=atime_bnd[1],
    )
)
params.append(
    dict(
        name="astep",
        type="fixed",
        value=astep_max,
    )
)
params

[{'name': 'R', 'type': 'range', 'bounds': [0, 89]},
 {'name': 'G', 'type': 'range', 'bounds': [0, 89]},
 {'name': 'B', 'type': 'range', 'bounds': [0, 89]},
 {'name': 'atime',
  'type': 'range',
  'is_fidelity': True,
  'bounds': [0, 5],
  'target_value': 5},
 {'name': 'astep', 'type': 'fixed', 'value': 5}]

In [19]:
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.utils.instantiation import ObjectiveProperties
import torch
from time import time

integration_time_name = "integration_time_s"
model_runtime_name = "model_runtime_s"
tracking_metric_names = [integration_time_name, model_runtime_name]

campaign_objects = []
objective_name = "frechet"
batch_size = 1
num_sobol = 6
objectives = {objective_name: ObjectiveProperties(minimize=True)}
model_kwargs = (
    {"torch_device": torch.device("cuda")} if torch.cuda.is_available() else None
)
print(f"model_kwargs: {model_kwargs}")

for i, sdl in enumerate(sdls):

    def evaluate(parameters):
        results = sdl.evaluate(parameters)
        # # remove channel names to prevent extra tracking metrics warnings
        # [results.pop(ch) for ch in sdl.channel_names]
        atime = parameters["atime"]
        astep = parameters["astep"]
        results[integration_time_name] = calc_integration_time_s(atime, astep)
        return {objective_name: results[objective_name]}

    gs = GenerationStrategy(
        steps=[
            GenerationStep(model=Models.SOBOL, num_trials=num_sobol),
            GenerationStep(
                model=Models.GPKG,
                num_trials=-1,
                model_kwargs=model_kwargs,
                model_gen_kwargs=model_gen_kwargs,
            ),
        ]
    )

    ax_client = AxClient(generation_strategy=gs)
    ax_client.create_experiment(
        name="sdl_demo_mf_experiment",
        parameters=params,
        objectives=objectives,
        tracking_metric_names=tracking_metric_names,
        overwrite_existing_experiment=True,
    )

    running_integration_time_s = 0
    # Initial sobol samples
    for i in range(num_sobol):
        t0 = time()
        parameters, trial_index = ax_client.get_next_trial()
        model_runtime_dict = {model_runtime_name: time() - t0}
        results = evaluate(parameters)
        raw_data = {**model_runtime_dict, **results}
        ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

    # KGBO
    while running_integration_time_s < time_limit_s:
        t0 = time()
        q_p, q_t = [], []
        # Simulate batches
        for q in range(batch_size):
            parameters, trial_index = ax_client.get_next_trial()
            q_p.append(parameters)
            q_t.append(trial_index)
        model_runtime_dict = {model_runtime_name: time() - t0}

        for q in range(batch_size):
            pi = q_p[q]
            ti = q_t[q]
            integration_time = calc_integration_time_s(pi["atime"], pi["astep"])
            running_integration_time_s = running_integration_time_s + integration_time
            results = evaluate(pi)
            if running_integration_time_s > time_limit_s:
                # backup the time by one iteration and break
                final_cost_s = running_integration_time_s - integration_time
                break
            raw_data = {**model_runtime_dict, **results}
            ax_client.complete_trial(trial_index=ti, raw_data=raw_data)

    print(f"Time limit reached: {time_limit_s} s")
    print(f"Running integration time {final_cost_s} s")

    campaign_objects.append({"campaign_num": i, "sdl": sdl, "ax_client": ax_client})

[INFO 08-09 18:37:54] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter R. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter G. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter B. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instan

model_kwargs: {'torch_device': device(type='cuda')}


[INFO 08-09 18:37:57] ax.service.ax_client: Completed trial 0 with data: {'model_runtime_s': (0.008999, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:37:57] ax.service.ax_client: Generated new trial 1 with parameters {'R': 82, 'G': 58, 'B': 18, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:38:01] ax.service.ax_client: Completed trial 1 with data: {'model_runtime_s': (0.015611, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:38:01] ax.service.ax_client: Generated new trial 2 with parameters {'R': 55, 'G': 3, 'B': 87, 'atime': 4, 'astep': 5}.
[INFO 08-09 18:38:04] ax.service.ax_client: Completed trial 2 with data: {'model_runtime_s': (0.006849, None), 'frechet': (15567.0, None)}.
[INFO 08-09 18:38:04] ax.service.ax_client: Generated new trial 3 with parameters {'R': 32, 'G': 11, 'B': 18, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:38:08] ax.service.ax_client: Completed trial 3 with data: {'model_runtime_s': (0.005571, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:38:08] ax.service.ax_

Time limit reached: 0.00010008 s
Running integration time 5.004e-05 s


[INFO 08-09 18:42:48] ax.service.ax_client: Completed trial 0 with data: {'model_runtime_s': (0.0, None), 'frechet': (9536.179791, None)}.
[INFO 08-09 18:42:48] ax.service.ax_client: Generated new trial 1 with parameters {'R': 1, 'G': 59, 'B': 25, 'atime': 3, 'astep': 5}.
[INFO 08-09 18:42:52] ax.service.ax_client: Completed trial 1 with data: {'model_runtime_s': (0.015429, None), 'frechet': (9537.335529, None)}.
[INFO 08-09 18:42:52] ax.service.ax_client: Generated new trial 2 with parameters {'R': 79, 'G': 47, 'B': 76, 'atime': 2, 'astep': 5}.
[INFO 08-09 18:42:55] ax.service.ax_client: Completed trial 2 with data: {'model_runtime_s': (0.00596, None), 'frechet': (9530.180533, None)}.
[INFO 08-09 18:42:55] ax.service.ax_client: Generated new trial 3 with parameters {'R': 26, 'G': 5, 'B': 84, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:42:59] ax.service.ax_client: Completed trial 3 with data: {'model_runtime_s': (0.013637, None), 'frechet': (9533.180162, None)}.
[INFO 08-09 18:42:59] ax.se

Time limit reached: 0.00010008 s
Running integration time 0.00010007999999999999 s


In [20]:
running_integration_time_s

0.00011675999999999999

In [6]:
ax_clients

[AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
 AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
 AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
 AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
 AxClient(experiment=Experiment(sdl_demo_mf_experiment))]

## Code Graveyard

In [None]:
# params.append(
#     dict(
#         name="astep",
#         type="range",
#         is_fidelity=True,
#         bounds=astep_bnd,
#         target_value=astep_bnd[1],
#     )
# )

In [None]:
# from botorch.test_functions.multi_fidelity import AugmentedHartmann
# import torch

In [None]:
# [
#     "utc_timestamp",
#     "ch470",
#     "ch550",
#     "ch670",
#     "ch410",
#     "background", # needs to be collapsed
#     "ch620",
#     "sd_card_ready",
#     "ch510",
#     "warning",
#     "ch583",
#     "device_nickname",
#     "ch440",
#     "onboard_temperature_K",
#     "encrypted_device_id_truncated",
#     "mae",
#     "rmse",
#     "frechet",
# ]