
# Single-objective Active learning using DigitalTwin microscope: Stochastic Variational Deep kernel learning in Gpytorch and BO loop in Botorch. [Recommended to take a GPU instance]
Prepared by [Utkarsh Pratiush](https://github.com/utkarshp1161)
- [parent notebook link where this is tested on Digital twin microscope](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/single_objective_BO_SVDKL-novelty-search.ipynb)



## which cuda 
- !export CUDA_VISIBLE_DEVICES=1
- also note that -- edx
    - dispersion - currentlyl 20ev / channel
    - sum - over enire counts

In [None]:
import numpy as np
from pathlib import Path
import random
from datetime import datetime
import pickle
import matplotlib.pyplot as plt

## 3. Single Objective Bayesian optimization with DKL

### 3a. DKL model 

In [None]:
import gpytorch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
import math
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional, Dict, Union, List
import numpy as np
import torch

# Simple ConvNet for feature extraction
class ConvNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels=1, output_dim=32):
        super(ConvNetFeatureExtractor, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.output_dim = output_dim
        self.fc = None  # Placeholder for the fully connected layer

    def forward(self, x):
        if len(x.shape) == 3: # TODO: hacky way to make sure botorch acquisition function works
            # flatten
            batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
            d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
            x = x.reshape(int(batch_size), int(channel), int(d), int(d))
        # Pass through the convolutional layers
        x = self.conv_layers(x)


        # If the fully connected layer is not defined yet, initialize it dynamically******************key
        if self.fc is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            device = x.device# TODO: better way to handle device
            self.fc = nn.Linear(flattened_size, self.output_dim).to(device)  # Create fc layer on the correct device

        # Flatten for fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# GP model with deep kernel using ConvNet feature extractor
class GPModelDKL(ApproximateGP):
    def __init__(self, inducing_points, likelihood, feature_extractor=None):
        if feature_extractor is None:
            feature_extractor = ConvNetFeatureExtractor(
                input_channels=1,  # Set according to your image channels
                output_dim=32      # Set as per the final feature dimension
            ).to(inducing_points.device)
        else:
            feature_extractor = feature_extractor.to(inducing_points.device)

        # Transform inducing points with ConvNet
        inducing_points = feature_extractor(inducing_points)

        # Variational setup
        variational_distribution = CholeskyVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )

        super(GPModelDKL, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.num_outputs = 1  # must be one
        self.likelihood = likelihood
        self.feature_extractor = feature_extractor

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)



    def __call__(self, x, use_feature_extractor=True, *args, **kwargs):
        ## TODO: to make it compatible with botorch acquisition function we need it to make patches internally from flattened patches
        if use_feature_extractor:
            if len(x.shape) == 3:
                # flatten
                batch_size, channel, mn = x.shape[0], x.shape[1] , x.shape[2]
                d = math.sqrt(mn)      ## TODO: what if mn is not a perfect square?
                x = x.reshape(int(batch_size), int(channel), int(d), int(d))
            x = self.feature_extractor(x)
        return super().__call__(x, *args, **kwargs)

    def posterior(self, X, output_indices=None, observation_noise=False, *args, **kwargs) -> GPyTorchPosterior:
        self.eval()
        self.likelihood.eval()
        dist = self.likelihood(self(X))
        return GPyTorchPosterior(dist)

    @property
    def hparam_dict(self):
        return {
            "likelihood.noise": self.likelihood.noise.item(),
            "covar_module.base_kernel.outputscale": self.covar_module.base_kernel.outputscale.item(),
            "mean_module.constant": self.mean_module.constant.item(),
        }




### 3b. Utility F:n's - 1

In [None]:
def normalize_data(data : np.ndarray) -> np.ndarray:  # Expected data type: torch.Tensor
    """Normalize data to the [0, 1] range."""
    return (data - data.min()) / (data.max() - data.min())


def numpy_to_torch_for_conv(np_array) -> torch.Tensor:
    """
    Converts a NumPy array of shape (batch_size, a, b) to a PyTorch tensor
    with shape (batch_size, 1, a, b) for neural network use.

    Parameters:
        np_array (np.ndarray): Input NumPy array of shape (batch_size, a, b).

    Returns:
        torch.Tensor: Converted PyTorch tensor of shape (batch_size, 1, a, b).
    """
    # Check if input is a numpy array
    if not isinstance(np_array, np.ndarray):
        raise TypeError("Input must be a NumPy array.")

    # Convert to PyTorch tensor and add a channel dimension
    tensor = torch.from_numpy(np_array).float()  # Convert to float tensor
    tensor = tensor.unsqueeze(1)  # Add a channel dimension at index 1

    return tensor


######################atomai utils####################################
#Credits Maxim Ziatdinov (https://github.com/ziatdinovmax): https://github.com/pycroscopy/atomai/blob/8db3e944cd9ece68c33c8e3fcca3ef3ce9a111ea/atomai/utils/img.py#L522

