[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/multi_objective_BO_SVDKL.ipynb)

# Multi-objective Active learning using DigitalTwin microscope: Stochastic Variational Deep kernel learning in Gpytorch and BO loop in Botorch. [Recommended to take a GPU instance]
- For single objective please [see](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/single_objective_BO_SVDKL.ipynb) 

Prepared by [Utkarsh Pratiush](https://github.com/utkarshp1161)
- Get in touch if any doubts or discussion [utkarshp1161@gmail.com]
- For reading related to this notebook please refer to [wilson et al 2016](https://arxiv.org/abs/1611.00336), [wilson et al 2015](https://arxiv.org/abs/1511.02222) and [Sebastian et al 2021](https://arxiv.org/abs/2102.12108)
- Reference to [Gpytorch example](https://docs.gpytorch.ai/en/v1.12/examples/06_PyTorch_NN_Integration_DKL/Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.html)



## 1. Install modules and start DigitalTwin microscope

In [None]:
#install
!pip install botorch==0.12.0
!pip install gpytorch==1.13
!pip install git+https://github.com/pycroscopy/DTMicroscope.git
!pip install h5py
!pip install sidpy
!pip install -q pyro5
!pip install -q scifireaders
!pip install -q pynsid

## start dtmic
!run_server_stem

## 2. BEPS data - credits: yongtao liu

### 2a. Download data

In [None]:
!gdown https://drive.google.com/uc?id=1x3Nw7ZQoZZlxDusDulgKLt2P0Wu3SJPo

### 2b. Preprocess data 


In [None]:
import h5py
import sidpy
import numpy as np
from pathlib import Path
import random
from datetime import datetime
import Pyro5.api
import pickle
import matplotlib.pyplot as plt
import pyNSID
import SciFiReaders

input_file = "BEPS_1d7um_0009.h5"
h5_f = h5py.File(input_file, 'r+')
sidpy.hdf.hdf_utils.print_tree(h5_f)


sho_mat = h5_f['Measurement_000/Channel_000/Raw_Data-SHO_Fit_001/Fit']
spec_val = h5_f['Measurement_000/Channel_000/Raw_Data-SHO_Fit_000/Spectroscopic_Values']
pos_inds = h5_f['Measurement_000/Channel_000/Position_Indices']
pos_dim_sizes = [np.max(pos_inds[:,0])+1, np.max(pos_inds[:,1]+1)]
topo = h5_f['Measurement_000/Channel_001/Raw_Data/']['r']

sho_mat_ndim = sho_mat[:].reshape(pos_dim_sizes[0], pos_dim_sizes[1], -1)
amp_mat_ndim = sho_mat_ndim['Amplitude [V]']
phase_mat_ndim = sho_mat_ndim['Phase [rad]']
fre_mat_ndim = sho_mat_ndim['Frequency [Hz]']
q_mat_ndim = sho_mat_ndim['Quality Factor']



min_ = np.min(phase_mat_ndim)
max_ = np.max(phase_mat_ndim)

pha_correct = np.where(phase_mat_ndim > 1.5, phase_mat_ndim + (min_-max_), phase_mat_ndim)

pha_correct = pha_correct + np.pi-0.55

#separate on field and off field
full_spec_len = sho_mat_ndim.shape[2]
spec_len = int(full_spec_len/2)
pix_x, pix_y = sho_mat_ndim.shape[0], sho_mat_ndim.shape[1]

amp_off_field = np.zeros((pix_x, pix_y, spec_len))
pha_off_field = np.zeros((pix_x, pix_y, spec_len))
fre_off_field = np.zeros((pix_x, pix_y, spec_len))

amp_on_field = np.zeros((pix_x, pix_y, spec_len))
pha_on_field = np.zeros((pix_x, pix_y, spec_len))
fre_on_field = np.zeros((pix_x, pix_y, spec_len))

v_step = np.zeros(spec_len)

for i in range (spec_len):
  amp_off_field[:,:, i] = amp_mat_ndim[:,:,2*i+1]
  pha_off_field[:,:, i] = pha_correct[:,:,2*i+1]
  fre_off_field[:,:, i] = fre_mat_ndim[:,:,2*i+1]/1000
  amp_on_field[:,:, i] = amp_mat_ndim[:,:,2*i]
  pha_on_field[:,:, i] = pha_correct[:,:,2*i]
  fre_on_field[:,:, i] = fre_mat_ndim[:,:,2*i]/1000
  v_step[i] = spec_val[0,2*i]


pola_off_field = amp_off_field*np.cos(pha_off_field)

struc_img = amp_off_field.mean(2)


## Cut a part for DKL exploration

exp_data = pola_off_field[50:, 20:70]
img_data = amp_off_field[50:, 20:70]

struc_img = img_data.mean(2)
plot_pixx1 = 10; plot_pixy1 = 30
plot_pixx2 = 40; plot_pixy2 = 10

norm_ = lambda x: (x - x.min()) / x.ptp()

img = norm_(struc_img)
spectra = norm_(exp_data)

# coordinates = get_coord_grid(img, step = 1, return_dict=False)


### 2c. Prepare data for DigitalTwin Microscope

In [None]:
image = img
spectrum_image = spectra
## Set scale bar and energy axis of spectrum based on your data
scale = 1
energy_axis = np.linspace(0, 1, 256)

data_sets = {'Channel_000': sidpy.Dataset.from_array(image , name = "overview image"),
             'Channel_001': sidpy.Dataset.from_array(spectrum_image, name = "spectrum image")}

data_sets['Channel_000'].data_type = 'image'
data_sets['Channel_001'].data_type = 'spectral_image'

data_sets['Channel_000'].set_dimension(0, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[0])*scale,
                                          name='x', units='nm', quantity='Length',
                                          dimension_type='spatial'))

data_sets['Channel_000'].set_dimension(1, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[1])*scale,
                                          'y', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(0, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[0])*scale,
                                          name='x', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(1, sidpy.Dimension(np.arange(data_sets['Channel_001'].shape[1])*scale,
                                          'y', units='nm', quantity='Length',
                                          dimension_type='spatial'))
data_sets['Channel_001'].set_dimension(2, sidpy.Dimension(energy_axis,
                                          'energy_scale', units='ev', quantity='Energy',
                                          dimension_type='spectral'))


def save_dataset_dictionary(h5_file, datasets):
    h5_measurement_group = sidpy.hdf.prov_utils.create_indexed_group(h5_file, 'Measurement_')
    for key, dataset in datasets.items():
        if key[-1] == '/':
            key = key[:-1]
        if isinstance(dataset, sidpy.Dataset):
            h5_group = h5_measurement_group.create_group(key)
            h5_dataset = pyNSID.hdf_io.write_nsid_dataset(dataset, h5_group)
            dataset.h5_dataset = h5_dataset
            h5_dataset.file.flush()
        elif isinstance(dataset, dict):
            sidpy.hdf.hdf_utils.write_dict_to_h5_group(h5_measurement_group, dataset, key)
        else:
            print('could not save item ', key, 'of dataset dictionary')
    return h5_measurement_group

dataset_name = 'yl_beps.h5'
h5_file = h5py.File(dataset_name, mode='a')
save_dataset_dictionary(h5_file, data_sets)
h5_file.close()





## 3. Multi Objective Bayesian optimization with DKL
- Note we use one surrogate to model each objective and then the next point is chosen based on the point which has highest value across both the surrogates(objectives).

### 3a. DKL model 

In [12]:
import gpytorch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
import math
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional, Dict, Union, List
import numpy as np
import torch

# Simple ConvNet for feature extraction
class ConvNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels=1, output_dim=32):
        super(ConvNetFeatureExtractor, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.output_dim = output_dim
        self.fc = None  # Placeholder for the fully connected layer

    def forward(self, x):
        if len(x.shape) == 3: # TODO: hacky way to make sure botorch acquisition function works
            # flatten
            batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
            d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
            x = x.reshape(int(batch_size), int(channel), int(d), int(d))
        # Pass through the convolutional layers
        x = self.conv_layers(x)


        # If the fully connected layer is not defined yet, initialize it dynamically******************key
        if self.fc is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            device = x.device# TODO: better way to handle device
            self.fc = nn.Linear(flattened_size, self.output_dim).to(device)  # Create fc layer on the correct device

        # Flatten for fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# GP model with deep kernel using ConvNet feature extractor
class GPModelDKL(ApproximateGP):
    def __init__(self, inducing_points, likelihood, feature_extractor=None):
        if feature_extractor is None:
            feature_extractor = ConvNetFeatureExtractor(
                input_channels=1,  # Set according to your image channels
                output_dim=32      # Set as per the final feature dimension
            ).to(inducing_points.device)
        else:
            feature_extractor = feature_extractor.to(inducing_points.device)

        # Transform inducing points with ConvNet
        inducing_points = feature_extractor(inducing_points)

        # Variational setup
        variational_distribution = CholeskyVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )

        super(GPModelDKL, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.num_outputs = 1  # must be one
        self.likelihood = likelihood
        self.feature_extractor = feature_extractor

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



    def __call__(self, x, use_feature_extractor=True, *args, **kwargs):
        ## TODO: to make it compatible with botorch acquisition function we need it to make patches internally from flattened patches
        if use_feature_extractor:
            if len(x.shape) == 3:
                # flatten
                batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
                d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
                x = x.reshape(int(batch_size), int(channel), int(d), int(d))
            x = self.feature_extractor(x)
        return super().__call__(x, *args, **kwargs)

    def posterior(self, X, output_indices=None, observation_noise=False, *args, **kwargs) -> GPyTorchPosterior:
        self.eval()
        self.likelihood.eval()
        dist = self.likelihood(self(X))
        return GPyTorchPosterior(dist)

    @property
    def hparam_dict(self):
        return {
            "likelihood.noise": self.likelihood.noise.item(),
            "covar_module.base_kernel.outputscale": self.covar_module.base_kernel.outputscale.item(),
            "mean_module.constant": self.mean_module.constant.item(),
        }




### 3b. Utility F:n's - 1

In [13]:
def normalize_data(data : np.ndarray) -> np.ndarray:  # Expected data type: torch.Tensor
    """Normalize data to the [0, 1] range."""
    return (data - data.min()) / (data.max() - data.min())


def numpy_to_torch_for_conv(np_array) -> torch.Tensor:
    """
    Converts a NumPy array of shape (batch_size, a, b) to a PyTorch tensor
    with shape (batch_size, 1, a, b) for neural network use.

    Parameters:
        np_array (np.ndarray): Input NumPy array of shape (batch_size, a, b).

    Returns:
        torch.Tensor: Converted PyTorch tensor of shape (batch_size, 1, a, b).
    """
    # Check if input is a numpy array
    if not isinstance(np_array, np.ndarray):
        raise TypeError("Input must be a NumPy array.")

    # Convert to PyTorch tensor and add a channel dimension
    tensor = torch.from_numpy(np_array).float()  # Convert to float tensor
    tensor = tensor.unsqueeze(1)  # Add a channel dimension at index 1

    return tensor


######################atomai utils####################################
#Credits Maxim Ziatdinov (https://github.com/ziatdinovmax): https://github.com/pycroscopy/atomai/blob/8db3e944cd9ece68c33c8e3fcca3ef3ce9a111ea/atomai/utils/img.py#L522

def get_coord_grid(imgdata: np.ndarray, step: int,
                   return_dict: bool = True
                   ) -> Union[np.ndarray, Dict[int, np.ndarray]]:
    """
    Generate a square coordinate grid for every image in a stack. Returns coordinates
    in a dictionary format (same format as generated by atomnet.predictor)
    that can be used as an input for utility functions extracting subimages
    and atomstat.imlocal class

    Args:
        imgdata (numpy array): 2D or 3D numpy array
        step (int): distance between grid points
        return_dict (bool): returns coordiantes as a dictionary (same format as atomnet.predictor)

    Returns:
        Dictionary or numpy array with coordinates
    """
    if np.ndim(imgdata) == 2:
        imgdata = np.expand_dims(imgdata, axis=0)
    coord = []
    for i in range(0, imgdata.shape[1], step):
        for j in range(0, imgdata.shape[2], step):
            coord.append(np.array([i, j]))
    coord = np.array(coord)
    if return_dict:
        coord = np.concatenate((coord, np.zeros((coord.shape[0], 1))), axis=-1)
        coordinates_dict = {i: coord for i in range(imgdata.shape[0])}
        return coordinates_dict
    coordinates = [coord for _ in range(imgdata.shape[0])]
    return np.concatenate(coordinates, axis=0)

def get_imgstack(imgdata: np.ndarray,
                 coord: np.ndarray,
                 r: int) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at specified coordinates
    for a single image

    Args:
        imgdata (3D numpy array):
            Prediction of a neural network with dimensions
            :math:`height \\times width \\times n channels`
        coord (N x 2 numpy array):
            (x, y) coordinates
        r (int):
            Window size

    Returns:
        2-element tuple containing

        - Stack of subimages
        - (x, y) coordinates of their centers
    """
    img_cr_all = []
    com = []
    for c in coord:
        cx = int(np.around(c[0]))
        cy = int(np.around(c[1]))
        if r % 2 != 0:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2+1,
                        cy-r//2:cy+r//2+1])
        else:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2,
                        cy-r//2:cy+r//2])
        if img_cr.shape[0:2] == (int(r), int(r)) and not np.isnan(img_cr).any():
            img_cr_all.append(img_cr[None, ...])
            com.append(c[None, ...])
    if len(img_cr_all) == 0:
        return None, None
    img_cr_all = np.concatenate(img_cr_all, axis=0)
    com = np.concatenate(com, axis=0)
    return img_cr_all, com

def extract_subimages(imgdata: np.ndarray,
                      coordinates: Union[Dict[int, np.ndarray], np.ndarray],
                      window_size: int, coord_class: int = 0) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at certain atom class/type
    (usually from a neural network output)

    Args:
        imgdata (numpy array):
            4D stack of images (n, height, width, channel).
            It is also possible to pass a single 2D image.
        coordinates (dict or N x 2 numpy arry): Prediction from atomnet.locator
            (can be from other source but must be in the same format)
            Each element is a :math:`N \\times 3` numpy array,
            where *N* is a number of detected atoms/defects,
            the first 2 columns are *xy* coordinates
            and the third columns is class (starts with 0).
            It is also possible to pass N x 2 numpy array if the corresponding
            imgdata is a single 2D image.
        window_size (int):
            Side of the square for subimage cropping
        coord_class (int):
            Class of atoms/defects around around which the subimages
            will be cropped (3rd column in the atomnet.locator output)

    Returns:
        3-element tuple containing

        - stack of subimages,
        - (x, y) coordinates of their centers,
        - frame number associated with each subimage
    """
    if isinstance(coordinates, np.ndarray):
        coordinates = np.concatenate((
            coordinates, np.zeros((coordinates.shape[0], 1))), axis=-1)
        coordinates = {0: coordinates}
    if np.ndim(imgdata) == 2:
        imgdata = imgdata[None, ..., None]
    subimages_all, com_all, frames_all = [], [], []
    for i, (img, coord) in enumerate(
            zip(imgdata, coordinates.values())):
        coord_i = coord[np.where(coord[:, 2] == coord_class)][:, :2]
        stack_i, com_i = get_imgstack(img, coord_i, window_size)
        if stack_i is None:
            continue
        subimages_all.append(stack_i)
        com_all.append(com_i)
        frames_all.append(np.ones(len(com_i), int) * i)
    if len(subimages_all) > 0:
        subimages_all = np.concatenate(subimages_all, axis=0)
        com_all = np.concatenate(com_all, axis=0)
        frames_all = np.concatenate(frames_all, axis=0)
    return subimages_all, com_all, frames_all

### 3c. Utility F:n's - 2

In [14]:
#*********************************DTmic specific functions starts **********************************************#
from sklearn.metrics import mean_squared_error

def black_box(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)
    score = spectrum[e_start:e_end].sum()


    return score

def black_box_loop_height(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_height(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.max(loop_top) - np.min(loop_bottom)

    score = loop_height(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_loop_area(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_area (raw_spec, cycle) :
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.abs(np.sum(loop_top)-np.sum(loop_bottom))

    score = loop_area(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_positive_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def positive_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.mean(loop_top) - np.mean(loop_bottom)

    score = positive_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now


    return score

def black_box_negative_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def negative_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.min(loop_top) - np.max(loop_bottom)

    score = negative_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

## a) evaluations metrics like nlpd, mse ----
def calculate_mse(y_true, y_pred):
    """Calculate Mean Squared Error (MSE)"""
    #Smaller values indicate better predictions.
    #Squaring ensures that positive and negative errors don't cancel out.
    mse = mean_squared_error(y_true, y_pred)
    return mse

def calculate_nlpd(y_true, y_pred_mean, y_pred_var):
    """Calculate Negative Log Predictive Density (NLPD)"""
    #NLPD evaluates how well the predicted probability distribution matches the true values.
    #Lower NLPD indicates a better match, accounting for both the mean and uncertainty.
    nlpd = 0.5 * torch.log(2 * torch.pi * y_pred_var) + 0.5 * ((y_true - y_pred_mean) ** 2 / y_pred_var)
    return nlpd.mean().item()

### 3d. Utility F:n's - 3

In [15]:
from tqdm import tqdm

def calculate_scores_for_patches(unacquired_indices, indices_all, e1a, e1b, black_box_fn = black_box) -> torch.Tensor:
    """
    Calculate the score for each patch using the black_box function.

    Parameters:
    - patches: Tensor of all data patches.

    Returns:
    - scores: List of scores for each patch.
    """
    scores = []
    for i in unacquired_indices:
        score = black_box_fn(i, indices_all, e1a, e1b)  # Calculate score for each patch
        scores.append(score)
    return torch.tensor(scores)  # Return as a tensor for compatibility

def update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, black_box_fn = black_box) -> (np.array, list):
    for idx in selected_indices:# TODO: It queries the black box everytime on already acquired points:
        acquired_data[idx] = black_box_fn(idx, indices_all, e1a, e1b)
    unacquired_indices = [idx for idx in unacquired_indices if idx not in selected_indices]


    return acquired_data, unacquired_indices


def load_image_and_features(img: np.ndarray , window_size : int) -> (np.ndarray, np.ndarray):
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:, :, :, 0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - x.min()) / x.ptp()
    features = norm_(features_all)
    return features, coords# shapes (3366, 5, 5) and (3366, 2)


def prepare_data_from_microscope(window_size: int) -> (np.ndarray, np.ndarray):
    array_list, shape, dtype = mic_server.get_overview_image()
    img = np.array(array_list, dtype=dtype).reshape(shape)#
    features, indices_all = load_image_and_features(img, window_size)



    return img, features, indices_all# shapes (55, 70), (3366, 5, 5) and (3366, 2)

def get_spectrum_data(indices, energy_range, channel="Channel_001") -> (np.array, int, int):
    array_list, shape, dtype = mic_server.get_spectrum_image(spectrum_image_index=channel)
    spectral_img = np.array(array_list, dtype=dtype).reshape(shape)
    array_list, shape, dtype = mic_server.get_spectrum_image_e_axis(spectrum_image_index=channel)
    E_axis = np.array(array_list, dtype=dtype).reshape(shape)
    e_start, e_end = abs(E_axis - energy_range[0]).argmin(), abs(E_axis - energy_range[1]).argmin()
    return spectral_img, e_start, e_end


from botorch.utils.multi_objective.pareto import is_non_dominated

def plot_pareto_front(acquired_data1, acquired_data2, step, save_path=None):
    """
    Plot the current Pareto front based on acquired data from both objectives and return Pareto indices.
    
    Args:
        acquired_data1 (dict): Acquired data for Objective 1. Format: {index: value1}
        acquired_data2 (dict): Acquired data for Objective 2. Format: {index: value2}
        step (int): Current BO step.
        save_path (str, optional): Path to save the plot. If None, displays the plot.
    
    Returns:
        list: List of indices that are Pareto optimal.
    """
    if not acquired_data1 or not acquired_data2:
        print("Insufficient data to plot Pareto front.")
        return []
    
    # Ensure both acquired_data1 and acquired_data2 have the same indices
    common_indices = list(set(acquired_data1.keys()).intersection(set(acquired_data2.keys())))
    if not common_indices:
        print("No common indices between acquired_data1 and acquired_data2 for Pareto front plotting.")
        return []
    
    # Extract objective values
    obj1 = np.array([acquired_data1[idx] for idx in common_indices])
    obj2 = np.array([acquired_data2[idx] for idx in common_indices])
    
    obj1= (obj1 - obj1.min()) / (obj1.max() - obj1.min())######## normalized
    obj2= (obj2 - obj2.min()) / (obj2.max() - obj2.min())######## normalized

    
    # Combine objectives into a single array
    objectives = np.stack([obj1, obj2], axis=1)
    objectives_tensor = torch.tensor(objectives, dtype=torch.float32)
    
    # Determine Pareto optimal points using BoTorch's is_non_dominated
    pareto_mask = is_non_dominated(objectives_tensor)
    
    # Extract Pareto optimal points
    pareto_obj1 = obj1[pareto_mask.numpy()]
    pareto_obj2 = obj2[pareto_mask.numpy()]
    
    # Extract Pareto optimal indices
    pareto_indices = [common_indices[i] for i, is_pareto in enumerate(pareto_mask.numpy()) if is_pareto]
    
    # Plotting
    plt.figure(figsize=(8, 6))
    plt.scatter(obj1, obj2, label='Acquired Points', color='blue')
    plt.scatter(pareto_obj1, pareto_obj2, label='Pareto Front', color='red')
    
    # Optionally, annotate Pareto points with their indices
    for idx, x, y in zip(pareto_indices, pareto_obj1, pareto_obj2):
        plt.annotate(str(idx), (x, y), textcoords="offset points", xytext=(5,5), ha='left', fontsize=8)
    
    plt.xlabel('Objective 1')
    plt.ylabel('Objective 2')
    plt.title(f'Pareto Front after BO Step {step +1}')
    plt.legend()
    plt.grid(True)
    
    if save_path:
        plt.savefig(save_path)
        plt.show()
        plt.close()
    else:
        plt.show()
    
    return pareto_indices


def embeddings_and_predictions(model, patches, device="cpu") -> (torch.Tensor, torch.Tensor):
    """
    Get predictions from the trained model
    """
    model.eval()
    patches = patches.to(device)
    with torch.no_grad():
        predictions = model(patches)
        embeddings = model.feature_extractor(patches).view(patches.size(0), -1).cpu().numpy()
    return predictions, embeddings

def train_model(acquired_data, patches, feature_extractor,
                device="cpu", num_epochs=50, log_interval=5,
                scalarizer_zero=False) -> ApproximateGP:
    X_train = torch.stack([patches[idx] for idx in acquired_data]).to(device)
    y_train = torch.tensor(list(acquired_data.values()), dtype=torch.float32).to(device)
    if scalarizer_zero:
        y_train = torch.zeros_like(y_train)

    else:
        # Normalize y_train

        y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())



    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    inducing_points = X_train[:10]
    model = GPModelDKL(inducing_points=inducing_points, likelihood=likelihood, feature_extractor=feature_extractor).to(device)

    model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training Progress"):
        optimizer.zero_grad()
        output = model(X_train)


        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()



    return model



### 3e. Bayesian optimization loop

In [20]:
def run(config) -> None:
    # Extract all configuration variables
    seed = config["seed"]
    seed_pts = config["seed_pts"]
    budget = config["budget"]
    in_dir = config["in_dir"]
    out_dir_parent = config["out_dir_parent"]
    dataset_name = config["dataset_name"]
    device = config["device"]
    # initial_batch_size = config["initial_batch_size"]
    num_epochs = config["num_epochs"]
    normalize_data_flag = config["normalize_data"]
    window_size = config["window_size"]
    scal_pfm1 = config["scal_pfm1"]
    scal_pfm2 = config["scal_pfm2"]
    ## need scal1 and scal2

    energy_range1 = [0, 1]# TODO: can be confusing as used in eels data - for now ignore for beps
    energy_range2 = [0, 1]# TODO: can be confusing as used in eels data - for now ignore for beps

    if scal_pfm1 is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_pfm1 == "loop_area":
            black_box_fn1 = black_box_loop_height
        
        elif scal_pfm1 == "loop_height":
            black_box_fn1 = black_box_loop_height
            
        elif scal_pfm1 == "positive_nucleation_bias":
            black_box_fn1 = black_box_positive_nucleation_bias

        elif scal_pfm1 ==  "negative_nucleation_bias":
            black_box_fn1 = black_box_negative_nucleation_bias
        
    else :
        black_box_fn1 = black_box

    if scal_pfm2 is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_pfm2 == "loop_area":
            black_box_fn2 = black_box_loop_height
        
        elif scal_pfm2 == "loop_height":
            black_box_fn2 = black_box_loop_height
            
        elif scal_pfm2 == "positive_nucleation_bias":
            black_box_fn2 = black_box_positive_nucleation_bias

        elif scal_pfm2 ==  "negative_nucleation_bias":
            black_box_fn2 = black_box_negative_nucleation_bias
        
    else :
        black_box_fn2 = black_box
    
    scalarizer_zero = False # TODO: deafult value to zero -- so passed to train_model function --> better way to handel



    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    res_dir = Path(out_dir_parent) / f"benchmark/MOBO_pareto_hack_Dataset_seed{seed}_{dataset_name}_BO_{seed_pts}_epochs{num_epochs}_budget_{budget}_{scal_pfm1}_{scal_pfm2}_ws{window_size}_{timestamp}"
    res_dir.mkdir(parents=True, exist_ok=True)


    # Connect to the microscope server
    uri = "PYRO:microscope.server@localhost:9091"
    global mic_server # TODO: later see better way to do this
    mic_server = Pyro5.api.Proxy(uri)

    dataset_path = in_dir + "/" +  dataset_name
    ### 2. Download data and register
    # !wget https://github.com/pycroscopy/DTMicroscope/raw/utk/data/STEM/SI/test_stem.h5
    mic_server.initialize_microscope("STEM")
    mic_server.register_data(dataset_path)


    # Prepare features and indices from microscope image
    img, features, indices_all = prepare_data_from_microscope(window_size=window_size)
    ############################################
    
    
    patches = numpy_to_torch_for_conv(features)

    # Set up energy ranges for scalarizer extraction
    spectral_img1, e1a, e1b = get_spectrum_data(indices_all, energy_range1)
    spectral_img2, e2a, e2b = get_spectrum_data(indices_all, energy_range2)



    patches = patches.to(device)

    if normalize_data_flag:
        patches = normalize_data(patches)

    feature_extractor1 = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)
    feature_extractor2 = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)

    acquired_data1 = {}
    acquired_data2 = {}
    unacquired_indices1 = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all
    unacquired_indices2 = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all

    selected_indices1 = random.sample(unacquired_indices1, seed_pts)
    selected_indices2 = selected_indices1
     
    seed_indices = selected_indices1
    ######### queries spectrum_image
    true_scalarizer1= calculate_scores_for_patches(unacquired_indices1, indices_all, e1a, e1b, black_box_fn=black_box_fn1)
    true_scalarizer1= (true_scalarizer1 - true_scalarizer1.min()) / (true_scalarizer1.max() - true_scalarizer1.min())######## normalized
    true_scalarizer2= calculate_scores_for_patches(unacquired_indices2, indices_all, e2a, e2b, black_box_fn=black_box_fn2)
    true_scalarizer2= (true_scalarizer2 - true_scalarizer2.min()) / (true_scalarizer2.max() - true_scalarizer2.min())######## normalized

        
    ######### queries microscope 
    acquired_data2, unacquired_indices2 = update_acquired(acquired_data2, unacquired_indices2, selected_indices2, indices_all, e2a, e2b, black_box_fn= black_box_fn2)
    acquired_data1, unacquired_indices1 = update_acquired(acquired_data1, unacquired_indices1, selected_indices1, indices_all, e1a, e1b, black_box_fn= black_box_fn1)

    

    from botorch.acquisition import LogExpectedImprovement #ExpectedImprovement
    mean_y_pred_mean_al = []
    mean_y_pred_variance_al = []
    mae_list = []
    nlpd_list = []
    # Start Bayesian Optimization loop
    for step in range(budget):


        # Train the DKL model
        model1 = train_model(acquired_data1, patches, feature_extractor1, device=device, num_epochs=num_epochs, scalarizer_zero=scalarizer_zero)
        model2 = train_model(acquired_data2, patches, feature_extractor2, device=device, num_epochs=num_epochs, scalarizer_zero=scalarizer_zero)

        model1.eval()
        model2.eval()


        # Wrap the model and likelihood in the BoTorch model ------> Ithink not needed as have approxiamateGP--> check later

        # Prepare candidate set (unacquired patches)
        candidate_indices = unacquired_indices1
        X_candidates = torch.stack([patches[idx] for idx in candidate_indices]).to(device)
        # X_candidates = X_candidates.reshape(-1, window_size*window_size)# TODO: to make botorch acquisition functions compatible-> optimize_acqf_discrete
        X_candidates = X_candidates.reshape(-1, 1, window_size*window_size) # Note this is when using acq f:n directly and not invoking  optimize_acqf_discrete
        # as we are normalizing the y_values in the train_model function:y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())
        # TODO: make it better -> Note we dont normalize entire data as we dont have access to entire spectrum image early on --> so just on seed points
        y_train1 = torch.tensor(list(acquired_data1.values()), dtype=torch.float32).to(device)
        y_train1 = (y_train1 - y_train1.min()) / (y_train1.max() - y_train1.min())
    
        y_train2 = torch.tensor(list(acquired_data2.values()), dtype=torch.float32).to(device)
        y_train2 = (y_train2 - y_train2.min()) / (y_train2.max() - y_train2.min())

        acq_func1 = LogExpectedImprovement(model=model1, best_f=y_train1.max().to(device))
        acq_func2 = LogExpectedImprovement(model=model2, best_f=y_train2.max().to(device))


        acq_values1 = acq_func1(X_candidates)
        acq_values2 = acq_func2(X_candidates)
        
        # Stack acquisition values into a single tensor of shape (Q, 2)
        acq_matrix = torch.stack([acq_values1, acq_values2], dim=1)

        # Determine Pareto optimal points
        pareto_mask = is_non_dominated(acq_matrix)

        # Convert mask to boolean numpy array
        pareto_mask_np = pareto_mask.cpu().numpy()

        # Extract Pareto optimal indices
        pareto_indices = np.where(pareto_mask_np)[0]


        # Among Pareto optimal points, compute aggregate acquisition value and select the highest
        if len(pareto_indices) == 0:
            # If no Pareto optimal points, fallback to selecting the point with the highest acq_1
            selected_index = torch.argmax(acq_values1).item()
        else:
            # Extract acquisition values for Pareto optimal points
            pareto_acq1 = acq_values1[pareto_mask]
            pareto_acq2 = acq_values2[pareto_mask]
            
            # Compute aggregate acquisition value (e.g., sum or average)
            aggregate_acq = pareto_acq1 + pareto_acq2  # Simple sum; can also use weighted sum or average
            
            # Select the Pareto point with the highest aggregate acquisition value
            selected_pareto_idx = torch.argmax(aggregate_acq).item()
            selected_index = pareto_indices[selected_pareto_idx]

        selected_indices = [selected_index]#### can be multiple indices if batch acquisition

        # Update acquired data with new observations
        acquired_data1, unacquired_indices1 = update_acquired(acquired_data1, unacquired_indices1, selected_indices, indices_all, e1a, e1b)
        acquired_data2, unacquired_indices2 = update_acquired(acquired_data2, unacquired_indices2, selected_indices, indices_all, e2a, e2b)


        pareto_indices = plot_pareto_front(acquired_data1, acquired_data2, step, save_path=f"{res_dir}/Pareto_plot_step{step}.png")

        print(f"**************************done BO step {step +1}")
        print("total points in pareto_indices",len(pareto_indices))
        

    # Save predictions as a .pkl file
    Active_learning_statistics = {
        "img": img,
        "features": features,
        "indices_all": np.array(indices_all),
        "seed_indices": np.array(seed_indices),
        "unacquired_indices": np.array(unacquired_indices1),
        "mean_y_pred_mean_al": np.array(mean_y_pred_mean_al),
        "mean_y_pred_variance_al": np.array(mean_y_pred_variance_al)
        # "mae": np.array(mae_list),
        # "nlpd": np.array(nlpd_list)
                }

    with open(Path(res_dir) / f'Active_learning_statistics.pkl', 'wb') as f:
        pickle.dump(Active_learning_statistics, f)

  
    ##############################-> plot for group meetin

    predictions_data = Active_learning_statistics
    # Extract necessary data
    img = np.array(predictions_data["img"])  # Image or grid for background visualization
    seed_indices = np.array(predictions_data["seed_indices"])  # Initial sampled indices (referring to positions in indices_all)
    unacquired_indices = np.array(predictions_data["unacquired_indices"])  # Remaining indices
    indices_all = np.array(predictions_data["indices_all"])  # All possible indices (coordinates)

    # Map seed_indices and unacquired_indices to their coordinates in indices_all
    seed_coords = indices_all[seed_indices]
    unacquired_coords = indices_all[unacquired_indices]

    # Calculate acquired indices as the complement of unacquired and seed indices
    acquired_indices = np.setdiff1d(np.arange(indices_all.shape[0]), np.union1d(seed_indices, unacquired_indices), assume_unique=True)
    acquired_coords = indices_all[acquired_indices]

    # Plot the results
    plt.figure(figsize=(10, 8))

    # Display the image or grid as the background
    plt.imshow(img, cmap="gray", origin="lower")

    # Plot the seed points in blue
    plt.scatter(seed_coords[:, 1], seed_coords[:, 0], c="b", label="Seed Points", marker="o")

    time_order = np.arange(len(acquired_coords))  # Create a sequence representing time
    scatter = plt.scatter(acquired_coords[:, 1], acquired_coords[:, 0], c=time_order, cmap="bwr", label="Acquired Points", marker="x")

    # Plot the unacquired points in green
    # plt.scatter(unacquired_coords[:, 1], unacquired_coords[:, 0], c="g", label="Unacquired Points", marker="+")

    # Set plot labels and legend
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title("Active Learning Trajectory")
    plt.legend()
    plt.grid(True)
    # Add a colorbar and label it as "Steps"
    cbar = plt.colorbar(scatter)
    cbar.set_label("Steps")


    plt.savefig(Path(res_dir) / "AL_traj.png")
    plt.show()
    plt.close()


### 3f. Set parameters and Run experiments

In [None]:
config = {
        "seed" : 1, # for repeatibility
        "seed_pts" : 10, # How many points you want to start your BO with?
        "budget" : 50, # How many experimental budget you have?
        "in_dir": ".", # recommended : leave as is
        "out_dir_parent": "out", # recommended : leave as is
        "dataset_name": "yl_beps.h5", # name of data to be loaded in DTmicroscope
        "device": "cuda",
        "num_epochs": 100, # Number of epoch the dkl model trains at each experimental step - Might need tuning based on data
        "normalize_data": True, 
        "window_size": 16, # For square patches - structure property relationship
        "scal_pfm1": "positive_nucleation_bias", # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias",
        "scal_pfm1": "negative_nucleation_bias", # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias"
        }
run(config)