[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/multi_objective_BO_SVDKL.ipynb)

# Multi-objective Active learning using DigitalTwin microscope: Stochastic Variational Deep kernel learning in Gpytorch and BO loop in Botorch. [Recommended to take a GPU instance]
- For single objective please [see](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/single_objective_BO_SVDKL.ipynb) 

Prepared by [Utkarsh Pratiush](https://github.com/utkarshp1161)
- Get in touch if any doubts or discussion [utkarshp1161@gmail.com]
- For reading related to this notebook please refer to [wilson et al 2016](https://arxiv.org/abs/1611.00336), [wilson et al 2015](https://arxiv.org/abs/1511.02222) and [Sebastian et al 2021](https://arxiv.org/abs/2102.12108)
- Reference to [Gpytorch example](https://docs.gpytorch.ai/en/v1.12/examples/06_PyTorch_NN_Integration_DKL/Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.html)



## 1. Install modules and start DigitalTwin microscope

In [None]:
#install
!pip install -q botorch==0.12.0
!pip install -q gpytorch==1.13
!pip install -q git+https://github.com/pycroscopy/DTMicroscope.git
!pip install -q h5py
!pip install -q sidpy
!pip install -q pyro5
!pip install -q scifireaders
!pip install -q pynsid
!pip install -q pytemlib

## start dtmic
!run_server_stem

## 2. BEPS data - credits: yongtao liu

### 2a. Download data

In [None]:
!gdown https://drive.google.com/uc?id=1UMub2L-9X8imtvTbz_aE4l4Gz-TW5Yd5

### 2b. Preprocess data 


In [None]:
import h5py
import sidpy
import numpy as np
from pathlib import Path
import random
from datetime import datetime
import Pyro5.api
import pickle
import matplotlib.pyplot as plt
import pyNSID
import SciFiReaders

input_file = "BEPS_1d7um_0009.h5"
h5_f = h5py.File(input_file, 'r+')
sidpy.hdf.hdf_utils.print_tree(h5_f)


sho_mat = h5_f['Measurement_000/Channel_000/Raw_Data-SHO_Fit_001/Fit']
spec_val = h5_f['Measurement_000/Channel_000/Raw_Data-SHO_Fit_000/Spectroscopic_Values']
pos_inds = h5_f['Measurement_000/Channel_000/Position_Indices']
pos_dim_sizes = [np.max(pos_inds[:,0])+1, np.max(pos_inds[:,1]+1)]
topo = h5_f['Measurement_000/Channel_001/Raw_Data/']['r']

sho_mat_ndim = sho_mat[:].reshape(pos_dim_sizes[0], pos_dim_sizes[1], -1)
amp_mat_ndim = sho_mat_ndim['Amplitude [V]']
phase_mat_ndim = sho_mat_ndim['Phase [rad]']
fre_mat_ndim = sho_mat_ndim['Frequency [Hz]']
q_mat_ndim = sho_mat_ndim['Quality Factor']



min_ = np.min(phase_mat_ndim)
max_ = np.max(phase_mat_ndim)

pha_correct = np.where(phase_mat_ndim > 1.5, phase_mat_ndim + (min_-max_), phase_mat_ndim)

pha_correct = pha_correct + np.pi-0.55

#separate on field and off field
full_spec_len = sho_mat_ndim.shape[2]
spec_len = int(full_spec_len/2)
pix_x, pix_y = sho_mat_ndim.shape[0], sho_mat_ndim.shape[1]

amp_off_field = np.zeros((pix_x, pix_y, spec_len))
pha_off_field = np.zeros((pix_x, pix_y, spec_len))
fre_off_field = np.zeros((pix_x, pix_y, spec_len))

amp_on_field = np.zeros((pix_x, pix_y, spec_len))
pha_on_field = np.zeros((pix_x, pix_y, spec_len))
fre_on_field = np.zeros((pix_x, pix_y, spec_len))

v_step = np.zeros(spec_len)

for i in range (spec_len):
  amp_off_field[:,:, i] = amp_mat_ndim[:,:,2*i+1]
  pha_off_field[:,:, i] = pha_correct[:,:,2*i+1]
  fre_off_field[:,:, i] = fre_mat_ndim[:,:,2*i+1]/1000
  amp_on_field[:,:, i] = amp_mat_ndim[:,:,2*i]
  pha_on_field[:,:, i] = pha_correct[:,:,2*i]
  fre_on_field[:,:, i] = fre_mat_ndim[:,:,2*i]/1000
  v_step[i] = spec_val[0,2*i]


pola_off_field = amp_off_field*np.cos(pha_off_field)

struc_img = amp_off_field.mean(2)


## Cut a part for DKL exploration

exp_data = pola_off_field[50:, 20:70]
img_data = amp_off_field[50:, 20:70]

struc_img = img_data.mean(2)
plot_pixx1 = 10; plot_pixy1 = 30
plot_pixx2 = 40; plot_pixy2 = 10

norm_ = lambda x: (x - np.min(x)) / np.ptp(x)# or use:  norm_ = lambda x: (x - np.min(x)) / np.ptp(x) --> numpy-2.0 upgrade

img = norm_(struc_img)
spectra = norm_(exp_data)

# coordinates = get_coord_grid(img, step = 1, return_dict=False)


### 2c. Prepare data for DigitalTwin Microscope

In [None]:
image = img
spectrum_image = spectra
## Set scale bar and energy axis of spectrum based on your data
scale = 1
energy_axis = np.linspace(0, 1, 256)

data_sets = {'Channel_000': sidpy.Dataset.from_array(image , name = "overview image"),
             'Channel_001': sidpy.Dataset.from_array(spectrum_image, name = "spectrum image")}

data_sets['Channel_000'].data_type = 'image'
data_sets['Channel_001'].data_type = 'spectral_image'

data_sets['Channel_000'].set_dimension(0, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[0])*scale,
                                          name='x', units='nm', quantity='Length',
                                          dimension_type='spatial'))

data_sets['Channel_000'].set_dimension(1, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[1])*scale,
                                          'y', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(0, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[0])*scale,
                                          name='x', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(1, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[1])*scale,
                                          'y', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(2, sidpy.Dimension(energy_axis,
                                          'energy_scale', units='ev', quantity='Energy',
                                          dimension_type='spectral'))


def save_dataset_dictionary(h5_file, datasets):
    h5_measurement_group = sidpy.hdf.prov_utils.create_indexed_group(h5_file, 'Measurement_')
    for key, dataset in datasets.items():
        if key[-1] == '/':
            key = key[:-1]
        if isinstance(dataset, sidpy.Dataset):
            h5_group = h5_measurement_group.create_group(key)
            h5_dataset = pyNSID.hdf_io.write_nsid_dataset(dataset, h5_group)
            dataset.h5_dataset = h5_dataset
            h5_dataset.file.flush()
        elif isinstance(dataset, dict):
            sidpy.hdf.hdf_utils.write_dict_to_h5_group(h5_measurement_group, dataset, key)
        else:
            print('could not save item ', key, 'of dataset dictionary')
    return h5_measurement_group

dataset_name = 'yl_beps.h5'
h5_file = h5py.File(dataset_name, mode='a')
save_dataset_dictionary(h5_file, data_sets)
h5_file.close()





## 3. Multi Objective Bayesian optimization with DKL
- Note we use one surrogate to model each objective and then the next point is chosen based on the point which has highest value across both the surrogates(objectives).

### 3a. DKL model 

In [None]:
import torch
import torch.nn as nn
from botorch.models.model import Model
from botorch.models.gpytorch import GPyTorchModel
from botorch.acquisition.multi_objective.monte_carlo import qExpectedHypervolumeImprovement
from botorch.optim import optimize_acqf
# from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning
from botorch.utils.sampling import sample_simplex
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.test_functions.multi_objective import BraninCurrin
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.mlls.variational_elbo import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood
from torch.optim import Adam
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.optim import optimize_acqf_discrete
from botorch.acquisition.multi_objective import ExpectedHypervolumeImprovement
from botorch.utils.multi_objective.box_decompositions.non_dominated import FastNondominatedPartitioning
from tqdm import tqdm
import gpytorch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
import math
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional, Dict, Union, List
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
torch.set_default_dtype(dtype)

print(device)


# Simple ConvNet for feature extraction
class ConvNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels=1, output_dim=32):
        super(ConvNetFeatureExtractor, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.output_dim = output_dim
        self.fc = None  # Placeholder for the fully connected layer

    def forward(self, x):
        if len(x.shape) == 3: # TODO: hacky way to make sure botorch acquisition function works
            # flatten
            batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
            d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
            x = x.reshape(int(batch_size), int(channel), int(d), int(d))
        # Pass through the convolutional layers
        x = self.conv_layers(x)

        # If the fully connected layer is not defined yet, initialize it dynamically******************key
        if self.fc is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            device = x.device# TODO: better way to handle device
            self.fc = nn.Linear(flattened_size, self.output_dim).to(device)  # Create fc layer on the correct device

        # Flatten for fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# GP model with deep kernel using ConvNet feature extractor
class GPModelDKL(ApproximateGP, Model):
    def __init__(self, inducing_points, likelihood, feature_extractor=None, input_shape=(1, 5, 5)):
        self.input_shape = input_shape
        
        # Transform inducing points first
        flat_inducing = self._flatten_input(inducing_points)
        feature_inducing = feature_extractor(self._reshape_to_patch(flat_inducing))
        # No need to reshape feature_inducing as it's already in correct shape
        
        variational_distribution = CholeskyVariationalDistribution(feature_inducing.size(0))
        variational_strategy = VariationalStrategy(
            self, feature_inducing, variational_distribution, learn_inducing_locations=True
        )
        
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self._num_outputs = 1  # storing as private attribute
        self.likelihood = likelihood
        self.feature_extractor = feature_extractor

    def _flatten_input(self, x):
        """Flatten input to include channel dim"""
        if x.ndim == 4:  # (batch, channel, height, width)
            return x.reshape(x.size(0), 1, -1)
        return x
    
    def _reshape_to_patch(self, x):
        """Reshape flattened input to patch format"""
        if x.ndim == 3:  # (batch, channel, flattened)
            batch_size = x.size(0)
            return x.reshape(batch_size, *self.input_shape)
        return x

    def forward(self, x):
        # x should already be features at this point
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

    def __call__(self, x, use_feature_extractor=True, *args, **kwargs):
        if use_feature_extractor:
            # First ensure x is in patch format
            x = self._reshape_to_patch(x)
            # Get features - these will be [batch_size, feature_dim]
            x = self.feature_extractor(x)
            # No need to reshape after feature extraction
        return super().__call__(x, *args, **kwargs)

    def posterior(self, X, output_indices=None, observation_noise=False, *args, **kwargs):
        self.eval()
        with torch.no_grad():
            # Ensure correct shape for feature extraction
            if X.ndim == 2:  # If input is (batch_size, flattened_dim)
                X = X.unsqueeze(1)  # Add channel dim
            X = self._reshape_to_patch(X)
            # Features will be [batch_size, feature_dim]
            features = self.feature_extractor(X)
            dist = self.likelihood(self(features, use_feature_extractor=False))
            # Ensure output has correct shape for MOBO
            mean = dist.mean.unsqueeze(-1)  # Shape: [batch_size, 1]
            variance = dist.variance.unsqueeze(-1)  # Shape: [batch_size, 1]
            dist = gpytorch.distributions.MultivariateNormal(mean, torch.diag_embed(variance))
        return GPyTorchPosterior(dist)
    
    @property
    def num_outputs(self) -> int:
        """The number of outputs of the model."""
        return self._num_outputs

    @property
    def hparam_dict(self):
        return {
            "likelihood.noise": self.likelihood.noise.item(),
            "covar_module.base_kernel.outputscale": self.covar_module.base_kernel.outputscale.item(),
            "mean_module.constant": self.mean_module.constant.item(),
        }




### 3b. Utility F:n's - 1

In [None]:
def normalize_data(data : np.ndarray) -> np.ndarray:  # Expected data type: torch.Tensor
    """Normalize data to the [0, 1] range."""
    return (data - data.min()) / (data.max() - data.min())


def numpy_to_torch_for_conv(np_array) -> torch.Tensor:
    """
    Converts a NumPy array of shape (batch_size, a, b) to a PyTorch tensor
    with shape (batch_size, 1, a, b) for neural network use.

    Parameters:
        np_array (np.ndarray): Input NumPy array of shape (batch_size, a, b).

    Returns:
        torch.Tensor: Converted PyTorch tensor of shape (batch_size, 1, a, b).
    """
    # Check if input is a numpy array
    if not isinstance(np_array, np.ndarray):
        raise TypeError("Input must be a NumPy array.")

    # Convert to PyTorch tensor and add a channel dimension
    # tensor = torch.from_numpy(np_array, dtype=torch.double)  # Convert to float tensor
    tensor = torch.tensor(np_array, dtype=dtype)  # Convert to float tensor
    tensor = tensor.unsqueeze(1)  # Add a channel dimension at index 1

    return tensor


######################atomai utils####################################
#Credits Maxim Ziatdinov (https://github.com/ziatdinovmax): https://github.com/pycroscopy/atomai/blob/8db3e944cd9ece68c33c8e3fcca3ef3ce9a111ea/atomai/utils/img.py#L522

def get_coord_grid(imgdata: np.ndarray, step: int,
                   return_dict: bool = True
                   ) -> Union[np.ndarray, Dict[int, np.ndarray]]:
    """
    Generate a square coordinate grid for every image in a stack. Returns coordinates
    in a dictionary format (same format as generated by atomnet.predictor)
    that can be used as an input for utility functions extracting subimages
    and atomstat.imlocal class

    Args:
        imgdata (numpy array): 2D or 3D numpy array
        step (int): distance between grid points
        return_dict (bool): returns coordiantes as a dictionary (same format as atomnet.predictor)

    Returns:
        Dictionary or numpy array with coordinates
    """
    if np.ndim(imgdata) == 2:
        imgdata = np.expand_dims(imgdata, axis=0)
    coord = []
    for i in range(0, imgdata.shape[1], step):
        for j in range(0, imgdata.shape[2], step):
            coord.append(np.array([i, j]))
    coord = np.array(coord)
    if return_dict:
        coord = np.concatenate((coord, np.zeros((coord.shape[0], 1))), axis=-1)
        coordinates_dict = {i: coord for i in range(imgdata.shape[0])}
        return coordinates_dict
    coordinates = [coord for _ in range(imgdata.shape[0])]
    return np.concatenate(coordinates, axis=0)

def get_imgstack(imgdata: np.ndarray,
                 coord: np.ndarray,
                 r: int) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at specified coordinates
    for a single image

    Args:
        imgdata (3D numpy array):
            Prediction of a neural network with dimensions
            :math:`height \\times width \\times n channels`
        coord (N x 2 numpy array):
            (x, y) coordinates
        r (int):
            Window size

    Returns:
        2-element tuple containing

        - Stack of subimages
        - (x, y) coordinates of their centers
    """
    img_cr_all = []
    com = []
    for c in coord:
        cx = int(np.around(c[0]))
        cy = int(np.around(c[1]))
        if r % 2 != 0:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2+1,
                        cy-r//2:cy+r//2+1])
        else:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2,
                        cy-r//2:cy+r//2])
        if img_cr.shape[0:2] == (int(r), int(r)) and not np.isnan(img_cr).any():
            img_cr_all.append(img_cr[None, ...])
            com.append(c[None, ...])
    if len(img_cr_all) == 0:
        return None, None
    img_cr_all = np.concatenate(img_cr_all, axis=0)
    com = np.concatenate(com, axis=0)
    return img_cr_all, com

def extract_subimages(imgdata: np.ndarray,
                      coordinates: Union[Dict[int, np.ndarray], np.ndarray],
                      window_size: int, coord_class: int = 0) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at certain atom class/type
    (usually from a neural network output)

    Args:
        imgdata (numpy array):
            4D stack of images (n, height, width, channel).
            It is also possible to pass a single 2D image.
        coordinates (dict or N x 2 numpy arry): Prediction from atomnet.locator
            (can be from other source but must be in the same format)
            Each element is a :math:`N \\times 3` numpy array,
            where *N* is a number of detected atoms/defects,
            the first 2 columns are *xy* coordinates
            and the third columns is class (starts with 0).
            It is also possible to pass N x 2 numpy array if the corresponding
            imgdata is a single 2D image.
        window_size (int):
            Side of the square for subimage cropping
        coord_class (int):
            Class of atoms/defects around around which the subimages
            will be cropped (3rd column in the atomnet.locator output)

    Returns:
        3-element tuple containing

        - stack of subimages,
        - (x, y) coordinates of their centers,
        - frame number associated with each subimage
    """
    if isinstance(coordinates, np.ndarray):
        coordinates = np.concatenate((
            coordinates, np.zeros((coordinates.shape[0], 1))), axis=-1)
        coordinates = {0: coordinates}
    if np.ndim(imgdata) == 2:
        imgdata = imgdata[None, ..., None]
    subimages_all, com_all, frames_all = [], [], []
    for i, (img, coord) in enumerate(
            zip(imgdata, coordinates.values())):
        coord_i = coord[np.where(coord[:, 2] == coord_class)][:, :2]
        stack_i, com_i = get_imgstack(img, coord_i, window_size)
        if stack_i is None:
            continue
        subimages_all.append(stack_i)
        com_all.append(com_i)
        frames_all.append(np.ones(len(com_i), int) * i)
    if len(subimages_all) > 0:
        subimages_all = np.concatenate(subimages_all, axis=0)
        com_all = np.concatenate(com_all, axis=0)
        frames_all = np.concatenate(frames_all, axis=0)
    return subimages_all, com_all, frames_all

### 3c. Utility F:n's - 2

In [None]:
#*********************************DTmic specific functions starts **********************************************#
from sklearn.metrics import mean_squared_error

def black_box(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)
    score = spectrum[e_start:e_end].sum()


    return score

def black_box_loop_height(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_height(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.max(loop_top) - np.min(loop_bottom)

    score = loop_height(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_loop_area(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_area (raw_spec, cycle) :
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.abs(np.sum(loop_top)-np.sum(loop_bottom))

    score = loop_area(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_positive_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def positive_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.mean(loop_top) - np.mean(loop_bottom)

    score = positive_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now


    return score

def black_box_negative_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def negative_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.min(loop_top) - np.max(loop_bottom)

    score = negative_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

## a) evaluations metrics like nlpd, mse ----
def calculate_mse(y_true, y_pred):
    """Calculate Mean Squared Error (MSE)"""
    #Smaller values indicate better predictions.
    #Squaring ensures that positive and negative errors don't cancel out.
    mse = mean_squared_error(y_true, y_pred)
    return mse

def calculate_nlpd(y_true, y_pred_mean, y_pred_var):
    """Calculate Negative Log Predictive Density (NLPD)"""
    #NLPD evaluates how well the predicted probability distribution matches the true values.
    #Lower NLPD indicates a better match, accounting for both the mean and uncertainty.
    nlpd = 0.5 * torch.log(2 * torch.pi * y_pred_var) + 0.5 * ((y_true - y_pred_mean) ** 2 / y_pred_var)
    return nlpd.mean().item()

### 3d. Utility F:n's - 3

In [None]:
from tqdm import tqdm

def calculate_scores_for_patches(unacquired_indices, indices_all, e1a, e1b, black_box_fn = black_box) -> torch.Tensor:
    """
    Calculate the score for each patch using the black_box function.

    Parameters:
    - patches: Tensor of all data patches.

    Returns:
    - scores: List of scores for each patch.
    """
    scores = []
    for i in unacquired_indices:
        score = black_box_fn(i, indices_all, e1a, e1b)  # Calculate score for each patch
        scores.append(score)
    return torch.tensor(scores)  # Return as a tensor for compatibility

def update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, black_box_fn = black_box) -> (np.array, list):
    for idx in selected_indices:# TODO: It queries the black box everytime on already acquired points:
        acquired_data[idx] = black_box_fn(idx, indices_all, e1a, e1b)
    unacquired_indices = [idx for idx in unacquired_indices if idx not in selected_indices]


    return acquired_data, unacquired_indices


def load_image_and_features(img: np.ndarray , window_size : int) -> (np.ndarray, np.ndarray):
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:, :, :, 0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - np.min(x)) / np.ptp(x) # or use:  norm_ = lambda x: (x - np.min(x)) / np.ptp(x) --> numpy-2.0 upgrade
    features = norm_(features_all)
    return features, coords# shapes (3366, 5, 5) and (3366, 2)


def prepare_data_from_microscope(window_size: int) -> (np.ndarray, np.ndarray):
    array_list, shape, dtype = mic_server.get_overview_image()
    img = np.array(array_list, dtype=dtype).reshape(shape)#
    features, indices_all = load_image_and_features(img, window_size)



    return img, features, indices_all# shapes (55, 70), (3366, 5, 5) and (3366, 2)

def get_spectrum_data(indices, energy_range, channel="Channel_001") -> (np.array, int, int):
    array_list, shape, dtype = mic_server.get_spectrum_image(spectrum_image_index=channel)
    spectral_img = np.array(array_list, dtype=dtype).reshape(shape)
    array_list, shape, dtype = mic_server.get_spectrum_image_e_axis(spectrum_image_index=channel)
    E_axis = np.array(array_list, dtype=dtype).reshape(shape)
    e_start, e_end = abs(E_axis - energy_range[0]).argmin(), abs(E_axis - energy_range[1]).argmin()
    return spectral_img, e_start, e_end


from botorch.utils.multi_objective.pareto import is_non_dominated

def plot_pareto_front(acquired_data1, acquired_data2, step, save_path=None):
    """Enhanced Pareto front plotting with proper normalization and visualization"""
    if not acquired_data1 or not acquired_data2:
        log_with_context("Insufficient data to plot Pareto front.")
        return []
    
    # Extract and normalize objectives
    common_indices = sorted(set(acquired_data1.keys()) & set(acquired_data2.keys()))
    objectives = np.zeros((len(common_indices), 2))
    
    for i, idx in enumerate(common_indices):
        objectives[i, 0] = acquired_data1[idx]
        objectives[i, 1] = acquired_data2[idx]
    
    # Normalize objectives
    objectives = (objectives - objectives.min(axis=0)) / (objectives.max(axis=0) - objectives.min(axis=0))
    objectives_tensor = torch.tensor(objectives, dtype=torch.float32)
    
    # Find Pareto optimal points
    pareto_mask = is_non_dominated(objectives_tensor)
    pareto_front = objectives[pareto_mask.numpy()]
    
    # Sort Pareto front for better visualization
    pareto_front = pareto_front[pareto_front[:, 0].argsort()]
    
    # Create visualization
    plt.figure(figsize=(10, 8))
    plt.scatter(objectives[:, 0], objectives[:, 1], 
               c='lightgray', marker='o', label='All Points', alpha=0.5)
    
    # Plot Pareto front with connecting lines
    plt.plot(pareto_front[:, 0], pareto_front[:, 1], 
            'r--', linewidth=2, label='Pareto Front')
    plt.scatter(pareto_front[:, 0], pareto_front[:, 1], 
               c='red', marker='*', s=100, label='Pareto Optimal')
    
    # Annotate Pareto points
    pareto_indices = [common_indices[i] for i, is_pareto in enumerate(pareto_mask.numpy()) if is_pareto]
    for idx, point in zip(pareto_indices, pareto_front):
        plt.annotate(f'#{idx}', 
                    (point[0], point[1]),
                    xytext=(5, 5), 
                    textcoords='offset points',
                    fontsize=8)
    
    plt.xlabel('Objective 1 (Normalized)')
    plt.ylabel('Objective 2 (Normalized)')
    plt.title(f'Pareto Front Evolution - Step {step + 1}')
    plt.legend(loc='upper right')
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Add statistics text box
    stats_text = f'Total Points: {len(common_indices)}\n'
    stats_text += f'Pareto Points: {len(pareto_indices)}'
    plt.text(0.02, 0.98, stats_text,
             transform=plt.gca().transAxes,
             bbox=dict(facecolor='white', alpha=0.8),
             verticalalignment='top')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()
    
    return pareto_indices


def embeddings_and_predictions(model, patches, device="cpu") -> (torch.Tensor, torch.Tensor):
    """
    Get predictions from the trained model
    """
    model.eval()
    patches = patches.to(device)
    with torch.no_grad():
        predictions = model(patches)
        embeddings = model.feature_extractor(patches).view(patches.size(0), -1).cpu().numpy()
    return predictions, embeddings

def train_model(acquired_data, patches, feature_extractor, 
                device="cpu", num_epochs=50, log_interval=5,
                scalarizer_zero=False, debug=False, window_size=5) -> ApproximateGP:

    # Stack patches and prepare input
    X_train_patches = torch.stack([patches[idx] for idx in acquired_data]).to(device)
    X_train_flat = X_train_patches.reshape(X_train_patches.size(0),1, -1).to(device)  # Flatten for BoTorch compatibility
    
    y_train = torch.tensor(list(acquired_data.values()), dtype=torch.float32).to(device)
    if scalarizer_zero:
        y_train = torch.zeros_like(y_train)
    else:
        y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())


    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    # Initialize model with proper input shape
    model = GPModelDKL(
        inducing_points=X_train_flat[:10],  # Use flattened inducing points
        likelihood=likelihood,
        feature_extractor=feature_extractor,
        input_shape=(1, window_size, window_size)  # Specify the original patch shape
    ).to(device)
    
    model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training Progress"):
        optimizer.zero_grad()
        output = model(X_train_flat)  # Pass flattened input
        
        
        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()
        
    return model

# def train_models(acquired_data1, acquired_data2, acquired_data3, patches, feature_extractor1, feature_extractor2, feature_extractor3,
#                 device="cpu", window_size=5):
#     # Train both models with flattened inputs
#     model1 = train_model(acquired_data1, patches, feature_extractor1, device=device, window_size=window_size)
#     model2 = train_model(acquired_data2, patches, feature_extractor2, device=device, window_size=window_size)
#     model3 = train_model(acquired_data3, patches, feature_extractor3, device=device, window_size=window_size)

    
#     # Combine models for MOBO
#     models = [model1, model2, model3]
#     model = ModelListGP(*models)  # Properly initialize ModelListGP
#     return model
def train_models(acquired_data1, acquired_data2, patches, feature_extractor1, feature_extractor2,
                device="cpu", window_size=5):
    # Train both models with flattened inputs

    model1 = train_model(acquired_data1, patches, feature_extractor1, device=device, window_size=window_size)
    model2 = train_model(acquired_data2, patches, feature_extractor2, device=device, window_size=window_size)
    # model3 = train_model(acquired_data3, patches, feature_extractor3, device=device, window_size=window_size)

    
    # Combine models for MOBO
    models = [model1, model2]
    model = ModelListGP(*models)  # Properly initialize ModelListGP
    return model


### 3e. Bayesian optimization loop

In [None]:
def run(config) -> None:
    # Extract all configuration variables
    seed = config["seed"]
    seed_pts = config["seed_pts"]
    budget = config["budget"]
    in_dir = config["in_dir"]
    out_dir_parent = config["out_dir_parent"]
    dataset_name = config["dataset_name"]
    device = config["device"]
    # initial_batch_size = config["initial_batch_size"]
    num_epochs = config["num_epochs"]
    normalize_data_flag = config["normalize_data"]
    window_size = config["window_size"]
    scal_pfm1 = config["scal_pfm1"]
    scal_pfm2 = config["scal_pfm2"]
    mobo_w1 = config["mobo_w1"]
    mobo_w2 = config["mobo_w2"]
    ## need scal1 and scal2

    energy_range1 = [0, 1]# TODO: can be confusing as used in eels data - for now ignore for beps
    energy_range2 = [0, 1]# TODO: can be confusing as used in eels data - for now ignore for beps

    if scal_pfm1 is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_pfm1 == "loop_area":
            black_box_fn1 = black_box_loop_height
        
        elif scal_pfm1 == "loop_height":
            black_box_fn1 = black_box_loop_height
            
        elif scal_pfm1 == "positive_nucleation_bias":
            black_box_fn1 = black_box_positive_nucleation_bias

        elif scal_pfm1 ==  "negative_nucleation_bias":
            black_box_fn1 = black_box_negative_nucleation_bias
        
    else :
        black_box_fn1 = black_box

    if scal_pfm2 is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_pfm2 == "loop_area":
            black_box_fn2 = black_box_loop_height
        
        elif scal_pfm2 == "loop_height":
            black_box_fn2 = black_box_loop_height
            
        elif scal_pfm2 == "positive_nucleation_bias":
            black_box_fn2 = black_box_positive_nucleation_bias

        elif scal_pfm2 ==  "negative_nucleation_bias":
            black_box_fn2 = black_box_negative_nucleation_bias
        
    else :
        black_box_fn2 = black_box
    
    scalarizer_zero = False # TODO: deafult value to zero -- so passed to train_model function --> better way to handel



    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    res_dir = Path(out_dir_parent) / f"MOBO_seed{seed}_Dataset_{dataset_name}_BO_{seed_pts}_epochs{num_epochs}_budget_{budget}_{scal_pfm1}_{scal_pfm2}_ws{window_size}_{timestamp}"
    res_dir.mkdir(parents=True, exist_ok=True)


    # Connect to the microscope server
    uri = "PYRO:microscope.server@localhost:9091"
    global mic_server # TODO: later see better way to do this
    mic_server = Pyro5.api.Proxy(uri)

    dataset_path = in_dir + "/" +  dataset_name
    ### 2. Download data and register
    # !wget https://github.com/pycroscopy/DTMicroscope/raw/utk/data/STEM/SI/test_stem.h5
    mic_server.initialize_microscope("STEM")
    mic_server.register_data(dataset_path)


    # Prepare features and indices from microscope image
    img, features, indices_all = prepare_data_from_microscope(window_size=window_size)
    ############################################
    
    
    patches = numpy_to_torch_for_conv(features)

    # Set up energy ranges for scalarizer extraction
    spectral_img1, e1a, e1b = get_spectrum_data(indices_all, energy_range1)
    spectral_img2, e2a, e2b = get_spectrum_data(indices_all, energy_range2)



    patches = patches.to(device)

    if normalize_data_flag:
        patches = normalize_data(patches)

    feature_extractor1 = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)
    feature_extractor2 = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)

    acquired_data1 = {}
    acquired_data2 = {}
    unacquired_indices1 = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all
    unacquired_indices2 = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all

    selected_indices1 = random.sample(unacquired_indices1, seed_pts)
    selected_indices2 = selected_indices1
     
    seed_indices = selected_indices1
    ######### queries spectrum_image
    true_scalarizer1= calculate_scores_for_patches(unacquired_indices1, indices_all, e1a, e1b, black_box_fn=black_box_fn1)
    true_scalarizer1= (true_scalarizer1 - true_scalarizer1.min()) / (true_scalarizer1.max() - true_scalarizer1.min())######## normalized
    true_scalarizer2= calculate_scores_for_patches(unacquired_indices2, indices_all, e2a, e2b, black_box_fn=black_box_fn2)
    true_scalarizer2= (true_scalarizer2 - true_scalarizer2.min()) / (true_scalarizer2.max() - true_scalarizer2.min())######## normalized

        
    ######### queries microscope 
    acquired_data2, unacquired_indices2 = update_acquired(acquired_data2, unacquired_indices2, selected_indices2, indices_all, e2a, e2b, black_box_fn= black_box_fn2)
    acquired_data1, unacquired_indices1 = update_acquired(acquired_data1, unacquired_indices1, selected_indices1, indices_all, e1a, e1b, black_box_fn= black_box_fn1)

    

    from botorch.acquisition import LogExpectedImprovement #ExpectedImprovement
    mean_y_pred_mean_al = []
    mean_y_pred_variance_al = []
    mae_list = []
    nlpd_list = []
    # Start Bayesian Optimization loop
    for step in range(budget):

        # Train the DKL model===========================NEW=====================
        model = train_models(acquired_data1, acquired_data2, patches, 
                           feature_extractor1, feature_extractor2, device, window_size=window_size)
        model.eval()

        # Prepare candidate set (unacquired patches)
        candidate_indices = unacquired_indices1
        X_candidates = torch.stack([patches[idx] for idx in candidate_indices]).to(device)
        X_candidates_flat = X_candidates.reshape(X_candidates.size(0), -1).to(device)

        # Get training data and setup for EHVI
        train_y1 = torch.tensor(list(acquired_data1.values()), dtype=torch.float32).to(device)
        train_y2 = torch.tensor(list(acquired_data2.values()), dtype=torch.float32).to(device)
        # train_y3 = torch.tensor(list(acquired_data3.values()), dtype=torch.float32).to(device)


        # Normalize and reshape training data
        train_y1 = (train_y1 - train_y1.min()) / (train_y1.max() - train_y1.min())
        train_y2 = (train_y2 - train_y2.min()) / (train_y2.max() - train_y2.min())
        # train_y3 = (train_y3 - train_y3.min()) / (train_y3.max() - train_y3.min())
        train_y = torch.stack([train_y1, train_y2], dim=-1)  # Shape: [n, 2]

        # Define reference point and make sure it's 2D
        ref_point = torch.zeros(2, device=device)  # For maximization

        # Set up partitioning with properly shaped inputs
        partitioning = FastNondominatedPartitioning(
            ref_point=ref_point,
            Y=train_y
        )

        # # Create acquisition function with correct shapes
        # acq_func = ExpectedHypervolumeImprovement(
        #     model=model,
        #     ref_point=ref_point.clone(),
        #     partitioning=partitioning
        # )
        qEHVI = qExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point.clone(),
            partitioning=partitioning,
        )

        # Ensure candidates are properly shaped
        X_candidates = torch.stack([patches[idx] for idx in unacquired_indices1]).to(device)
        X_candidates_flat = X_candidates.reshape(X_candidates.size(0), -1)

        # Optimize with proper shapes
        new_x, acq_value = optimize_acqf_discrete(
            acq_function=qEHVI,
            choices=X_candidates_flat,
            q=1,
        )
        # acq = acq_value.cpu().detach().numpy() 

        # Map the selected candidate index to the original dataset index
        # Convert back to index
        selected_idx = torch.where(
            (X_candidates_flat == new_x.view(1, -1)).all(dim=1)
        )[0].item()
        selected_indices = [candidate_indices[selected_idx]]

        unacquired_indices1_temp = np.copy(unacquired_indices1)

        # Update acquired data with new observations
        acquired_data1, unacquired_indices1 = update_acquired(acquired_data1, unacquired_indices1, selected_indices, indices_all, e1a, e1b)
        acquired_data2, unacquired_indices2 = update_acquired(acquired_data2, unacquired_indices2, selected_indices, indices_all, e2a, e2b)

        # pareto_indices = plot_pareto_front(acquired_data1, acquired_data2, step, save_path=None)
        print(f"**************************done BO step {step +1}", end='\r')

        #***********************************************************plotting in active-learning starts********************************************************************************************************
        pareto_indices = plot_pareto_front(acquired_data1, acquired_data2, step)

        model1 = model.models[0]
        model2 = model.models[1]
        import matplotlib.pyplot as plt
        pred1, embeddings1 = embeddings_and_predictions(model1, patches, device)
        pred2, embeddings2 = embeddings_and_predictions(model2, patches, device)
        # ---------- Reward Predictions ----------
        fig, ax = plt.subplots(1, 2, figsize=(10, 4))

        sc0 = ax[0].scatter(indices_all[:, 1], indices_all[:, 0], c=pred1.mean.cpu().numpy(), cmap='viridis')
        sc1 = ax[1].scatter(indices_all[:, 1], indices_all[:, 0], c=pred2.mean.cpu().numpy(), cmap='viridis')

        ax[0].set_title('model1: predicted reward')
        ax[1].set_title('model2: predicted reward')

        for a, sc in zip(ax, [sc0, sc1]):
            a.set_xlabel('X')
            a.set_ylabel('Y')
            a.invert_yaxis()
            plt.colorbar(sc, ax=a, fraction=0.046, pad=0.04)

        plt.tight_layout()

        # ---------- Spatial Embedding Channels ----------
        fig, ax = plt.subplots(2, 2, figsize=(10, 8))

        sc00 = ax[0, 0].scatter(indices_all[:, 1], indices_all[:, 0], c=embeddings1[:, 0], cmap='plasma')
        sc01 = ax[0, 1].scatter(indices_all[:, 1], indices_all[:, 0], c=embeddings1[:, 1], cmap='plasma')
        sc10 = ax[1, 0].scatter(indices_all[:, 1], indices_all[:, 0], c=embeddings2[:, 0], cmap='plasma')
        sc11 = ax[1, 1].scatter(indices_all[:, 1], indices_all[:, 0], c=embeddings2[:, 1], cmap='plasma')

        titles = [
            "model1: embedding[:, 0]",
            "model1: embedding[:, 1]",
            "model2: embedding[:, 0]",
            "model2: embedding[:, 1]",
        ]

        for i in range(2):
            for j in range(2):
                idx = i * 2 + j
                ax[i, j].set_title(titles[idx])
                ax[i, j].set_xlabel('X')
                ax[i, j].set_ylabel('Y')
                ax[i, j].invert_yaxis()
                plt.colorbar([sc00, sc01, sc10, sc11][idx], ax=ax[i, j], fraction=0.046, pad=0.04)

        plt.tight_layout()

        # ---------- Latent Scatter Plots ----------
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))

        ax[0].scatter(embeddings1[:, 0], embeddings1[:, 1], alpha=0.7, edgecolor='k')
        ax[1].scatter(embeddings2[:, 0], embeddings2[:, 1], alpha=0.7, edgecolor='k')

        ax[0].set_title('model1: latent space')
        ax[1].set_title('model2: latent space')

        for a in ax:
            a.set_xlabel('Embedding 1')
            a.set_ylabel('Embedding 2')
            a.grid(True)

        plt.tight_layout()

        #***********************************************************plotting in active-learning stops********************************************************************************************************

        

    # Save predictions as a .pkl file
    Active_learning_statistics = {
        "img": img,
        "features": features,
        "indices_all": np.array(indices_all),
        "seed_indices": np.array(seed_indices),
        "unacquired_indices": np.array(unacquired_indices1),
        "mean_y_pred_mean_al": np.array(mean_y_pred_mean_al),
        "mean_y_pred_variance_al": np.array(mean_y_pred_variance_al)
        # "mae": np.array(mae_list),
        # "nlpd": np.array(nlpd_list)
                }

    with open(Path(res_dir) / f'Active_learning_statistics.pkl', 'wb') as f:
        pickle.dump(Active_learning_statistics, f)

  
    ##############################-> plot

    predictions_data = Active_learning_statistics
    # Extract necessary data
    img = np.array(predictions_data["img"])  # Image or grid for background visualization
    seed_indices = np.array(predictions_data["seed_indices"])  # Initial sampled indices (referring to positions in indices_all)
    unacquired_indices = np.array(predictions_data["unacquired_indices"])  # Remaining indices
    indices_all = np.array(predictions_data["indices_all"])  # All possible indices (coordinates)

    # Map seed_indices and unacquired_indices to their coordinates in indices_all
    seed_coords = indices_all[seed_indices]
    unacquired_coords = indices_all[unacquired_indices]

    # Calculate acquired indices as the complement of unacquired and seed indices
    acquired_indices = np.setdiff1d(np.arange(indices_all.shape[0]), np.union1d(seed_indices, unacquired_indices), assume_unique=True)
    acquired_coords = indices_all[acquired_indices]

    # Plot the results
    plt.figure(figsize=(10, 8))

    # Display the image or grid as the background
    plt.imshow(img, cmap="gray", origin="lower")

    # Plot the seed points in blue
    plt.scatter(seed_coords[:, 1], seed_coords[:, 0], c="b", label="Seed Points", marker="o")

    time_order = np.arange(len(acquired_coords))  # Create a sequence representing time
    scatter = plt.scatter(acquired_coords[:, 1], acquired_coords[:, 0], c=time_order, cmap="bwr", label="Acquired Points", marker="x")

    # Plot the unacquired points in green
    # plt.scatter(unacquired_coords[:, 1], unacquired_coords[:, 0], c="g", label="Unacquired Points", marker="+")

    # Set plot labels and legend
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title("Active Learning Trajectory")
    plt.legend()
    plt.grid(True)
    # Add a colorbar and label it as "Steps"
    cbar = plt.colorbar(scatter)
    cbar.set_label("Steps")


    plt.savefig(Path(res_dir) / "AL_traj.png")
    plt.show()
    plt.close()


### 3f. Set parameters and Run experiments

In [None]:
config = {
        "seed" : 1, # for repeatibility
        "seed_pts" : 10, # How many points you want to start your BO with?
        "budget" : 50, # How many experimental budget you have?
        "in_dir": ".", # recommended : leave as is
        "out_dir_parent": "out", # recommended : leave as is
        "dataset_name": "yl_beps.h5", # name of data to be loaded in DTmicroscope
        "device": "cuda",
        "num_epochs": 100, # Number of epoch the dkl model trains at each experimental step - Might need tuning based on data
        "normalize_data": True, 
        "window_size": 16, # For square patches - structure property relationship
        "scal_pfm1": "positive_nucleation_bias", # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias",
        "scal_pfm2": "negative_nucleation_bias", # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias"
        "mobo_w1": 1,########----------> weight of first objective on the acquisiton values
        "mobo_w2": 1,########----------> weight of second objective on the acquisiton values
        }
run(config)