def get_coord_grid(imgdata: np.ndarray, step: int,
                   return_dict: bool = True
                   ) -> Union[np.ndarray, Dict[int, np.ndarray]]:
    """
    Generate a square coordinate grid for every image in a stack. Returns coordinates
    in a dictionary format (same format as generated by atomnet.predictor)
    that can be used as an input for utility functions extracting subimages
    and atomstat.imlocal class

    Args:
        imgdata (numpy array): 2D or 3D numpy array
        step (int): distance between grid points
        return_dict (bool): returns coordiantes as a dictionary (same format as atomnet.predictor)

    Returns:
        Dictionary or numpy array with coordinates
    """
    if np.ndim(imgdata) == 2:
        imgdata = np.expand_dims(imgdata, axis=0)
    coord = []
    for i in range(0, imgdata.shape[1], step):
        for j in range(0, imgdata.shape[2], step):
            coord.append(np.array([i, j]))
    coord = np.array(coord)
    if return_dict:
        coord = np.concatenate((coord, np.zeros((coord.shape[0], 1))), axis=-1)
        coordinates_dict = {i: coord for i in range(imgdata.shape[0])}
        return coordinates_dict
    coordinates = [coord for _ in range(imgdata.shape[0])]
    return np.concatenate(coordinates, axis=0)

