[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/utkarshp1161/Active-learning-in-microscopy/blob/apply/apply_nb/STEM_EELS_SI_SVDKL_single_obj.ipynb)

# Single-objective Active learning using DigitalTwin microscope:applied to EELS spectrum image data. [Adapted from - notebook link](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/single_objective_BO_SVDKL.ipynb)

Prepared by [Utkarsh Pratiush](https://github.com/utkarshp1161)




## 1. Install modules and start DigitalTwin microscope

In [None]:
#install
!pip install botorch==0.12.0
!pip install gpytorch==1.13
!pip install git+https://github.com/pycroscopy/DTMicroscope.git
!pip install h5py
!pip install sidpy
!pip install -q pyro5
!pip install -q scifireaders
!pip install -q pynsid
!pip install pytemlib

## start dtmic
!run_server_stem

## 2. eels data - credits : Austin Houston and Utkarsh Pratiush

- Sample : Ag-SiN - collected on March 6th 2025

In [None]:
!gdown https://drive.google.com/uc?id=1U7yTAe3Mub4tKnG6FQctyfoq2JKQ06S7 

import h5py
import sidpy
import numpy as np
from pathlib import Path
import random
from datetime import datetime
import Pyro5.api
import pickle
import matplotlib.pyplot as plt
import pyNSID
import SciFiReaders


## 3. Single Objective BO with DKL

### 3a. DKL model

In [None]:
import gpytorch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
import math
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional, Dict, Union, List
import numpy as np
import torch

# Simple ConvNet for feature extraction
class ConvNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels=1, output_dim=32):
        super(ConvNetFeatureExtractor, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.output_dim = output_dim
        self.fc = None  # Placeholder for the fully connected layer

    def forward(self, x):
        if len(x.shape) == 3: # TODO: hacky way to make sure botorch acquisition function works
            # flatten
            batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
            d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
            x = x.reshape(int(batch_size), int(channel), int(d), int(d))
        # Pass through the convolutional layers
        x = self.conv_layers(x)


        # If the fully connected layer is not defined yet, initialize it dynamically******************key
        if self.fc is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            device = x.device# TODO: better way to handle device
            self.fc = nn.Linear(flattened_size, self.output_dim).to(device)  # Create fc layer on the correct device

        # Flatten for fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# GP model with deep kernel using ConvNet feature extractor
class GPModelDKL(ApproximateGP):
    def __init__(self, inducing_points, likelihood, feature_extractor=None):
        if feature_extractor is None:
            feature_extractor = ConvNetFeatureExtractor(
                input_channels=1,  # Set according to your image channels
                output_dim=32      # Set as per the final feature dimension
            ).to(inducing_points.device)
        else:
            feature_extractor = feature_extractor.to(inducing_points.device)

        # Transform inducing points with ConvNet
        inducing_points = feature_extractor(inducing_points)

        # Variational setup
        variational_distribution = CholeskyVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )

        super(GPModelDKL, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.num_outputs = 1  # must be one
        self.likelihood = likelihood
        self.feature_extractor = feature_extractor

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



    def __call__(self, x, use_feature_extractor=True, *args, **kwargs):
        ## TODO: to make it compatible with botorch acquisition function we need it to make patches internally from flattened patches
        if use_feature_extractor:
            if len(x.shape) == 3:
                # flatten
                batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
                d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
                x = x.reshape(int(batch_size), int(channel), int(d), int(d))
            x = self.feature_extractor(x)
        return super().__call__(x, *args, **kwargs)

    def posterior(self, X, output_indices=None, observation_noise=False, *args, **kwargs) -> GPyTorchPosterior:
        self.eval()
        self.likelihood.eval()
        dist = self.likelihood(self(X))
        return GPyTorchPosterior(dist)

    @property
    def hparam_dict(self):
        return {
            "likelihood.noise": self.likelihood.noise.item(),
            "covar_module.base_kernel.outputscale": self.covar_module.base_kernel.outputscale.item(),
            "mean_module.constant": self.mean_module.constant.item(),
        }




### 3b. Utility F:n's - 1

In [None]:
def normalize_data(data : np.ndarray) -> np.ndarray:  # Expected data type: torch.Tensor
    """Normalize data to the [0, 1] range."""
    return (data - data.min()) / (data.max() - data.min())


def numpy_to_torch_for_conv(np_array) -> torch.Tensor:
    """
    Converts a NumPy array of shape (batch_size, a, b) to a PyTorch tensor
    with shape (batch_size, 1, a, b) for neural network use.

    Parameters:
        np_array (np.ndarray): Input NumPy array of shape (batch_size, a, b).

    Returns:
        torch.Tensor: Converted PyTorch tensor of shape (batch_size, 1, a, b).
    """
    # Check if input is a numpy array
    if not isinstance(np_array, np.ndarray):
        raise TypeError("Input must be a NumPy array.")

    # Convert to PyTorch tensor and add a channel dimension
    tensor = torch.from_numpy(np_array).float()  # Convert to float tensor
    tensor = tensor.unsqueeze(1)  # Add a channel dimension at index 1

    return tensor


######################atomai utils####################################
#Credits Maxim Ziatdinov (https://github.com/ziatdinovmax): https://github.com/pycroscopy/atomai/blob/8db3e944cd9ece68c33c8e3fcca3ef3ce9a111ea/atomai/utils/img.py#L522

def get_coord_grid(imgdata: np.ndarray, step: int,
                   return_dict: bool = True
                   ) -> Union[np.ndarray, Dict[int, np.ndarray]]:
    """
    Generate a square coordinate grid for every image in a stack. Returns coordinates
    in a dictionary format (same format as generated by atomnet.predictor)
    that can be used as an input for utility functions extracting subimages
    and atomstat.imlocal class

    Args:
        imgdata (numpy array): 2D or 3D numpy array
        step (int): distance between grid points
        return_dict (bool): returns coordiantes as a dictionary (same format as atomnet.predictor)

    Returns:
        Dictionary or numpy array with coordinates
    """
    if np.ndim(imgdata) == 2:
        imgdata = np.expand_dims(imgdata, axis=0)
    coord = []
    for i in range(0, imgdata.shape[1], step):
        for j in range(0, imgdata.shape[2], step):
            coord.append(np.array([i, j]))
    coord = np.array(coord)
    if return_dict:
        coord = np.concatenate((coord, np.zeros((coord.shape[0], 1))), axis=-1)
        coordinates_dict = {i: coord for i in range(imgdata.shape[0])}
        return coordinates_dict
    coordinates = [coord for _ in range(imgdata.shape[0])]
    return np.concatenate(coordinates, axis=0)

def get_imgstack(imgdata: np.ndarray,
                 coord: np.ndarray,
                 r: int) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at specified coordinates
    for a single image

    Args:
        imgdata (3D numpy array):
            Prediction of a neural network with dimensions
            :math:`height \\times width \\times n channels`
        coord (N x 2 numpy array):
            (x, y) coordinates
        r (int):
            Window size

    Returns:
        2-element tuple containing

        - Stack of subimages
        - (x, y) coordinates of their centers
    """
    img_cr_all = []
    com = []
    for c in coord:
        cx = int(np.around(c[0]))
        cy = int(np.around(c[1]))
        if r % 2 != 0:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2+1,
                        cy-r//2:cy+r//2+1])
        else:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2,
                        cy-r//2:cy+r//2])
        if img_cr.shape[0:2] == (int(r), int(r)) and not np.isnan(img_cr).any():
            img_cr_all.append(img_cr[None, ...])
            com.append(c[None, ...])
    if len(img_cr_all) == 0:
        return None, None
    img_cr_all = np.concatenate(img_cr_all, axis=0)
    com = np.concatenate(com, axis=0)
    return img_cr_all, com

def extract_subimages(imgdata: np.ndarray,
                      coordinates: Union[Dict[int, np.ndarray], np.ndarray],
                      window_size: int, coord_class: int = 0) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at certain atom class/type
    (usually from a neural network output)

    Args:
        imgdata (numpy array):
            4D stack of images (n, height, width, channel).
            It is also possible to pass a single 2D image.
        coordinates (dict or N x 2 numpy arry): Prediction from atomnet.locator
            (can be from other source but must be in the same format)
            Each element is a :math:`N \\times 3` numpy array,
            where *N* is a number of detected atoms/defects,
            the first 2 columns are *xy* coordinates
            and the third columns is class (starts with 0).
            It is also possible to pass N x 2 numpy array if the corresponding
            imgdata is a single 2D image.
        window_size (int):
            Side of the square for subimage cropping
        coord_class (int):
            Class of atoms/defects around around which the subimages
            will be cropped (3rd column in the atomnet.locator output)

    Returns:
        3-element tuple containing

        - stack of subimages,
        - (x, y) coordinates of their centers,
        - frame number associated with each subimage
    """
    if isinstance(coordinates, np.ndarray):
        coordinates = np.concatenate((
            coordinates, np.zeros((coordinates.shape[0], 1))), axis=-1)
        coordinates = {0: coordinates}
    if np.ndim(imgdata) == 2:
        imgdata = imgdata[None, ..., None]
    subimages_all, com_all, frames_all = [], [], []
    for i, (img, coord) in enumerate(
            zip(imgdata, coordinates.values())):
        coord_i = coord[np.where(coord[:, 2] == coord_class)][:, :2]
        stack_i, com_i = get_imgstack(img, coord_i, window_size)
        if stack_i is None:
            continue
        subimages_all.append(stack_i)
        com_all.append(com_i)
        frames_all.append(np.ones(len(com_i), int) * i)
    if len(subimages_all) > 0:
        subimages_all = np.concatenate(subimages_all, axis=0)
        com_all = np.concatenate(com_all, axis=0)
        frames_all = np.concatenate(frames_all, axis=0)
    return subimages_all, com_all, frames_all

### 3c. Utility F:n's - 2

In [None]:
#*********************************DTmic specific functions starts **********************************************#
from sklearn.metrics import mean_squared_error

def black_box(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)
    score = spectrum[e_start:e_end].sum()


    return score

def black_box_loop_height(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_height(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.max(loop_top) - np.min(loop_bottom)

    score = loop_height(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_loop_area(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def loop_area (raw_spec, cycle) :
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.abs(np.sum(loop_top)-np.sum(loop_bottom))

    score = loop_area(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

def black_box_positive_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def positive_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.mean(loop_top) - np.mean(loop_bottom)

    score = positive_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now


    return score

def black_box_negative_nucleation_bias(index, indices_all, e1a, e1b) -> float:
    """
    Black box function that returns a target score simulates the blackbox function
    """

    e_start,e_end = e1a, e1b
    array_list, shape, dtype = mic_server.get_point_data(
        spectrum_image_index="Channel_001",
        x=int(indices_all[index, 0]),######### TODO: check if x anf y needs to be flipped
        y=int(indices_all[index, 1])
    )
    spectrum = np.array(array_list, dtype=dtype).reshape(shape)

    def negative_nucleation_bias(raw_spec, cycle):
        raw_spec_len = len(raw_spec)
        cycle_len = int(raw_spec_len / cycle)
        half_len = int(cycle_len / 2)
        q_len = int(cycle_len / 4)
        loop_top, loop_bottom = [], []
        loop_top.append(raw_spec[q_len : q_len + half_len])
        loop_top.append(raw_spec[q_len + 2*half_len : q_len + 3*half_len])
        loop_top.append(raw_spec[q_len + 4*half_len : 2*q_len + 4*half_len])
        loop_bottom.append(raw_spec[:q_len])
        loop_bottom.append(raw_spec[q_len + half_len: q_len + 2*half_len])
        loop_bottom.append(raw_spec[q_len + 3*half_len: q_len + 4*half_len])
        loop_top = np.concatenate(loop_top)
        loop_bottom = np.concatenate(loop_bottom)
        return np.min(loop_top) - np.max(loop_bottom)

    score = negative_nucleation_bias(raw_spec = spectrum, cycle = 3)# TODO: hard coded cycle now



    return score

## a) evaluations metrics like nlpd, mse ----
def calculate_mse(y_true, y_pred):
    """Calculate Mean Squared Error (MSE)"""
    #Smaller values indicate better predictions.
    #Squaring ensures that positive and negative errors don't cancel out.
    mse = mean_squared_error(y_true, y_pred)
    return mse

def calculate_nlpd(y_true, y_pred_mean, y_pred_var):
    """Calculate Negative Log Predictive Density (NLPD)"""
    #NLPD evaluates how well the predicted probability distribution matches the true values.
    #Lower NLPD indicates a better match, accounting for both the mean and uncertainty.
    nlpd = 0.5 * torch.log(2 * torch.pi * y_pred_var) + 0.5 * ((y_true - y_pred_mean) ** 2 / y_pred_var)
    return nlpd.mean().item()

### 3d. Utility F:n's - 3

In [None]:
from tqdm import tqdm

def calculate_scores_for_patches(unacquired_indices, indices_all, e1a, e1b, black_box_fn = black_box, debug = True) -> torch.Tensor:
    """
    Calculate the score for each patch using the black_box function.

    Parameters:
    - patches: Tensor of all data patches.

    Returns:
    - scores: List of scores for each patch.
    """
    scores = []
    for i in unacquired_indices:
        score = black_box_fn(i, indices_all, e1a, e1b)  # Calculate score for each patch
        scores.append(score)
    return torch.tensor(scores)  # Return as a tensor for compatibility

def update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, black_box_fn = black_box) -> (np.array, list):
    for idx in selected_indices:# TODO: It queries the black box everytime on already acquired points:
        acquired_data[idx] = black_box_fn(idx, indices_all, e1a, e1b)
    unacquired_indices = [idx for idx in unacquired_indices if idx not in selected_indices]


    return acquired_data, unacquired_indices


def load_image_and_features(img: np.ndarray , window_size : int) -> (np.ndarray, np.ndarray):
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:, :, :, 0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - x.min()) / x.ptp() # or use:  norm_ = lambda x: (x - np.min(x)) / np.ptp(x) --> numpy-2.0 upgrade
    features = norm_(features_all)
    return features, coords# shapes (3366, 5, 5) and (3366, 2)


def prepare_data_from_microscope(window_size: int) -> (np.ndarray, np.ndarray):
    array_list, shape, dtype = mic_server.get_overview_image()
    img = np.array(array_list, dtype=dtype).reshape(shape)#
    features, indices_all = load_image_and_features(img, window_size)



    return img, features, indices_all# shapes (55, 70), (3366, 5, 5) and (3366, 2)

def get_spectrum_data(indices, energy_range, channel="Channel_001") -> (np.array, int, int):
    array_list, shape, dtype = mic_server.get_spectrum_image(spectrum_image_index=channel)
    spectral_img = np.array(array_list, dtype=dtype).reshape(shape)
    array_list, shape, dtype = mic_server.get_spectrum_image_e_axis(spectrum_image_index=channel)
    E_axis = np.array(array_list, dtype=dtype).reshape(shape)
    e_start, e_end = abs(E_axis - energy_range[0]).argmin(), abs(E_axis - energy_range[1]).argmin()
    return spectral_img, e_start, e_end

def embeddings_and_predictions(model, patches, device="cpu") -> (torch.Tensor, torch.Tensor):
    """
    Get predictions from the trained model
    """
    model.eval()
    patches = patches.to(device)
    with torch.no_grad():
        predictions = model(patches)
        embeddings = model.feature_extractor(patches).view(patches.size(0), -1).cpu().numpy()
    return predictions, embeddings

def train_model(acquired_data, patches, feature_extractor,
                device="cpu", num_epochs=50, log_interval=5,
                scalarizer_zero=False) -> ApproximateGP:
    X_train = torch.stack([patches[idx] for idx in acquired_data]).to(device)
    y_train = torch.tensor(list(acquired_data.values()), dtype=torch.float32).to(device)
    if scalarizer_zero:
        y_train = torch.zeros_like(y_train)

    else:
        # Normalize y_train

        y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())



    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    inducing_points = X_train[:10]
    model = GPModelDKL(inducing_points=inducing_points, likelihood=likelihood, feature_extractor=feature_extractor).to(device)

    model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training Progress"):
        optimizer.zero_grad()
        output = model(X_train)


        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()



    return model



### 3e. Bayesian optimization loop

In [None]:


def run(config) -> None:
    # Extract all configuration variables
    seed = config["seed"]
    seed_pts = config["seed_pts"]
    budget = config["budget"]
    in_dir = config["in_dir"]
    out_dir_parent = config["out_dir_parent"]
    dataset_name = config["dataset_name"]
    device = config["device"]
    num_epochs = config["num_epochs"]
    normalize_data_flag = config["normalize_data"]
    window_size = config["window_size"]
    scal_pfm = config["scal_pfm"]

    energy_range = [0, 1]# TODO: can be confusing as used in eels data - for now ignore for beps

    scalarizer_zero = False # TODO: deafult value to zero -- so passed to train_model function --> better way to handel

    if scal_pfm is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_pfm == "loop_area":
            black_box_fn = black_box_loop_height

        elif scal_pfm == "loop_height":
            black_box_fn = black_box_loop_height

        elif scal_pfm == "positive_nucleation_bias":
            black_box_fn = black_box_positive_nucleation_bias

        elif scal_pfm ==  "negative_nucleation_bias":
            black_box_fn = black_box_negative_nucleation_bias

    else :
        black_box_fn = black_box


    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    res_dir = Path(out_dir_parent) / f"Dataset_seed{seed}_{dataset_name}_BO_{seed_pts}_epochs{num_epochs}_budget_{budget}_{scal_pfm}_ws{window_size}_{timestamp}"
    res_dir.mkdir(parents=True, exist_ok=True)


    # Connect to the microscope server
    uri = "PYRO:microscope.server@localhost:9091"
    global mic_server # TODO: later see better way to do this
    mic_server = Pyro5.api.Proxy(uri)

    dataset_path = in_dir + "/" +  dataset_name
    # dataset_path =  dataset_name
    ### 2. Download data and register
    # !wget https://github.com/pycroscopy/DTMicroscope/raw/utk/data/STEM/SI/test_stem.h5
    mic_server.initialize_microscope("STEM")
    mic_server.register_data(dataset_path)


    # Prepare features and indices from microscope image
    img, features, indices_all = prepare_data_from_microscope(window_size=window_size)


    ############################################


    patches = numpy_to_torch_for_conv(features)

    # Set up energy ranges for scalarizer extraction
    spectral_img, e1a, e1b = get_spectrum_data(indices_all, energy_range)


    patches = patches.to(device)

    if normalize_data_flag:
        patches = normalize_data(patches)

    feature_extractor = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)
    acquired_data = {}
    unacquired_indices = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all

    selected_indices = random.sample(unacquired_indices, seed_pts)

    seed_indices = selected_indices
    ######### queries spectrum_image
    true_scalarizer = calculate_scores_for_patches(unacquired_indices, indices_all, e1a, e1b, black_box_fn=black_box_fn)
    true_scalarizer = (true_scalarizer - true_scalarizer.min()) / (true_scalarizer.max() - true_scalarizer.min())######## normalized


    ######### queries microscope
    acquired_data, unacquired_indices = update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, black_box_fn=black_box_fn)



    from botorch.acquisition import LogExpectedImprovement #ExpectedImprovement
    mean_y_pred_mean_al = []
    mean_y_pred_variance_al = []
    mae_list = []
    nlpd_list = []
    # Start Bayesian Optimization loop
    for step in range(budget):

        # Train the DKL model
        model = train_model(acquired_data, patches, feature_extractor, device=device, num_epochs=num_epochs, scalarizer_zero=scalarizer_zero)
        model.eval()


        # Wrap the model and likelihood in the BoTorch model ------> Ithink not needed as have approxiamateGP--> check later

        # Prepare candidate set (unacquired patches)
        candidate_indices = unacquired_indices
        
        X_candidates = torch.stack([patches[idx] for idx in candidate_indices]).to(device)
        X_candidates = X_candidates.reshape(-1, 1, window_size*window_size) # Note this is when using acq f:n directly and not invoking  optimize_acqf_discrete
        
        y_train = torch.tensor(list(acquired_data.values()), dtype=torch.float32).to(device)
        y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())

        acq_func = LogExpectedImprovement(model=model, best_f=y_train.max().to(device))

        acq_values = acq_func(X_candidates)
        best_idx = torch.argmax(acq_values).item()
        selected_candidate = X_candidates[best_idx]
        selected_index = candidate_indices[best_idx]
        # Map selected tensors back to indices
        selected_indices = [selected_index]#### can be multiple indices if batch acquisition

        # Update acquired data with new observations
        acquired_data, unacquired_indices = update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, black_box_fn=black_box_fn)

        print(f"**************************done BO step {step +1}")


        predictions, embeddings = embeddings_and_predictions(model, patches, device=device)

        y_pred_mean = predictions.mean
        y_pred_var = predictions.variance


        # Calculate MSE and NLPD
        mse = calculate_mse(true_scalarizer.cpu(), y_pred_mean.cpu())
        mae = np.sqrt(mse)
        nlpd = calculate_nlpd(true_scalarizer.cpu(), y_pred_mean.cpu(), y_pred_var.cpu())

        # Fill the prediction image with predicted mean values
        true_scalarizer_img = np.zeros((img.shape[0], img.shape[1]))
        y_pred_mean_img = np.zeros((img.shape[0], img.shape[1]))
        y_pred_var_img = np.zeros((img.shape[0], img.shape[1]))
        # Fill the prediction image with predicted mean values
        for j in range(len(indices_all)):
            true_scalarizer_img[indices_all[j][0], indices_all[j][1]] = true_scalarizer[j]
            y_pred_mean_img[indices_all[j][0], indices_all[j][1]] = y_pred_mean[j]
            y_pred_var_img[indices_all[j][0], indices_all[j][1]] = y_pred_var[j]

        # Display the images
        fig, axs = plt.subplots(2, 2, figsize=(12, 10))

        im0 = axs[0, 0].imshow(img, cmap='gray')
        axs[0, 0].set_title('Original Image with next point selection')
        axs[0, 0].scatter([int(indices_all[selected_indices[0]][1])], [int(indices_all[selected_indices[0]][0])], color='yellow', marker='x')
        fig.colorbar(im0, ax=axs[0, 0])
        
        im1 = axs[0, 1].imshow(y_pred_mean_img, cmap='viridis')
        axs[0, 1].set_title('Predicted Mean')
        fig.colorbar(im1, ax=axs[0, 1])

        im2 = axs[1, 0].imshow(y_pred_var_img, cmap='viridis')
        axs[1, 0].set_title('Predicted Variance')
        fig.colorbar(im2, ax=axs[1, 0])

        im3 = axs[1, 1].imshow(true_scalarizer_img, cmap='viridis')
        axs[1, 1].set_title('True Scalarizer')
        fig.colorbar(im3, ax=axs[1, 1])

        for ax in axs.flat:
            ax.axis('off')

        fig.suptitle(f'MAE: {mae:.4f}, NLPD: {nlpd:.4f}', fontsize=16)
        plt.tight_layout()
        plt.savefig(Path(res_dir) / f'predictions_MAE: {mae:.4f}, NLPD: {nlpd:.4f}_BO_step{step}.png')
        plt.show()
        plt.close()


        # Save predictions as a .pkl file
        predictions_data = {
            "true_scalarizer_img": true_scalarizer_img,
            "y_pred_mean_img": y_pred_mean_img,
            "y_pred_var_img": y_pred_var_img,
            "mse": mse,
            "nlpd": nlpd,
            "embeddings": embeddings,
        }


        with open(Path(res_dir) / f'predictions_BO_step{step}.pkl', 'wb') as f:
            pickle.dump(predictions_data, f)


        mean_y_pred_mean_al.append(y_pred_mean.mean().cpu())
        mean_y_pred_variance_al.append(y_pred_var.mean().cpu())
        mae_list.append(mae)
        nlpd_list.append(nlpd)

        # imshow 4 images: img, pred_mean_img, pred_var_img, true_scal_img


    # Save predictions as a .pkl file
    Active_learning_statistics = {
        "img": img,
        "features": features,
        "indices_all": np.array(indices_all),
        "seed_indices": np.array(seed_indices),
        "unacquired_indices": np.array(unacquired_indices),
        "mean_y_pred_mean_al": np.array(mean_y_pred_mean_al),
        "mean_y_pred_variance_al": np.array(mean_y_pred_variance_al),
        "mae": np.array(mae_list),
        "nlpd": np.array(nlpd_list)
                }

    with open(Path(res_dir) / f'Active_learning_statistics.pkl', 'wb') as f:
        pickle.dump(Active_learning_statistics, f)


    ##############################

    predictions_data = Active_learning_statistics
    # Extract necessary data
    img = np.array(predictions_data["img"])  # Image or grid for background visualization
    seed_indices = np.array(predictions_data["seed_indices"])  # Initial sampled indices (referring to positions in indices_all)
    unacquired_indices = np.array(predictions_data["unacquired_indices"])  # Remaining indices
    indices_all = np.array(predictions_data["indices_all"])  # All possible indices (coordinates)

    # Map seed_indices and unacquired_indices to their coordinates in indices_all
    seed_coords = indices_all[seed_indices]
    unacquired_coords = indices_all[unacquired_indices]

    # Calculate acquired indices as the complement of unacquired and seed indices
    acquired_indices = np.setdiff1d(np.arange(indices_all.shape[0]), np.union1d(seed_indices, unacquired_indices), assume_unique=True)
    acquired_coords = indices_all[acquired_indices]

    # Plot the results
    plt.figure(figsize=(10, 8))

    # Display the image or grid as the background
    plt.imshow(img, cmap="gray")#, origin="lower")

    # Plot the seed points in blue
    plt.scatter(seed_coords[:, 1], seed_coords[:, 0], c="b", label="Seed Points", marker="o")

    time_order = np.arange(len(acquired_coords))  # Create a sequence representing time
    scatter = plt.scatter(acquired_coords[:, 1], acquired_coords[:, 0], c=time_order, cmap="bwr", label="Acquired Points", marker="x")

    # Plot the unacquired points in green
    # plt.scatter(unacquired_coords[:, 1], unacquired_coords[:, 0], c="g", label="Unacquired Points", marker="+")

    # Set plot labels and legend
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title("Active Learning Trajectory")
    plt.legend()
    plt.grid(True)
    # Add a colorbar and label it as "Steps"
    cbar = plt.colorbar(scatter)
    cbar.set_label("Steps")


    plt.savefig(Path(res_dir) / "AL_traj.png")
    plt.show()
    plt.close()


    # Extract data for learning curve
    mean_y_pred_mean_al = np.array(predictions_data["mean_y_pred_mean_al"])
    mean_y_pred_variance_al = np.array(predictions_data["mean_y_pred_variance_al"])
    mae_list = np.array(predictions_data["mae"])
    nlpd_list = np.array(predictions_data["nlpd"])

    steps = np.arange(len(mean_y_pred_mean_al))  # Assuming the steps are sequential indices

    # Calculate the upper and lower bounds using variance
    upper_bound = mean_y_pred_mean_al + np.sqrt(mean_y_pred_variance_al)
    lower_bound = mean_y_pred_mean_al - np.sqrt(mean_y_pred_variance_al)

    # Plot the learning curve
    plt.figure(figsize=(10, 6))

    # Plot the mean predictions
    plt.plot(steps, mean_y_pred_mean_al, label="Mean Prediction", color="blue", linewidth=2)

    # Fill between the upper and lower bounds to represent variance
    plt.fill_between(
        steps,
        lower_bound,
        upper_bound,
        color="blue",
        alpha=0.2,
        label="Variance (±1 std)"
    )

    # Add labels, title, and legend
    plt.xlabel("Steps")
    plt.ylabel("Mean Prediction")
    plt.title("Learning Curve with Variance")
    plt.legend()
    plt.grid(True)
    plt.savefig(Path(res_dir) /"AL_learning_curve.png")
    plt.show()
    plt.close()


    plt.figure(figsize=(10, 6))

    plt.plot(steps, mae_list, color="red", linewidth=2)

    plt.xlabel("Steps")
    plt.ylabel("Mean absolute ERROR")
    plt.grid(True)
    plt.savefig(Path(res_dir) / "AL_error_curve.png")
    plt.show()
    plt.close()


### 3f. Set parameters and Run experiments

In [None]:
config = {
        "seed" : 1, # for repeatibility
        "seed_pts" : 10, # How many points you want to start your BO with?
        "budget" : 50, # How many experimental budget you have?
        "in_dir": "/nfs/home/upratius/scratch_i24/projects/gp_experimetns/GpyDKLBO/in_dir", # recommended : leave as is
        "out_dir_parent": "out", # recommended : leave as is
        "dataset_name": "SiN-Au-Overview.h5", # name of data to be loaded in DTmicroscope
        "device": "cuda",
        "num_epochs": 100, # Number of epoch the dkl model trains at each experimental step - Might need tuning based on data
        "normalize_data": True, 
        "window_size": 16, # For square patches - structure property relationship
        "scal_pfm": None # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias"
        }
run(config)