def get_imgstack(imgdata: np.ndarray,
                 coord: np.ndarray,
                 r: int) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at specified coordinates
    for a single image

    Args:
        imgdata (3D numpy array):
            Prediction of a neural network with dimensions
            :math:`height \\times width \\times n channels`
        coord (N x 2 numpy array):
            (x, y) coordinates
        r (int):
            Window size

    Returns:
        2-element tuple containing

        - Stack of subimages
        - (x, y) coordinates of their centers
    """
    img_cr_all = []
    com = []
    for c in coord:
        cx = int(np.around(c[0]))
        cy = int(np.around(c[1]))
        if r % 2 != 0:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2+1,
                        cy-r//2:cy+r//2+1])
        else:
            img_cr = np.copy(
                imgdata[cx-r//2:cx+r//2,
                        cy-r//2:cy+r//2])
        if img_cr.shape[0:2] == (int(r), int(r)) and not np.isnan(img_cr).any():
            img_cr_all.append(img_cr[None, ...])
            com.append(c[None, ...])
    if len(img_cr_all) == 0:
        return None, None
    img_cr_all = np.concatenate(img_cr_all, axis=0)
    com = np.concatenate(com, axis=0)
    return img_cr_all, com

def extract_subimages(imgdata: np.ndarray,
                      coordinates: Union[Dict[int, np.ndarray], np.ndarray],
                      window_size: int, coord_class: int = 0) -> Tuple[np.ndarray]:
    """
    Extracts subimages centered at certain atom class/type
    (usually from a neural network output)

    Args:
        imgdata (numpy array):
            4D stack of images (n, height, width, channel).
            It is also possible to pass a single 2D image.
        coordinates (dict or N x 2 numpy arry): Prediction from atomnet.locator
            (can be from other source but must be in the same format)
            Each element is a :math:`N \\times 3` numpy array,
            where *N* is a number of detected atoms/defects,
            the first 2 columns are *xy* coordinates
            and the third columns is class (starts with 0).
            It is also possible to pass N x 2 numpy array if the corresponding
            imgdata is a single 2D image.
        window_size (int):
            Side of the square for subimage cropping
        coord_class (int):
            Class of atoms/defects around around which the subimages
            will be cropped (3rd column in the atomnet.locator output)

    Returns:
        3-element tuple containing

        - stack of subimages,
        - (x, y) coordinates of their centers,
        - frame number associated with each subimage
    """
    if isinstance(coordinates, np.ndarray):
        coordinates = np.concatenate((
            coordinates, np.zeros((coordinates.shape[0], 1))), axis=-1)
        coordinates = {0: coordinates}
    if np.ndim(imgdata) == 2:
        imgdata = imgdata[None, ..., None]
    subimages_all, com_all, frames_all = [], [], []
    for i, (img, coord) in enumerate(
            zip(imgdata, coordinates.values())):
        coord_i = coord[np.where(coord[:, 2] == coord_class)][:, :2]
        stack_i, com_i = get_imgstack(img, coord_i, window_size)
        if stack_i is None:
            continue
        subimages_all.append(stack_i)
        com_all.append(com_i)
        frames_all.append(np.ones(len(com_i), int) * i)
    if len(subimages_all) > 0:
        subimages_all = np.concatenate(subimages_all, axis=0)
        com_all = np.concatenate(com_all, axis=0)
        frames_all = np.concatenate(frames_all, axis=0)
    return subimages_all, com_all, frames_all

### 3c. Utility F:n's - 2

In [None]:
#*********************************DTmic specific functions starts **********************************************#
from sklearn.metrics import mean_squared_error

# setup edx acquisition
import autoscript_tem_toolkit.vision as vision_toolkit
from autoscript_tem_microscope_client.structures import RunOptiStemSettings, RunStemAutoFocusSettings, Point, StagePosition, AdornedImage, EdsAcquisitionSettings, AdornedSpectrum,  StemAcquisitionSettings, StageVelocity, EdsSpectrumImageSettings
from autoscript_tem_microscope_client.enumerations import DetectorType, CameraType, OptiStemMethod, OpticalMode, EdsDetectorType, ExposureTimeType




def configure_acquisition(exposure_time=2):
    """Configure the EDS acquisition settings."""
    # mic_server is global variable intriduced in def run function
    microscope = mic_server
    eds_detector_name = microscope.detectors.eds_detectors[0]
    eds_detector = microscope.detectors.get_eds_detector(eds_detector_name)
    # Configure the acquisition
    global eds_settings
    eds_settings = EdsAcquisitionSettings()
    eds_settings.eds_detector = eds_detector_name
    eds_settings.dispersion = eds_detector.dispersions[-1]
    eds_settings.shaping_time = eds_detector.shaping_times[-1]
    eds_settings.exposure_time = exposure_time
    eds_settings.exposure_time_type = ExposureTimeType.LIVE_TIME
    return eds_settings

def get_channel_index(energy_keV: float, dispersion: float, offset: float) -> int:
    """Convert energy (keV) into spectrum channel index."""
    return int(round((energy_keV - offset) / dispersion))

import xmltodict
import json
import numpy as np

def get_dispersion_and_offset(spectrum):
    """
    Extract dispersion and offset from EDS spectrum metadata (xml).
    Returns (dispersion_keV_per_ch, offset_keV).
    """
    xml_string = spectrum.metadata.metadata_as_xml
    metadata = xmltodict.parse(xml_string)
    metadata = json.loads(json.dumps(metadata))

    detectors = metadata["Metadata"]["Detectors"]["AnalyticalDetector"]

    # If only one detector, wrap it into a list
    if isinstance(detectors, dict):
        detectors = [detectors]

    # Take the first detector (or filter by name if needed)
    det = detectors[0]
    dispersion = float(det.get("Dispersion", 0))
    offset = float(det.get("OffsetEnergy", 0))

    return dispersion, offset


def get_eds_black_box(index, indices_all, e1a, e1b, eds_settings, image_size, element="sum") -> float:
    """
    Black box function that returns a target score based on EDS peak intensities.
    """
    # Move paused beam to location
    x = int(indices_all[index, 0]) / image_size
    y = int(indices_all[index, 1]) / image_size
    print("collecting spectrum at fractional coord", x, y, "true coord", x*512, y*512)
    mic_server.optics.paused_scan_beam_position = [x, y]  # (0, 0) = top left corner
    import time
    # time.sleep(12)  # wait 2 seconds

    # Acquire EDS spectrum
    mic_server.optics.unblank()
    spectrum = mic_server.analysis.eds.acquire_spectrum(eds_settings)
    mic_server.optics.blank()
    
    plt.imshow(img, cmap='gray', origin="upper")
    plt.scatter(x*image_size, y*image_size, marker="o", c="y")

    # Average spectrum data from 4 detectors
    n_channels_per_detector = len(spectrum.data) // 4
    summed_spectrum = np.zeros(n_channels_per_detector)
    
    for i in range(4):
        start_idx = i * n_channels_per_detector
        end_idx = (i + 1) * n_channels_per_detector
        summed_spectrum += spectrum.data[start_idx:end_idx]
    
    # Use summed spectrum for analysis
    spectrum_data = summed_spectrum

    # Plot spectrum using matplotlib instead of vision_toolkit
    dispersion, offset = get_dispersion_and_offset(spectrum)
    energy_axis = (np.arange(len(spectrum_data)) * dispersion + offset)/1000 # 1000 for Kev
    
    plt.figure(figsize=(12, 6))
    plt.plot(energy_axis, spectrum_data)
    plt.xlabel('Energy (keV)')
    plt.ylabel('Counts')
    plt.title('EDS Spectrum (Summed from 4 Detectors)')
    plt.xlim(0, 20)  # Focus on physically relevant energy range
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Define element peaks (in keV) with their IDs
    element_peaks = {
        "Al": {"Al_Ka": 1.487},
        "Cu": {"Cu_La": 0.930, "Cu_Ka": 8.048, "Cu_Kb": 8.905},
        "Zr": {"Zr_La": 2.04, "Zr_Lb": 2.12, "Zr_Ka": 15.77, "Zr_Kb": 17.67},
        "Sb": {"Sb_La": 3.605, "Sb_Lb": 3.844, "Sb_Ka": 26.359, "Sb_Kb": 29.725}
    }
    
    # Mark element locations on the plot
    colors = {"Al": "red", "Cu": "blue", "Zr": "green", "Sb": "orange"}
    for elem, peaks in element_peaks.items():
        for peak_name, energy in peaks.items():
            if energy <= 30:  # Only mark peaks within visible range
                plt.axvline(x=energy, color=colors[elem], linestyle='--', alpha=0.7, 
                           label=f'{elem} ({peak_name})' if peak_name == list(peaks.keys())[0] else "")
    
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(f"{out_path}_time_{timestamp}_(x,y)_{x,y}disp_{dispersion}_offset_{offset}_4det_summed.png")
    plt.show()

    window_halfwidth = 3  # channels to integrate around each peak
    total_counts = 0

    # Calculate score based on element parameter
    if element == "sum":
        # Sum entire spectrum
        total_counts = spectrum_data.sum()
        print(f"Total spectrum counts: {total_counts}")
    elif element in element_peaks:
        # Sum counts for specified element
        for peak_name, energy in element_peaks[element].items():
            center_idx = get_channel_index(energy, dispersion, offset)
            start = max(0, center_idx - window_halfwidth)
            end = min(len(spectrum_data), center_idx + window_halfwidth)
            peak_counts = spectrum_data[start:end].sum()
            total_counts += peak_counts
            print(f"{element} {peak_name} counts: {peak_counts}")
        print(f"Total {element} counts: {total_counts}")
    else:
        print(f"Warning: Element '{element}' not recognized. Available: Al, Cu, Zr, Sb, sum")
        total_counts = 0

    # Get all element IDs for reference
    all_element_ids = {}
    for elem, peaks in element_peaks.items():
        all_element_ids[elem] = list(peaks.keys())
    print("Available element peak IDs:", all_element_ids)
    
    # Save spectrum to disk (save the summed spectrum)
    spec_array = summed_spectrum  # Save the processed spectrum
    dispersion, offset = get_dispersion_and_offset(spectrum=spectrum)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    np.save(f"{out_path}_time_{timestamp}_(x,y)_{x,y}disp_{dispersion}_offset_{offset}_4det_summed.npy", spec_array)
    print("total counts", total_counts)
    return total_counts

# def get_eds_black_box(index, indices_all, e1a, e1b, eds_settings, image_size) -> float:
#     """
#     Black box function that returns a target score simulates the blackbox function
#     """

#     e_start,e_end = e1a, e1b
#     # move paused beam to the location index
#     x=int(indices_all[index, 0])/image_size######### TODO: check if x anf y needs to be flipped
#     y=int(indices_all[index, 1])/image_size
#     mic_server.optics.paused_scan_beam_position = list(x, y)#----> (0, 0) should be top left corner 
    
#     # collect eds
#     microscope.optics.unblank()
#     spectrum = mic_server.analysis.eds.acquire_spectrum(eds_settings)
#     microscope.optics.blank()
#     score = spectrum[e_start:e_end].sum()

#     return score




## a) evaluations metrics like nlpd, mse ----
def calculate_mse(y_true, y_pred):
    """Calculate Mean Squared Error (MSE)"""
    #Smaller values indicate better predictions.
    #Squaring ensures that positive and negative errors don't cancel out.
    mse = mean_squared_error(y_true, y_pred)
    return mse

def calculate_nlpd(y_true, y_pred_mean, y_pred_var):
    """Calculate Negative Log Predictive Density (NLPD)"""
    #NLPD evaluates how well the predicted probability distribution matches the true values.
    #Lower NLPD indicates a better match, accounting for both the mean and uncertainty.
    nlpd = 0.5 * torch.log(2 * torch.pi * y_pred_var) + 0.5 * ((y_true - y_pred_mean) ** 2 / y_pred_var)
    return nlpd.mean().item()

### 3d. Utility F:n's - 3

In [None]:
from tqdm import tqdm

def calculate_scores_for_patches(unacquired_indices, indices_all, e1a, e1b, black_box_fn = get_eds_black_box, debug = True) -> torch.Tensor:
    """
    Calculate the score for each patch using the black_box function.

    Parameters:
    - patches: Tensor of all data patches.

    Returns:
    - scores: List of scores for each patch.
    """
    scores = []
    for i in unacquired_indices:
        score = black_box_fn(i, indices_all, e1a, e1b)  # Calculate score for each patch
        scores.append(score)
    return torch.tensor(scores)  # Return as a tensor for compatibility

def update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, eds_settings, image_size, black_box_fn = get_eds_black_box) -> (np.array, list):
    for idx in selected_indices:# TODO: It queries the black box everytime on already acquired points:
        acquired_data[idx] = black_box_fn(idx, indices_all, e1a, e1b, eds_settings, image_size)
    unacquired_indices = [idx for idx in unacquired_indices if idx not in selected_indices]


    return acquired_data, unacquired_indices


def load_image_and_features(img: np.ndarray , window_size : int) -> (np.ndarray, np.ndarray):
    coordinates = get_coord_grid(img, step=1, return_dict=False)
    features_all, coords, _ = extract_subimages(img, coordinates, window_size)
    features_all = features_all[:, :, :, 0]
    coords = np.array(coords, dtype=int)
    norm_ = lambda x: (x - x.min()) / np.ptp(x)
    features = norm_(features_all)
    return features, coords# shapes (3366, 5, 5) and (3366, 2)


def prepare_data_from_microscope(window_size: int, haadf: np.ndarray) -> (np.ndarray, np.ndarray):
    global img # TODO: better way to deal with this --> at this point need to plot it when collcting spectrum
    img = haadf
    features, indices_all = load_image_and_features(img, window_size)
    return img, features, indices_all# shapes (55, 70), (3366, 5, 5) and (3366, 2)



def embeddings_and_predictions(model, patches, device="cpu") -> (torch.Tensor, torch.Tensor):
    """
    Get predictions from the trained model
    """
    model.eval()
    patches = patches.to(device)
    with torch.no_grad():
        predictions = model(patches)
        embeddings = model.feature_extractor(patches).view(patches.size(0), -1).cpu().numpy()
    return predictions, embeddings

def train_model(acquired_data, patches, feature_extractor,
                device="cpu", num_epochs=50, log_interval=5,
                scalarizer_zero=False) -> ApproximateGP:
    X_train = torch.stack([patches[idx] for idx in acquired_data]).to(device)
    y_train = torch.tensor(list(acquired_data.values()), dtype=torch.float32).to(device)
    if scalarizer_zero:
        y_train = torch.zeros_like(y_train)

    else:
        # Normalize y_train

        y_train = (y_train - y_train.min()) / (y_train.max() - y_train.min())



    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    inducing_points = X_train[:10]
    model = GPModelDKL(inducing_points=inducing_points, likelihood=likelihood, feature_extractor=feature_extractor).to(device)

    model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training Progress"):
        optimizer.zero_grad()
        output = model(X_train)


        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()



    return model



### 3e. Bayesian optimization loop

In [None]:
from stemOrchestrator.acquisition import TFacquisition, DMacquisition
from stemOrchestrator.simulation import DMtwin
from autoscript_tem_microscope_client.enumerations import EdsDetectorType
from stemOrchestrator.process import HAADF_tiff_to_png, tiff_to_png
from autoscript_tem_microscope_client import TemMicroscopeClient
import matplotlib.pyplot as plt
import logging
plot = plt
from typing import Dict
import os

In [None]:
import torch
from botorch.acquisition import AnalyticAcquisitionFunction
from botorch.utils.transforms import t_batch_mode_transform

class BEACONAcquisitionFunction(AnalyticAcquisitionFunction):
    def __init__(self, model, acquired_data, indices_all, patches, k=10, elite_fraction=0.1, normalize=True):
        """
        BEACON acquisition function for DKL model using Thompson sampling
        
        Args:
            model: Trained DKL model
            acquired_data: Dictionary of {index: value} for acquired points
            indices_all: All coordinate indices
            patches: All image patches
            k: Number of nearest neighbors to consider
            elite_fraction: Fraction of top points to use as elite set
            normalize: Whether to apply z-score normalization
        """
        super().__init__(model=model)
        self.model = model
        self.k = k
        self.normalize = normalize
        
        # Build elite set from acquired data
        if len(acquired_data) > 0:
            # Get top performers (elite set)
            sorted_items = sorted(acquired_data.items(), key=lambda x: x[1], reverse=True)
            n_elite = max(1, int(len(sorted_items) * elite_fraction))
            elite_indices = [idx for idx, _ in sorted_items[:n_elite]]
            
            # Get elite patches and compute posterior means (keep this as means for reference set)
            elite_patches = torch.stack([patches[idx] for idx in elite_indices]).to(model.likelihood.noise.device)
            elite_patches = elite_patches.reshape(-1, 1, patches.shape[-1] * patches.shape[-2])
            
            with torch.no_grad():
                elite_posterior = model.posterior(elite_patches)
                self.elite_behaviors = elite_posterior.mean.detach()  # Keep as means for stable reference
                
                # Normalize if requested
                if self.normalize and self.elite_behaviors.numel() > 1:
                    self.behavior_mean = self.elite_behaviors.mean()
                    self.behavior_std = self.elite_behaviors.std()
                    if self.behavior_std > 1e-8:  # Avoid division by zero
                        self.elite_behaviors = (self.elite_behaviors - self.behavior_mean) / self.behavior_std
        else:
            self.elite_behaviors = torch.empty(0, 1).to(model.likelihood.noise.device)
            self.behavior_mean = 0.0
            self.behavior_std = 1.0
        
    @t_batch_mode_transform(expected_q=1)
    def forward(self, X):
        """Compute BEACON acquisition function value using Thompson sampling"""
        if self.elite_behaviors.numel() == 0:
            return torch.zeros(X.shape[0]).to(X.device)
            
        # Use Thompson sampling for candidate predictions
        with torch.no_grad():
            posterior = self.model.posterior(X)
            # Thompson sampling: sample from the posterior distribution
            candidate_samples = posterior.rsample(torch.Size([1])).squeeze(0)
            
            # Ensure proper shape
            if candidate_samples.dim() == 1:
                candidate_samples = candidate_samples.unsqueeze(1)
                
            # Apply same normalization to candidate samples
            if self.normalize and hasattr(self, 'behavior_std') and self.behavior_std > 1e-8:
                candidate_samples = (candidate_samples - self.behavior_mean) / self.behavior_std
        
        # Calculate distances to elite behaviors (means)
        dist = torch.cdist(candidate_samples, self.elite_behaviors)  # Shape: (batch_size, n_elite)
        
        # Sort distances for each candidate point
        dist_sorted, _ = torch.sort(dist, dim=1)
        
        # Calculate k-NN distances and average them
        n_elite = dist_sorted.size(1)
        k_actual = min(self.k, n_elite)
        
        if k_actual > 0:
            # Take average of k nearest distances
            knn_distances = dist_sorted[:, :k_actual].mean(dim=1)
        else:
            knn_distances = torch.zeros(candidate_samples.size(0)).to(X.device)
        
        return knn_distances.flatten()


In [None]:


def run(config) -> None:
    # Extract all configuration variables
    ip = config["ip"]
    port = config["port"]
    seed = config["seed"]
    seed_pts = config["seed_pts"]
    budget = config["budget"]
    out_dir_parent = config["out_dir_parent"]
    dataset_name = config["dataset_name"]
    device = config["device"]
    num_epochs = config["num_epochs"]
    normalize_data_flag = config["normalize_data"]
    window_size = config["window_size"]
    scal_stem = config["scal_stem"]
    haadf_exposure = config["haadf_exposure"]
    haadf_resolution = config["haadf_resolution"]
    edx_exposure = config["edx_exposure"]



    scalarizer_zero = False # TODO: deafult value to zero -- so passed to train_model function --> better way to handel

    if scal_stem is not None:## only for pfm: TODO : find better to accomodiate this
        if scal_stem == "sum":
            black_box_fn = get_eds_black_box
            #energy_range to sum in--> TODO: idea to sum for an element like zirconium
            e1a = 0
            e1b = 20 ###### till dispersion?


    else :
        print("what scalarizer you want")


    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    res_dir = Path(out_dir_parent) / f"Dataset_seed{seed}_{dataset_name}_BO_{seed_pts}_epochs{num_epochs}_budget_{budget}_{scal_stem}_ws{window_size}_{timestamp}"
    res_dir.mkdir(parents=True, exist_ok=True)


    # Connect to the microscope server
    global mic_server # TODO: later see better way to do this
    microscope = TemMicroscopeClient()
    microscope.connect(ip, port = port)# 7521 on velox  computer
    mic_server = microscope
    tf_acquisition = TFacquisition(microscope=microscope)


    # get haadf for dkl
    haadf_np_array, haadf_tiff_name = tf_acquisition.acquire_haadf(exposure = haadf_exposure, resolution=haadf_resolution, folder_path = out_dir_parent)
    image_size = haadf_resolution
    HAADF_tiff_to_png(out_dir_parent + haadf_tiff_name)

    # Prepare features and indices from microscope image
    img, features, indices_all = prepare_data_from_microscope(window_size=window_size, haadf=haadf_np_array)


    ############################################
    patches = numpy_to_torch_for_conv(features)
    patches = patches.to(device)

    if normalize_data_flag:
        patches = normalize_data(patches)

    feature_extractor = ConvNetFeatureExtractor(input_channels=1, output_dim=2).to(device)
    acquired_data = {}
    unacquired_indices = list(range(len(indices_all)))####### TODO: need to change later to use the indices_all

    # points to randompy sample from
    selected_indices = random.sample(unacquired_indices, seed_pts)
    seed_indices = selected_indices


    ######### queries microscope to get measuremnt on seed points
    eds_settings = configure_acquisition(exposure_time=edx_exposure)
    acquired_data, unacquired_indices = update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, eds_settings, image_size, black_box_fn=black_box_fn)



    from botorch.acquisition import LogExpectedImprovement #ExpectedImprovement
    mean_y_pred_mean_al = []
    mean_y_pred_variance_al = []
    # mae_list = []
    # nlpd_list = []
    # Start Bayesian Optimization loop
    for step in range(budget):

        # Train the DKL model
        model = train_model(acquired_data, patches, feature_extractor, device=device, num_epochs=num_epochs, scalarizer_zero=scalarizer_zero)
        model.eval()


        # Prepare candidate set
        candidate_indices = unacquired_indices
        
        # # Before calling BEACON
        max_candidates_for_beacon = 1000
        if len(candidate_indices) > max_candidates_for_beacon:
            # Randomly subsample or use uncertainty-based selection
            print("random sample due to compute issues")
            candidate_indices = random.sample(candidate_indices, max_candidates_for_beacon)

        X_candidates = torch.stack([patches[idx] for idx in candidate_indices]).to(device)
        X_candidates = X_candidates.reshape(-1, 1, window_size*window_size)

        acq_values_candidates = None
        if len(acquired_data) > 0:
            beacon_acq_func = BEACONAcquisitionFunction(
                model=model, 
                acquired_data=acquired_data,  # Pass the full acquired_data dict
                indices_all=indices_all,
                patches=patches,
                k=10,
                elite_fraction=0.2,  # Use top 20% as elite set
                normalize=True
            )
            
            acq_values_candidates  = beacon_acq_func(X_candidates)
            best_idx = torch.argmax(acq_values_candidates).item()
        else:
            best_idx = torch.randint(0, len(candidate_indices), (1,)).item()
        
        selected_index = candidate_indices[best_idx]
        selected_indices = [selected_index]
        
        selected_index = candidate_indices[best_idx]
        # Map selected tensors back to indices
        selected_indices = [selected_index]#### can be multiple indices if batch acquisition

        # Update acquired data with new observations
        acquired_data, unacquired_indices = update_acquired(acquired_data, unacquired_indices, selected_indices, indices_all, e1a, e1b, eds_settings, image_size, black_box_fn=black_box_fn)

        print(f"**************************done BO step {step +1}")


        predictions, embeddings = embeddings_and_predictions(model, patches, device=device)

        y_pred_mean = predictions.mean
        y_pred_var = predictions.variance

        candidate_acq_dict = {candidate_indices[i]: acq_values_candidates[i].item() for i in range(len(candidate_indices))}

        # Calculate MSE and NLPD
        # mse = calculate_mse(true_scalarizer.cpu(), y_pred_mean.cpu())
        # mae = np.sqrt(mse)
        # nlpd = calculate_nlpd(true_scalarizer.cpu(), y_pred_mean.cpu(), y_pred_var.cpu())

        if acq_values_candidates is None:
            print("no plotting as acq_values not triggered yet as len of acquired data not sufficient")
        else:
            print("plotting=========================")
            # Fill the prediction image with predicted mean values
            acq_fn_img = np.zeros((img.shape[0], img.shape[1]))
            y_pred_mean_img = np.zeros((img.shape[0], img.shape[1]))
            y_pred_var_img = np.zeros((img.shape[0], img.shape[1]))
            # Fill the prediction image with predicted mean values
            for j in range(len(indices_all)):
                # Fill acq_fn_img: non-zero only for candidates, zero for acquired points
                if j in candidate_acq_dict:
                    acq_fn_img[indices_all[j][0], indices_all[j][1]] = candidate_acq_dict[j]
                # else remains 0 (for acquired points)
                
                y_pred_mean_img[indices_all[j][0], indices_all[j][1]] = y_pred_mean[j]
                y_pred_var_img[indices_all[j][0], indices_all[j][1]] = y_pred_var[j]

            # Display the images
            fig, axs = plt.subplots(2, 2, figsize=(12, 10))

            # original overview image
            im0 = axs[0, 0].imshow(img, cmap='gray', origin = "upper")
            axs[0, 0].set_title('Original Image with next point selection')
            axs[0, 0].scatter([int(indices_all[selected_indices[0]][0])], [int(indices_all[selected_indices[0]][1])], color='yellow', marker='x')
            fig.colorbar(im0, ax=axs[0, 0])
            
            # preicted mean
            im1 = axs[0, 1].imshow(y_pred_mean_img, cmap='viridis', origin = "upper")
            axs[0, 1].set_title('Predicted Mean')
            fig.colorbar(im1, ax=axs[0, 1])

            # predicted variance
            im2 = axs[1, 0].imshow(y_pred_var_img, cmap='viridis', origin = "upper")
            axs[1, 0].set_title('Predicted Variance')
            fig.colorbar(im2, ax=axs[1, 0])

            # Acquisition Function
            im3 = axs[1, 1].imshow(acq_fn_img, cmap='viridis', origin = "upper")
            axs[1, 1].set_title('Acquisition Function')
            axs[1, 1].scatter(
                [int(indices_all[selected_indices[0]][0])],
                [int(indices_all[selected_indices[0]][1])],
                color='red', marker='x', s=100, label="Selected"
            )
            
            fig.colorbar(im3, ax=axs[1, 1])

            for ax in axs.flat:
                ax.axis('off')

            # fig.suptitle(f'MAE: {mae:.4f}, NLPD: {nlpd:.4f}', fontsize=16)
            plt.tight_layout()
            plt.savefig(Path(res_dir) / f'_BO_step{step}.png')
            plt.show()
            plt.close()


            # Save predictions as a .pkl file
            predictions_data = {
                "acq_fn_img": acq_fn_img ,
                "y_pred_mean_img": y_pred_mean_img,
                "y_pred_var_img": y_pred_var_img,

                "embeddings": embeddings,
            }


            with open(Path(res_dir) / f'predictions_BO_step{step}.pkl', 'wb') as f:
                pickle.dump(predictions_data, f)


            mean_y_pred_mean_al.append(y_pred_mean.mean().cpu())
            mean_y_pred_variance_al.append(y_pred_var.mean().cpu())

            # imshow 4 images: img, pred_mean_img, pred_var_img, true_scal_img


    # Save predictions as a .pkl file
    Active_learning_statistics = {
        "img": img,
        "features": features,
        "indices_all": np.array(indices_all),
        "seed_indices": np.array(seed_indices),
        "unacquired_indices": np.array(unacquired_indices),
        "mean_y_pred_mean_al": np.array(mean_y_pred_mean_al),
        "mean_y_pred_variance_al": np.array(mean_y_pred_variance_al),
        # "mae": np.array(mae_list),
        # "nlpd": np.array(nlpd_list)
                }

    with open(Path(res_dir) / f'Active_learning_statistics.pkl', 'wb') as f:
        pickle.dump(Active_learning_statistics, f)


    ##############################

    predictions_data = Active_learning_statistics
    # Extract necessary data
    img = np.array(predictions_data["img"])  # Image or grid for background visualization
    seed_indices = np.array(predictions_data["seed_indices"])  # Initial sampled indices (referring to positions in indices_all)
    unacquired_indices = np.array(predictions_data["unacquired_indices"])  # Remaining indices
    indices_all = np.array(predictions_data["indices_all"])  # All possible indices (coordinates)

    # Map seed_indices and unacquired_indices to their coordinates in indices_all
    seed_coords = indices_all[seed_indices]
    unacquired_coords = indices_all[unacquired_indices]

    # Calculate acquired indices as the complement of unacquired and seed indices
    acquired_indices = np.setdiff1d(np.arange(indices_all.shape[0]), np.union1d(seed_indices, unacquired_indices), assume_unique=True)
    acquired_coords = indices_all[acquired_indices]

    # Plot the results
    plt.figure(figsize=(10, 8))

    # Display the image or grid as the background
    plt.imshow(img, cmap="gray", origin="upper")

    # Plot the seed points in blue
    plt.scatter(seed_coords[:, 0], seed_coords[:, 1], c="b", label="Seed Points", marker="o")

    time_order = np.arange(len(acquired_coords))  # Create a sequence representing time
    scatter = plt.scatter(acquired_coords[:, 0], acquired_coords[:, 1], c=time_order, cmap="bwr", label="Acquired Points", marker="x")

    # Plot the unacquired points in green
    # plt.scatter(unacquired_coords[:, 1], unacquired_coords[:, 0], c="g", label="Unacquired Points", marker="+")

    # Set plot labels and legend
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title("Active Learning Trajectory")
    plt.legend()
    plt.grid(True)
    # Add a colorbar and label it as "Steps"
    cbar = plt.colorbar(scatter)
    cbar.set_label("Steps")


    plt.savefig(Path(res_dir) / "AL_traj.png")
    plt.show()
    plt.close()


    # Extract data for learning curve
    mean_y_pred_mean_al = np.array(predictions_data["mean_y_pred_mean_al"])
    mean_y_pred_variance_al = np.array(predictions_data["mean_y_pred_variance_al"])
    # mae_list = np.array(predictions_data["mae"])
    # nlpd_list = np.array(predictions_data["nlpd"])

    steps = np.arange(len(mean_y_pred_mean_al))  # Assuming the steps are sequential indices

    # Calculate the upper and lower bounds using variance
    upper_bound = mean_y_pred_mean_al + np.sqrt(mean_y_pred_variance_al)
    lower_bound = mean_y_pred_mean_al - np.sqrt(mean_y_pred_variance_al)

    # Plot the learning curve
    plt.figure(figsize=(10, 6))

    # Plot the mean predictions
    plt.plot(steps, mean_y_pred_mean_al, label="Mean Prediction", color="blue", linewidth=2)

    # Fill between the upper and lower bounds to represent variance
    plt.fill_between(
        steps,
        lower_bound,
        upper_bound,
        color="blue",
        alpha=0.2,
        label="Variance (±1 std)"
    )

    # Add labels, title, and legend
    plt.xlabel("Steps")
    plt.ylabel("Mean Prediction")
    plt.title("Learning Curve with Variance")
    plt.legend()
    plt.grid(True)
    plt.savefig(Path(res_dir) /"AL_learning_curve.png")
    plt.show()
    plt.close()


    # plt.figure(figsize=(10, 6))

    # plt.plot(steps, mae_list, color="red", linewidth=2)

    # plt.xlabel("Steps")
    # plt.ylabel("Mean absolute ERROR")
    # plt.grid(True)
    # plt.savefig(Path(res_dir) / "AL_error_curve.png")
    # plt.show()
    # plt.close()


### 3f. Set parameters and Run experiments

In [None]:
from stemOrchestrator.logging_config   import setup_logging
from datetime import datetime
import os
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
exp_name = "edx-haadf-beacon"
data_folder  = f"./{exp_name}_{current_time}/"
os.makedirs(name = data_folder, exist_ok=True)
out_path = data_folder
setup_logging(out_path=out_path) 

In [None]:
from stemOrchestrator.acquisition import TFacquisition, DMacquisition
from stemOrchestrator.simulation import DMtwin
from stemOrchestrator.process import HAADF_tiff_to_png, tiff_to_png
from autoscript_tem_microscope_client import TemMicroscopeClient
import matplotlib.pyplot as plt
import logging
plot = plt
from typing import Dict

In [None]:
import os
import json
from pathlib import Path

ip = os.getenv("MICROSCOPE_IP")
port = os.getenv("MICROSCOPE_PORT")

if not ip or not port:
    secret_path = Path("../../../config_secret.json")
    if secret_path.exists():
        with open(secret_path, "r") as f:
            secret = json.load(f)
            ip = ip or secret.get("ip_TF")
            port = port or secret.get("port_TF")
print(ip, port)

config = {
        "ip": ip,
        "port": port,
        "haadf_exposure": 40e-8,  # micro-seconds per pixel
        "haadf_resolution": 512, # square
        "edx_exposure": 3e-3, # seconds
        "seed": 5,
        "seed_pts" : 5, # How many points you want to start your BO with?
        "budget" : 5, # How many experimental budget you have?
        "out_dir_parent": out_path, # recommended : leave as is
        "dataset_name": "live_mic", # name of data to be loaded in DTmicroscope
        "device": "cuda",
        "num_epochs": 1, # Number of epoch the dkl model trains at each experimental step - Might need tuning based on data
        "normalize_data": True, 
        "window_size": 16, # For square patches - structure property relationship
        "scal_stem": "sum" # What physics interested in? options on this data: "loop_area", "loop_height", "positive_nucleation_bias", "negative_nucleation_bias"
        }
run(config)