[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/Tuning-aberrations-STEM.ipynb)

# Goal: Optimizing Electron Microscopy Image Quality through MOBO-Driven Aberration Tuning
- Note: If you are not familiar with aberrations or electron Microscope its completely fine
    - Treat it like optimizing bunch of paramteres fast to get a good signal (which in this case is image)
    - The skill you learn here is useful to apply on your optimization problems
    - at the end I have put some suggested things to try as exercise

## Recommended: Take a GPU instance in Colab

## Credits 
- Utkarsh and Sergei for this notebook and framing the optimization problem
- Gerd and Austin for the simulator 
    - checkout Austins original github repo here
        - https://github.com/AustinHouston/pystemsim


### 1. Install stuff

In [None]:
#install - botorch, gpytorch, abtem, pytemlib
!pip install botorch==0.12.0
!pip install gpytorch==1.13
!pip install abtem
!pip install pytemlib

### 2. Simulator helper functions

In [None]:
# Description: This file contains the functions to generate synthetic data for the neural network training.
# By Austin Houston
# Date: 02/28/2024
# Updated: 05/10/2024

import dask
import numpy as np
import random
import sidpy
import dask.array as da
import scipy.special as sp
from scipy.ndimage import zoom, gaussian_filter
from skimage.draw import disk
from ase import Atoms
from ase.neighborlist import NeighborList
from scipy.fft import fft2, ifft2
import pyTEMlib.probe_tools as pt


def make_holes(atoms: Atoms, n_holes: int, hole_size: float) -> Atoms:
    """
    Create holes in an Atoms object by deleting atoms around randomly selected positions.

    Parameters:
    - atoms (ase.Atoms): The input Atoms object.
    - n_holes (int): The number of holes to create.
    - hole_size (float): The radius of each hole.

    Returns:
    - ase.Atoms: The modified Atoms object with holes.
    """
    # Step 1: Randomly select n_holes atoms
    num_atoms = len(atoms)
    selected_indices = random.sample(range(num_atoms), n_holes)

    # Step 2: Find and delete atoms within radius hole_size
    for index in selected_indices:
        # Get the position of the selected atom
        pos = atoms[index].position

        # Create a NeighborList to find atoms within hole_size
        cutoffs = [hole_size / 2] * len(atoms)
        nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
        nl.update(atoms)

        # Find atoms within hole_size around the selected atom
        indices, offsets = nl.get_neighbors(index)
        indices = indices.tolist()

        # Add the selected atom itself to the list of atoms to be deleted
        indices.append(index)

        # Delete atoms by their indices
        atoms = atoms[[atom.index for atom in atoms if atom.index not in indices]]

    return atoms

def rotate_xtal(xtal, angle):
    # pad for worst case and rotate
    padded = xtal * (2, 2, 1)
    padded.rotate('z', angle, 'com')

    # crop to original cell
    cell = xtal.cell
    positions = padded.get_positions()[:, :2]
    inv_cell = np.linalg.inv(cell[:2, :2])
    frac = positions @ inv_cell - 0.5
    mask = np.all((frac >= 0) & (frac < 1), axis=1)

    # creat the new xtal object
    xtal_cropped = padded[mask].copy()
    xtal_cropped.set_cell(cell, scale_atoms=False)
    xtal_cropped.set_scaled_positions(np.hstack([frac[mask], padded.get_scaled_positions()[mask, 2:3]]))

    return xtal_cropped

def sub_pix_gaussian(size=10, sigma=0.2, dx=0.0, dy=0.0):
    # returns sub-pix shifted gaussian
    coords = np.arange(size) - (size - 1) / 2.0
    x, y = np.meshgrid(coords, coords)
    g = np.exp(-(((x + dx) ** 2 + (y + dy) ** 2) / (2 * sigma**2)))
    g /= g.max()
    return g

def create_pseudo_potential(xtal, pixel_size, sigma, bounds, atom_frame=11):
    # Create empty image
    x_min, x_max = bounds[0], bounds[1]
    y_min, y_max = bounds[2], bounds[3]
    pixels_x = int((x_max - x_min) / pixel_size)
    pixels_y = int((y_max - y_min) / pixel_size)
    potential_map = np.zeros((pixels_x, pixels_y))
    padding = atom_frame  # to avoid edge effects
    potential_map = np.pad(potential_map, padding, mode='constant', constant_values=0.0)

    # Map of atomic numbers - i.e. scattering intensity
    atomic_numbers = xtal.get_atomic_numbers()
    positions = xtal.get_positions()[:, :2]

    mask = ((positions[:, 0] >= x_min) & (positions[:, 0] < x_max) & (positions[:, 1] >= y_min) & (positions[:, 1] < y_max))
    positions = positions[mask]
    atomic_numbers = atomic_numbers[mask]

    for pos, atomic_number in zip(positions, atomic_numbers):
        x,y = np.round(pos/pixel_size)
        dx,dy = pos - np.round(pos)
  
        single_atom = sub_pix_gaussian(size=atom_frame, sigma=sigma, dx=dx, dy=dy) * atomic_number
        potential_map[int(x+padding+dx-padding//2-1):int(x+padding+dx+padding//2),int(y+padding+dy-padding//2-1):int(y+padding+dy+padding//2)] += single_atom
    potential_map = potential_map[padding:-padding, padding:-padding]
    normalized_map = potential_map / np.max(potential_map)

    # make a sidpy dataset
    dset = sidpy.Dataset.from_array(normalized_map, name = 'Scattering Potential')
    dset.data_type = 'image'
    dset.units = 'A.U.'
    dset.quantity = 'Scattering cross-section'
    dset.set_dimension(0, sidpy.Dimension(pixel_size * np.arange(pixels_x),
                        name='x', units='Å', quantity='Length',dimension_type='spatial'))
    dset.set_dimension(1, sidpy.Dimension(pixel_size * np.arange(pixels_y),
                        name='y', units='Å', quantity='Length',dimension_type='spatial'))

    return dset


def get_masks(xtal, pixel_size=0.1, radius=3, axis_extent=None, mode='one_hot'):
    positions = xtal.get_positions()[:, :2]
    atomic_numbers = xtal.get_atomic_numbers()
    _, inverse_indices = np.unique(atomic_numbers, return_inverse=True)
    atom_ids = inverse_indices + 1  # the background pixels will be labeled as 0
    unique_atom_ids = np.unique(atom_ids)

    # Determine image size
    if axis_extent is not None:
        xmin, xmax, ymin, ymax = axis_extent
    else:
        xmin, xmax = np.min(positions[:, 0]), np.max(positions[:, 0])
        ymin, ymax = np.min(positions[:, 1]), np.max(positions[:, 1])
    img_height = int((ymax - ymin) / pixel_size)
    img_width = int((xmax - xmin) / pixel_size)

    master_mask = np.zeros((len(unique_atom_ids), img_height, img_width), dtype=np.uint8)
    
    def create_mask_for_atom(atom_id):
        mask = np.zeros((img_height, img_width), dtype=np.uint8)
        atom_mask = (atom_ids == atom_id)
        atom_positions = positions[atom_mask]

        # Make mask 1 in radius around each atom
        for x, y in atom_positions:
            x_pixel = int((x - xmin) / pixel_size)
            y_pixel = int((y - ymin) / pixel_size)
            rr, cc = disk((y_pixel, x_pixel), radius, shape=mask.shape)
            mask[rr, cc] = 1
        master_mask[atom_id - 1, mask == 1] = 1

    # Parallelize the mask creation
    tasks = [dask.delayed(create_mask_for_atom)(atom_id) for atom_id in unique_atom_ids]
    dask.compute(*tasks)

    if mode.lower() == 'one_hot':
        num_masks = unique_atom_ids.size + 1  # include background
        background_mask = np.zeros((img_height, img_width), dtype=np.uint8)
        background_mask[(np.sum(master_mask, axis=0) == 0)] = 1
        masks = np.stack([background_mask] + [master_mask[i] for i in range(len(unique_atom_ids))], axis=0)
        return masks

    elif mode.lower() == 'binary':
        sum_masks = np.sum(master_mask, axis=0)
        final_mask = np.where(sum_masks > 0, 1, 0)
        return final_mask

    elif mode.lower() == 'integer':
        final_mask = np.zeros((img_height, img_width), dtype=np.uint8)
        for i, mask in enumerate(master_mask):
            final_mask[mask == 1] = i + 1
        return final_mask

    else:
        raise ValueError("Invalid mode. Choose from 'one_hot', 'binary', or 'integer'")


def airy_disk(potential, resolution = 1.1):
    # make grid
    size_x = potential.shape[0]
    size_y = potential.shape[1]
    x = np.arange(size_x) - size_x//2 + 1
    y = np.arange(size_y) - size_y//2 + 1
    xx, yy = np.meshgrid(x, y)
    rr = np.sqrt(xx**2 + yy**2)

    pixel_size = potential.x.slope # Angstrom/pixel
    
    disk_radius = pixel_size / resolution * 2.5 # Airy disk radius in pixels
    # not sure why this 2.5 belonggs in here, but it works

    # Calculate the Airy pattern (PSF)
    with np.errstate(divide='ignore', invalid='ignore'):
        psf = (2 * sp.j1(disk_radius * rr) / (disk_radius * rr))**2
        psf[rr == 0] = 1  # Handling the division by zero at the center

    # Normalize the PSF
    psf /= np.sum(psf)
    
    dset = sidpy.Dataset.from_array(psf, name = 'Probe PSF')
    dset.data_type = 'image'
    dset.units = 'A.U.'
    dset.quantity = 'Probability'
    dset.set_dimension(0, sidpy.Dimension(pixel_size * np.arange(size_x),
                        name='x', units='Å', quantity='Length',dimension_type='spatial'))
    dset.set_dimension(1, sidpy.Dimension(pixel_size * np.arange(size_y),
                        name='y', units='Å', quantity='Length',dimension_type='spatial'))

    return dset

def get_probe(ab, potential):
    pixel_size = potential.x.slope # Angstrom/pixel
    size_x, size_y = potential.shape

    probe, A_k, chi  = pt.get_probe(ab, size_x, size_y,  scale = 'mrad', verbose= True)

    dset = sidpy.Dataset.from_array(probe, name = 'Probe PSF')
    dset.data_type = 'image'
    dset.units = 'A.U.'
    dset.quantity = 'Probability'
    dset.set_dimension(0, sidpy.Dimension(pixel_size * np.arange(size_x),
                        name='x', units='Å', quantity='Length',dimension_type='spatial'))
    dset.set_dimension(1, sidpy.Dimension(pixel_size * np.arange(size_y),
                        name='y', units='Å', quantity='Length',dimension_type='spatial'))

    return dset


def convolve_kernel(potential, psf):
    # Convolve using FFT
    psf_shifted = da.fft.ifftshift(psf)
    image = da.fft.ifft2(da.fft.fft2(potential) * da.fft.fft2(psf_shifted))
    image = da.absolute(image)
    image = image - image.min()
    image = image / image.max()

    size_x, size_y = potential.shape
    pixel_size = potential.x.slope # Angstrom/pixel

    dset = potential.like_data(image)
    dset.units = 'A.U.'
    dset.quantity = 'Intensity'
    
    return dset


def poisson_noise(image, counts = 10e8):
    # Normalize the image
    image = image - image.min()
    image = image / image.sum()
    noisy_image = np.random.poisson(image * counts)

    noisy_image = noisy_image - noisy_image.min()
    noisy_image = noisy_image / noisy_image.max()
    noisy_image = image.like_data(noisy_image)

    return noisy_image


def lowfreq_noise(image, noise_level=0.1, freq_scale=0.1):
    size_x, size_y = image.shape

    noise = np.random.normal(0, noise_level, (size_x, size_y))
    noise_fft = np.fft.fft2(noise)

    # Create a frequency filter that emphasizes low frequencies
    x_freqs = np.fft.fftfreq(size_x)
    y_freqs = np.fft.fftfreq(size_y)
    freq_filter = np.outer(np.exp(-np.square(x_freqs) / (2 * freq_scale**2)),
                           np.exp(-np.square(y_freqs) / (2 * freq_scale**2)))

    # Apply the frequency filter to the noise in the frequency domain
    filtered_noise_fft = noise_fft * freq_filter
    low_freq_noise = np.fft.ifft2(filtered_noise_fft).real
    noisy_image = image + low_freq_noise
    noisy_image = image.like_data(noisy_image)

    return noisy_image


def grid_crop(image_master, crop_size=512, crop_glide=128):
    '''
    Slices an image into smaller, overlapping square crops.

    This function takes a larger image and divides it into smaller, overlapping square segments. 
    It's useful for processing large images in smaller batches, especially in machine learning applications 
    where input size is fixed.

    Parameters:
    - image_master: A NumPy array representing the image to be cropped. 
                    It should be a 2D array if the image is grayscale, or a 3D array for RGB images.
    - crop_size (int, optional): The size of each square crop. Default is 256 pixels.
    - crop_glide (int, optional): The stride or glide size for cropping. 
                                 Determines the overlap between consecutive crops. Default is 128 pixels.

    Returns:
    - cropped_ims: A NumPy array containing the cropped images. 
                   The array is 3D, where the first dimension represents the index of the crop, 
                   and the next two dimensions represent the height and width of the crops.

    Note:
    - The function assumes the input image is square. Non-square images might lead to unexpected results.
    - The return array is of type 'float16' to reduce memory usage, which might affect the precision of pixel values.
    '''

    n_crops = int((len(image_master) - crop_size)/crop_glide + 1)
    cropped_ims = np.zeros((n_crops,n_crops,crop_size,crop_size))

    for x in np.arange(n_crops):
        for y in np.arange(n_crops):
            xx,yy = int(x*crop_glide), int(y*crop_glide)
            cropped_ims[int(x),int(y)] = image_master[xx:xx+crop_size,yy:yy+crop_size]
    cropped_ims = cropped_ims.reshape((-1,crop_size,crop_size)).astype('float16')

    return cropped_ims


def resize_image(array, n, order = 3):
    """
    Resize a numpy array to n x n using interpolation.

    Parameters:
    array (numpy.ndarray): The input array.
    n (int): The size of the new square array.

    Returns:
    numpy.ndarray: The resized square array.
    """
    # Get the current shape of the array
    height, width = array.shape[-2:]

    # Calculate zoom factors
    zoom_factor = n / max(height, width)
    array = array.astype(np.float32)

    if len(array.shape) == 2:
        return zoom(array, [zoom_factor, zoom_factor], order = order)
    elif len(array.shape) == 3:
        return zoom(array, [1,zoom_factor, zoom_factor], order = order)


def shotgun_crop(image, crop_size=512, magnification_var = None, n_crops=10, seed=42, return_binary = False, roi = 'middle'):
    """
    Randomly crops a specified number of sub-images from a given image with variable magnification, supporting images with any number of channels.

    Parameters:
    image (numpy.ndarray): The input image as a NumPy array.
    crop_size (int, optional): The default size for each square crop. Defaults to 512.
    magnification_var (float, optional): The range of magnification variability as a fraction of the crop size. 
        If specified, each crop will be randomly sized within [crop_size * (1 - magnification_var), crop_size * (1 + magnification_var)]. Defaults to None.
    n_crops (int, optional): The number of crops to generate. Defaults to 10.
    seed (int, optional): Seed for the random number generator for reproducibility. Uses random package.

    Returns:
    numpy.ndarray: An array containing the cropped (and potentially resized) images as NumPy arrays.

    Important:
    If using this funciton on an image and mask together, make sure to use the same seed for both.
    """

    if return_binary == True:
        order = 0
    else:
        order = 3

    # Set seed for reproducibility
    # Seed should be a very large integer for good results
    crop_rng = np.random.default_rng(seed)

    # Get crop sizes for changing magnification later
    if magnification_var is not None:
        crop_sizes = crop_rng.integers(crop_size * ( 1 - magnification_var), crop_size * (1 + magnification_var), n_crops)
        crop_sizes = crop_sizes.astype(int)
    else:
        crop_sizes = np.full(n_crops, crop_size)

    # Randomly crop images (position and size)
    h, w = image.shape[-2:]
    crops = []
    for size in crop_sizes:
        if roi == 'middle':
            edge_cutoff = crop_size//4
            top = crop_rng.integers(edge_cutoff, h - size - edge_cutoff)
            left = crop_rng.integers(edge_cutoff, w - size - edge_cutoff)
        else:
            top = crop_rng.integers(0, h - size)
            left = crop_rng.integers(0, w - size)
        if len(image.shape) > 2:
            crop = image[:, top:top+size, left:left+size]
            crop = resize_image(crop, crop_size, order)
        else:
            crop = image[top:top+size, left:left+size]
            crop = resize_image(crop, crop_size, order)
        crops.append(crop)

    crops = np.array(crops)
    batch_crops = np.stack(crops, axis=0)
        
    return batch_crops


### 3. Introduce simulator

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib ipympl

from ase.io import read

from abtem.atoms import orthogonalize_cell
import pyTEMlib.probe_tools as pt



In [None]:
# download cif --> crystal structure of Tungsten Diselinide
!wget https://raw.githubusercontent.com/AustinHouston/pystemsim/main/crystal_files/WS2.cif


In [None]:
# Scattering potential
xtal = read('WS2.cif')
xtal, transform = orthogonalize_cell(xtal, allow_transform=True, return_transform=True)
xtal = xtal * (30, 20, 1)
positions = xtal.get_positions()[:, :2]
pixel_size = 0.106 # angstrom/pixel
fov = 96 # angstroms
frame = (0,fov,0,fov) # limits of the image in angstroms
potential = create_pseudo_potential(xtal, pixel_size, sigma=1, bounds=frame, atom_frame=11)

# Probe
ab = pt.get_target_aberrations("Spectra300", 60000)
ab['acceleration_voltage'] = 60e3 # eV
ab['FOV'] = fov /12 # Angstroms
ab['convergence_angle'] = 30 # mrad
ab['wavelength'] = pt.get_wavelength(ab['acceleration_voltage'])
ab['C10'] = 1
ab['C23a'] = 0
ab['C23b'] = 0
pt.print_aberrations(ab)

In [None]:
# change ab here
ab['C10'] = 0 # defocus
ab['C12a'] = 0 # twofold astigmatism (a)
ab['C12b'] = 0 # twofold astigmatism (b)
probe = get_probe(ab, potential)
image = convolve_kernel(potential, probe)
noisy_image = lowfreq_noise(image, noise_level=0.5, freq_scale=.04)
sim_im = poisson_noise(noisy_image, counts=1e7)

view = sim_im.plot()

In [None]:
def contrast_rms(im):
    return np.std(im) / np.mean(im)


In [None]:
# How contrast varies as we change defocus - c10
param_range = 8
params = np.linspace(-param_range, param_range, 21)

rms_contrasts = []
images = []
for defocus in params:
    ab['C10'] = defocus
    probe = get_probe(ab, potential)
    image = convolve_kernel(potential, probe)
    noisy_image = lowfreq_noise(image, noise_level=0.5, freq_scale=.04)
    sim_im = poisson_noise(noisy_image, counts=1e7)
    rms_contrasts.append(contrast_rms(np.array(sim_im)))
    images.append(sim_im)
ab['C10'] = 0
plt.figure()
plt.plot(params, rms_contrasts)

### 4. Lets do Multiobjective GP on c1-a1 [ total 3 parametrs as a1 is a1x and a1y]

#### 4a. define reward functions, set parameter range and collect seed data

In [None]:
# ========== SETUP STEM SIMULATOR ==========
print("Setting up STEM simulator...")
xtal = read('WS2.cif')
from abtem.atoms import orthogonalize_cell
xtal, transform = orthogonalize_cell(xtal, allow_transform=True, return_transform=True)
xtal = xtal * (30, 20, 1)

pixel_size = 0.106
fov = 96
frame = (0, fov, 0, fov)
potential = create_pseudo_potential(xtal, pixel_size, sigma=1, bounds=frame, atom_frame=11)

# Setup probe aberrations
ab = pt.get_target_aberrations("Spectra300", 60000)
ab['acceleration_voltage'] = 60e3
ab['FOV'] = fov / 12
ab['convergence_angle'] = 30
ab['wavelength'] = pt.get_wavelength(ab['acceleration_voltage'])
ab['C10'] = 1
ab['C23a'] = 0
ab['C23b'] = 0


def contrast_rms(im, eps=1e-12):
    m = np.mean(im)
    return np.std(im) / (m + eps)

def fft_snr_generic(im, kmin_frac=0.3, eps=1e-12):
    h, w = im.shape
    wy = np.hanning(h)[:, None]
    wx = np.hanning(w)[None, :]
    imw = im * wy * wx

    F = np.fft.fftshift(np.fft.fft2(imw))
    P = (np.abs(F)**2).astype(np.float64)
    P /= (P.sum() + eps)  # dose/scale invariance

    yy, xx = np.mgrid[0:h, 0:w]
    cy, cx = h//2, w//2
    rr = np.hypot(yy - cy, xx - cx)
    rmax = rr.max()
    high = rr >= (kmin_frac * rmax)
    low  = (rr >= 0.05*rmax) & (rr < 0.15*rmax)  # background ring
    return (P[high].mean()) / (P[low].mean() + eps)




import matplotlib.pyplot as plt
from scipy import ndimage

def get_stem_image_contrast_and_fft(c10, c12a, c12b, plot_diagnostics=False):
    ab['C10'] = c10
    ab['C12a'] = c12a
    ab['C12b'] = c12b
    
    probe = get_probe(ab, potential)
    image = convolve_kernel(potential, probe)
    noisy_image = lowfreq_noise(image, noise_level=0.5, freq_scale=0.04)
    sim_im = poisson_noise(noisy_image, counts=1e7)
    
    contrast = contrast_rms(np.array(sim_im))
    
    sim_array = np.array(sim_im, dtype=float)
    sim_array = (sim_array - sim_array.mean()) / sim_array.std()
    
    fft = np.fft.fft2(sim_array)
    fft_shift = np.fft.fftshift(fft)
    power = np.abs(fft_shift)**2
    power_log = np.log1p(power)
    
    center = np.array(power.shape) // 2
    y, x = np.ogrid[:power.shape[0], :power.shape[1]]
    r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
    
    power_log[center[0]-10:center[0]+10, center[1]-10:center[1]+10] = 0
    
    smoothed = ndimage.gaussian_filter(power_log, sigma=2)
    
    threshold = np.percentile(smoothed, 99.5)
    peaks = smoothed > threshold
    
    labeled, num_peaks = ndimage.label(peaks)
    
    if num_peaks == 0:
        fft_score = 0.0
        max_radius = 0
    else:
        peak_distances = []
        for i in range(1, num_peaks + 1):
            peak_coords = np.where(labeled == i)
            peak_y, peak_x = np.mean(peak_coords[0]), np.mean(peak_coords[1])
            distance = np.sqrt((peak_x - center[1])**2 + (peak_y - center[0])**2)
            peak_intensity = smoothed[labeled == i].max()
            if distance > 15:
                peak_distances.append((distance, peak_intensity))
        
        if len(peak_distances) > 0:
            max_radius = max(d[0] for d in peak_distances)
            fft_score = float(max_radius / min(center))
        else:
            max_radius = 0
            fft_score = 0.0
    
    if plot_diagnostics:
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        axes[0,0].imshow(sim_im, cmap='gray')
        axes[0,0].set_title('Original Image')
        
        axes[0,1].imshow(power_log, cmap='hot')
        axes[0,1].set_title('FFT Power (log)')
        
        axes[0,2].imshow(smoothed, cmap='hot')
        axes[0,2].set_title('Smoothed FFT')
        
        axes[1,0].imshow(peaks, cmap='gray')
        axes[1,0].set_title(f'Detected Peaks ({num_peaks})')
        
        axes[1,1].imshow(power_log, cmap='hot')
        if num_peaks > 0:
            for i in range(1, num_peaks + 1):
                peak_coords = np.where(labeled == i)
                peak_y, peak_x = np.mean(peak_coords[0]), np.mean(peak_coords[1])
                distance = np.sqrt((peak_x - center[1])**2 + (peak_y - center[0])**2)
                if distance > 15:
                    axes[1,1].plot(peak_x, peak_y, 'rx', markersize=10)
                    if distance == max_radius:
                        axes[1,1].plot(peak_x, peak_y, 'go', markersize=15, fillstyle='none', linewidth=2)
        axes[1,1].plot(center[1], center[0], 'b+', markersize=20)
        axes[1,1].set_title('Peaks Marked (Green=Farthest)')
        
        radial_profile = []
        for rad in range(0, int(min(center))):
            ring_mask = (r >= rad) & (r < rad+1)
            if ring_mask.any():
                radial_profile.append(smoothed[ring_mask].max())
        axes[1,2].plot(radial_profile)
        if max_radius > 0:
            axes[1,2].axvline(max_radius, color='g', linestyle='--', linewidth=2, label=f'Max peak: {max_radius:.1f}px')
        axes[1,2].axhline(threshold, color='r', linestyle='--', label='threshold')
        axes[1,2].set_xlabel('Radius (px)')
        axes[1,2].set_ylabel('Max Power')
        axes[1,2].legend()
        axes[1,2].set_title('Radial Power Profile')
        
        plt.tight_layout()
        plt.suptitle(f'FFT Score: {fft_score:.3f} (max_r={max_radius:.1f}px)', y=1.00, fontsize=14)
        plt.show()
    
    return contrast, fft_score, sim_im

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)


In [None]:
# Parameter ranges based on your exploration
param_ranges = {
    'C10': (-8, 8),    # defocus
    'C12a': (-10, 10), # twofold astigmatism (a)
    'C12b': (-10, 10)  # twofold astigmatism (b)
}

In [None]:
# Create full grid (coarser for computational efficiency)
n_grid = 7  # 7^3 = 343 points
c10_grid = np.linspace(*param_ranges['C10'], n_grid)
c12a_grid = np.linspace(*param_ranges['C12a'], n_grid)
c12b_grid = np.linspace(*param_ranges['C12b'], n_grid)
C10, C12A, C12B = np.meshgrid(c10_grid, c12a_grid, c12b_grid, indexing='ij')
full_grid = np.stack([C10.flatten(), C12A.flatten(), C12B.flatten()], axis=1)

In [None]:
# Sample seed points
n_seed = 4
seed_indices = np.random.choice(len(full_grid), n_seed, replace=False)
seed_points = full_grid[seed_indices]

In [None]:
# # ========== QUERY SEED POINTS ==========
# print(f"\nQuerying {n_seed} seed points...")
# seed_scores = []
# seed_images = []

# for i, (c23a, c23b, c21a, c21b) in enumerate(seed_points):
#     contrast, sim_im = get_stem_image_contrast(c23a, c23b, c21a, c21b)
#     seed_scores.append(contrast)
#     seed_images.append(sim_im)
#     print(f"Seed {i+1}/{n_seed}, C23a={c23a:.2f}, C23b={c23b:.2f},  C21a={c21a:.2f}, C21b={c21b:.2f}, contrast={contrast:.4f}")

# seed_scores = np.array(seed_scores)

# ========== QUERY SEED POINTS ==========
print(f"\nQuerying {n_seed} seed points...")
seed_scores = []
seed_images = []

for i, (c10, c12a, c12b) in enumerate(seed_points):
    contrast, fft_score, sim_im = get_stem_image_contrast_and_fft(c10, c12a, c12b, plot_diagnostics=True)
    rewards = np.array((contrast, fft_score))
    seed_scores.append(rewards)
    seed_images.append(sim_im)
    # print(f"Seed {i+1}/{n_seed}, C23a={c23a:.2f}, C23b={c23b:.2f},  C21a={c21a:.2f}, C21b={c21b:.2f}, contrast={contrast:.4f}")

seed_scores = np.array(seed_scores)


#### 4b. MOBO

In [None]:
import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.fit import fit_gpytorch_mll
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
import numpy as np
import matplotlib.pyplot as plt


from botorch.acquisition.multi_objective import qLogExpectedHypervolumeImprovement 
from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning
from botorch.utils.multi_objective import is_non_dominated


# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# ========== INITIAL SEED POINTS (from previous code) ==========
print(f"Starting with {n_seed} seed points...")
print(f"Best initial contrast: {seed_scores.max():.4f}")

# Convert to tensors
train_X = torch.tensor(seed_points, dtype=torch.float64)
train_Y = torch.tensor(seed_scores, dtype=torch.float64)

# Define bounds for optimization
bounds = torch.tensor([
    [param_ranges['C10'][0], param_ranges['C12a'][0], param_ranges['C12b'][0]],  # lower bounds
    [param_ranges['C10'][1], param_ranges['C12a'][1], param_ranges['C12b'][1]]   # upper bounds
], dtype=torch.float64)



In [None]:
import torch

# ---- device & dtype ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64

torch.set_default_dtype(dtype)

# move inputs/bounds to device+dtype
train_X = train_X.to(device=device, dtype=dtype)
train_Y = train_Y.to(device=device, dtype=dtype)
bounds  = bounds.to(device=device, dtype=dtype)

n_bo_steps = 50
all_X = train_X.clone()
all_Y = train_Y.clone()
all_images = seed_images.copy()

print("\n" + "="*60)
print("Starting Multi-Objective Bayesian Optimization with EHVI")
print("="*60)

ref_point = train_Y.min(dim=0).values - 0.1 * train_Y.std(dim=0)
print(f"Reference point: {ref_point.detach().cpu().numpy()}")

for step in range(n_bo_steps):
    print(f"\n--- BO Step {step + 1}/{n_bo_steps} ---")
    
    # Train GP (model follows tensor's device/dtype)
    print("Training Multi-Output GP...")
    gp_model = SingleTaskGP(
        all_X, all_Y,
        input_transform=Normalize(d=all_X.shape[-1]).to(device=device, dtype=dtype),
        outcome_transform=Standardize(m=all_Y.shape[-1]).to(device=device, dtype=dtype),
    ).to(device=device, dtype=dtype)
    
    gp_model.likelihood.noise_covar.initialize(noise=0.01)
    mll = ExactMarginalLogLikelihood(gp_model.likelihood, gp_model).to(device=device, dtype=dtype)
    fit_gpytorch_mll(mll)
    
    # EHVI acquisition
    print("Computing Pareto frontier...")
    pareto_mask = is_non_dominated(all_Y)
    pareto_Y = all_Y[pareto_mask]
    print(f"Pareto frontier size: {pareto_Y.shape[0]}")
    
    partitioning = NondominatedPartitioning(ref_point=ref_point, Y=pareto_Y)
    EHVI = qLogExpectedHypervolumeImprovement(
        model=gp_model,
        ref_point=ref_point.tolist(),
        partitioning=partitioning,
    )
    
    # Optimize
    print("Optimizing acquisition function...")
    candidate, acq_value = optimize_acqf(
        acq_function=EHVI,
        bounds=bounds,
        q=1,
        num_restarts=10,
        raw_samples=100,
    )
    
    next_X = candidate.detach()
    next_params = next_X.squeeze().detach().cpu().numpy()
    print(f"EHVI value: {acq_value:.6f}")
    
    # Query simulator (returns CPU values)
    print("Querying STEM simulator...")
    objective1, objective2, next_image = get_stem_image_contrast_and_fft(
        next_params[0], next_params[1], next_params[2], plot_diagnostics=True
    )
    next_Y = torch.tensor([[objective1, objective2]], dtype=dtype, device=device)
    
    print(f"Observed objectives: [{objective1:.4f}, {objective2:.4f}]")
    
    # Update tensors on-device
    all_X = torch.cat([all_X, next_X], dim=0)
    all_Y = torch.cat([all_Y, next_Y], dim=0)
    all_images.append(next_image)
    
    new_pareto_mask = is_non_dominated(all_Y)
    if new_pareto_mask[-1]:
        print("✓ NEW PARETO POINT!")


#### 4c. Lets look at Pareto front

In [None]:
# ========== FIND EXTREME AND MID PARETO POINTS ==========
print("\n" + "="*60)
print("Pareto Frontier Analysis")
print("="*60)

final_pareto_mask = is_non_dominated(all_Y)
final_pareto_X = all_X[final_pareto_mask]
final_pareto_Y = all_Y[final_pareto_mask]
pareto_indices = torch.where(final_pareto_mask)[0].cpu().numpy()

print(f"Number of Pareto optimal points: {final_pareto_Y.shape[0]}")

# Find extreme points
extreme_indices = []

# Extreme for Objective 1
max_obj1_idx = torch.argmax(final_pareto_Y[:, 0]).item()
min_obj1_idx = torch.argmin(final_pareto_Y[:, 0]).item()

# Extreme for Objective 2
max_obj2_idx = torch.argmax(final_pareto_Y[:, 1]).item()
min_obj2_idx = torch.argmin(final_pareto_Y[:, 1]).item()

extreme_indices.extend([max_obj1_idx, min_obj1_idx, max_obj2_idx, min_obj2_idx])
extreme_indices = list(set(extreme_indices))  # Remove duplicates

# Find middle point (balanced trade-off)
# Normalize objectives to [0,1] then find point closest to (0.5, 0.5)
normalized_pareto_Y = (final_pareto_Y - final_pareto_Y.min(dim=0).values) / (final_pareto_Y.max(dim=0).values - final_pareto_Y.min(dim=0).values + 1e-8)
distances_to_center = torch.norm(normalized_pareto_Y - 0.5, dim=1)
mid_idx = torch.argmin(distances_to_center).item()

# Combine: extremes + mid
selected_indices = sorted(list(set(extreme_indices + [mid_idx])))

print(f"\nSelected Pareto points for visualization: {len(selected_indices)}")
for idx in selected_indices:
    pareto_idx = pareto_indices[idx]
    params = all_X[pareto_idx].cpu().numpy()
    obj1, obj2 = all_Y[pareto_idx, 0].item(), all_Y[pareto_idx, 1].item()
    
    label = ""
    if idx == max_obj1_idx:
        label += "[MAX Obj1] "
    if idx == min_obj1_idx:
        label += "[MIN Obj1] "
    if idx == max_obj2_idx:
        label += "[MAX Obj2] "
    if idx == min_obj2_idx:
        label += "[MIN Obj2] "
    if idx == mid_idx:
        label += "[MID/Balanced] "
    
    print(f"  {label}")
    print(f"    Obj1={obj1:.4f}, Obj2={obj2:.4f}")
    print(f"    C10={params[0]:.2f}, C12a={params[1]:.2f}, C12b={params[2]:.2f}")

# ========== VISUALIZE ONLY EXTREME + MID PARETO IMAGES ==========
n_selected = len(selected_indices)
n_cols = min(3, n_selected)
n_rows = int(np.ceil(n_selected / n_cols))

fig = plt.figure(figsize=(7*n_cols, 7*n_rows))

for plot_idx, pareto_idx_in_frontier in enumerate(selected_indices):
    ax = fig.add_subplot(n_rows, n_cols, plot_idx + 1)
    
    pareto_idx = pareto_indices[pareto_idx_in_frontier]
    img = all_images[pareto_idx]
    params = all_X[pareto_idx].cpu().numpy()
    obj1, obj2 = all_Y[pareto_idx, 0].item(), all_Y[pareto_idx, 1].item()
    
    # Determine label
    label = ""
    if pareto_idx_in_frontier == max_obj1_idx:
        label = "MAX Obj1"
        color = 'red'
    elif pareto_idx_in_frontier == min_obj1_idx:
        label = "MIN Obj1"
        color = 'blue'
    elif pareto_idx_in_frontier == max_obj2_idx:
        label = "MAX Obj2"
        color = 'green'
    elif pareto_idx_in_frontier == min_obj2_idx:
        label = "MIN Obj2"
        color = 'orange'
    elif pareto_idx_in_frontier == mid_idx:
        label = "BALANCED (Mid)"
        color = 'purple'
    else:
        label = "Extreme"
        color = 'black'
    

    ax.imshow(np.array(img), cmap='gray')
    ax.set_title(
        f'{label}\n'
        f'Obj1={obj1:.4f}, Obj2={obj2:.4f}\n'
        f'C10={params[0]:.1f}, C12a={params[1]:.1f}\n'
        f'C12b={params[2]:.1f}\n',
        fontsize=12,
        fontweight='bold',
        color=color
    )
    ax.axis('off')

plt.tight_layout()
plt.savefig('pareto_extreme_mid_images.png', dpi=150, bbox_inches='tight')
plt.show()

# ========== COMBINED RESULTS PLOT ==========
fig = plt.figure(figsize=(18, 6))

# 1. Objective space
ax1 = plt.subplot(1, 3, 1)
ax1.scatter(all_Y[:, 0].cpu().numpy(), all_Y[:, 1].cpu().numpy(), 
           c='lightblue', s=150, alpha=0.6, edgecolors='gray',
           label='All evaluations')
ax1.scatter(final_pareto_Y[:, 0].cpu().numpy(), final_pareto_Y[:, 1].cpu().numpy(), 
           c='lightcoral', s=200, alpha=0.5, edgecolors='black', 
           linewidths=1, label='Pareto frontier')

# Highlight extreme and mid points
colors = []
labels_legend = []
for idx in selected_indices:
    if idx == max_obj1_idx:
        colors.append('red')
        if 'MAX Obj1' not in labels_legend:
            labels_legend.append('MAX Obj1')
    elif idx == max_obj2_idx:
        colors.append('green')
        if 'MAX Obj2' not in labels_legend:
            labels_legend.append('MAX Obj2')
    elif idx == mid_idx:
        colors.append('purple')
        if 'Balanced' not in labels_legend:
            labels_legend.append('Balanced')
    else:
        colors.append('orange')

for idx, color in zip(selected_indices, colors):
    ax1.scatter(final_pareto_Y[idx, 0].cpu().numpy(), final_pareto_Y[idx, 1].cpu().numpy(),
               c=color, s=400, marker='*', edgecolors='black', linewidths=2, zorder=10)

ax1.set_xlabel('Objective 1 (Contrast)', fontsize=12)
ax1.set_ylabel('Objective 2 (Other Metric)', fontsize=12)
ax1.set_title('Pareto Frontier (★ = Extreme/Mid points)', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. BO progress
ax2 = plt.subplot(1, 3, 2)
iterations = range(len(all_Y))
ax2.plot(iterations, all_Y[:, 0].cpu().numpy(), 'o-', label='Objective 1', 
        alpha=0.7, linewidth=2, markersize=8)
ax2.plot(iterations, all_Y[:, 1].cpu().numpy(), 's-', label='Objective 2', 
        alpha=0.7, linewidth=2, markersize=8)
ax2.axvline(len(train_Y)-1, color='red', linestyle='--', 
          label='BO start', linewidth=2)
ax2.set_xlabel('Iteration', fontsize=12)
ax2.set_ylabel('Objective Value', fontsize=12)
ax2.set_title('BO Progress', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# 3. Hypervolume
from botorch.utils.multi_objective.hypervolume import Hypervolume

hv_computer = Hypervolume(ref_point=ref_point)
hypervolumes = []
for i in range(len(train_Y), len(all_Y) + 1):
    current_Y = all_Y[:i]
    pareto_mask_i = is_non_dominated(current_Y)
    pareto_Y_i = current_Y[pareto_mask_i]
    hv = hv_computer.compute(pareto_Y_i)
    hypervolumes.append(hv)

ax3 = plt.subplot(1, 3, 3)
ax3.plot(range(len(train_Y), len(all_Y) + 1), hypervolumes, 'o-', 
        linewidth=2, markersize=8, color='purple')
ax3.set_xlabel('Iteration', fontsize=12)
ax3.set_ylabel('Hypervolume', fontsize=12)
ax3.set_title('Hypervolume Improvement', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('multi_objective_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n=== SUMMARY ===")
print(f"Total evaluations: {len(all_Y)}")
print(f"Pareto frontier size: {len(pareto_indices)}")
print(f"Extreme + Mid points shown: {len(selected_indices)}")
print(f"Final hypervolume: {hypervolumes[-1]:.4f}")

### 5. Lets do Multiobjective GP on b2-a2 [ total 4 parametrs as b2 is b2x and b2y and a2 is aa2x and a2y]

#### 5a. define reward functions, set parameter range and collect seed data

In [None]:
# ========== SETUP STEM SIMULATOR ==========
print("Setting up STEM simulator...")
xtal = read('WS2.cif')
from abtem.atoms import orthogonalize_cell
xtal, transform = orthogonalize_cell(xtal, allow_transform=True, return_transform=True)
xtal = xtal * (30, 20, 1)

pixel_size = 0.106
fov = 96
frame = (0, fov, 0, fov)
potential = create_pseudo_potential(xtal, pixel_size, sigma=1, bounds=frame, atom_frame=11)

# Setup probe aberrations
ab = pt.get_target_aberrations("Spectra300", 60000)
ab['acceleration_voltage'] = 60e3
ab['FOV'] = fov / 12
ab['convergence_angle'] = 30
ab['wavelength'] = pt.get_wavelength(ab['acceleration_voltage'])
ab['C10'] = 1
ab['C23a'] = 0
ab['C23b'] = 0


def contrast_rms(im, eps=1e-12):
    m = np.mean(im)
    return np.std(im) / (m + eps)

def fft_snr_generic(im, kmin_frac=0.3, eps=1e-12):
    h, w = im.shape
    wy = np.hanning(h)[:, None]
    wx = np.hanning(w)[None, :]
    imw = im * wy * wx

    F = np.fft.fftshift(np.fft.fft2(imw))
    P = (np.abs(F)**2).astype(np.float64)
    P /= (P.sum() + eps)  # dose/scale invariance

    yy, xx = np.mgrid[0:h, 0:w]
    cy, cx = h//2, w//2
    rr = np.hypot(yy - cy, xx - cx)
    rmax = rr.max()
    high = rr >= (kmin_frac * rmax)
    low  = (rr >= 0.05*rmax) & (rr < 0.15*rmax)  # background ring
    return (P[high].mean()) / (P[low].mean() + eps)


def get_stem_image_contrast(c23a, c23b, c21a, c21b):
    """Generate STEM image and return contrast for given aberrations"""
    # ab['C10'] = c10
    ab['C23a'] = c23a
    ab['C23b'] = c23b
    ab['C21a'] = c21a
    ab['C21b'] = c21b

    
    probe = dg.get_probe(ab, potential)
    image = dg.convolve_kernel(potential, probe)
    noisy_image = dg.lowfreq_noise(image, noise_level=0.5, freq_scale=0.04)
    sim_im = dg.poisson_noise(noisy_image, counts=1e7)
    
    contrast = contrast_rms(np.array(sim_im))
    return contrast, sim_im

def get_stem_image_contrast_and_fft(c23a, c23b, c21a, c21b):
    """Generate STEM image and return contrast for given aberrations"""
    # ab['C10'] = c10
    ab['C23a'] = c23a
    ab['C23b'] = c23b
    ab['C21a'] = c21a
    ab['C21b'] = c21b

    
    probe = get_probe(ab, potential)
    image = convolve_kernel(potential, probe)
    noisy_image = lowfreq_noise(image, noise_level=0.5, freq_scale=0.04)
    sim_im = poisson_noise(noisy_image, counts=1e7)
    
    contrast = contrast_rms(np.array(sim_im))
    
    fft_score = fft_snr_generic(np.array(sim_im))
    return contrast, fft_score, sim_im

In [None]:
import numpy as np
import torch
import botorch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)


In [None]:
# Parameter ranges based on your exploration
param_ranges = {
    # 'C10': (-8, 8),    # defocus
    'C23a': (-200, 200), # twofold astigmatism (a)
    'C23b': (-200, 200), # twofold astigmatism (a)

    'C21a': (-500, 500),  # twofold astigmatism (b)
    'C21b': (-500, 500)  # twofold astigmatism (b)

}

In [None]:
# Create full grid (coarser for computational efficiency)
n_grid = 20  # 7^3 = 343 points
# c10_grid = np.linspace(*param_ranges['C10'], n_grid)
c23a_grid = np.linspace(*param_ranges['C23a'], n_grid)
c23b_grid = np.linspace(*param_ranges['C23b'], n_grid)

c21a_grid = np.linspace(*param_ranges['C21a'], n_grid)
c21b_grid = np.linspace(*param_ranges['C21b'], n_grid)

C23A, C23B, C21A, C21B = np.meshgrid( c23a_grid, c23b_grid, c21a_grid, c21b_grid, indexing='ij')
full_grid = np.stack([C23A.flatten(), C23B.flatten(),  C21A.flatten(), C21B.flatten()], axis=1)

In [None]:
# Sample seed points
n_seed = 3
seed_indices = np.random.choice(len(full_grid), n_seed, replace=False)
seed_points = full_grid[seed_indices]

In [None]:
# # ========== QUERY SEED POINTS ==========
# print(f"\nQuerying {n_seed} seed points...")
# seed_scores = []
# seed_images = []

# for i, (c23a, c23b, c21a, c21b) in enumerate(seed_points):
#     contrast, sim_im = get_stem_image_contrast(c23a, c23b, c21a, c21b)
#     seed_scores.append(contrast)
#     seed_images.append(sim_im)
#     print(f"Seed {i+1}/{n_seed}, C23a={c23a:.2f}, C23b={c23b:.2f},  C21a={c21a:.2f}, C21b={c21b:.2f}, contrast={contrast:.4f}")

# seed_scores = np.array(seed_scores)


# ========== QUERY SEED POINTS ==========
print(f"\nQuerying {n_seed} seed points...")
seed_scores = []
seed_images = []

for i, (c23a, c23b, c21a, c21b) in enumerate(seed_points):
    contrast, fft_score, sim_im = get_stem_image_contrast_and_fft(c23a, c23b, c21a, c21b)
    rewards = np.array((contrast, fft_score))
    seed_scores.append(rewards)
    seed_images.append(sim_im)
    # print(f"Seed {i+1}/{n_seed}, C23a={c23a:.2f}, C23b={c23b:.2f},  C21a={c21a:.2f}, C21b={c21b:.2f}, contrast={contrast:.4f}")

seed_scores = np.array(seed_scores)

#### 5b. MOBO

In [None]:
import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.fit import fit_gpytorch_mll
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
import numpy as np
import matplotlib.pyplot as plt


from botorch.acquisition.multi_objective import qLogExpectedHypervolumeImprovement 
from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning
from botorch.utils.multi_objective import is_non_dominated


# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# ========== INITIAL SEED POINTS (from previous code) ==========
print(f"Starting with {n_seed} seed points...")
print(f"Best initial contrast: {seed_scores.max():.4f}")

# Convert to tensors
train_X = torch.tensor(seed_points, dtype=torch.float64)
train_Y = torch.tensor(seed_scores, dtype=torch.float64)

# Define bounds for optimization
bounds = torch.tensor([
    [param_ranges['C23a'][0], param_ranges['C23b'][0], param_ranges['C21a'][0], param_ranges['C21b'][0]],  # lower bounds
    [param_ranges['C23a'][1], param_ranges['C23b'][1], param_ranges['C21a'][1], param_ranges['C21b'][1]]   # upper bounds
], dtype=torch.float64)



In [None]:
import torch

# ---- device & dtype ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
torch.set_default_dtype(dtype)

# move inputs/bounds to device+dtype
train_X = train_X.to(device=device, dtype=dtype)
train_Y = train_Y.to(device=device, dtype=dtype)
bounds  = bounds.to(device=device, dtype=dtype)

n_bo_steps = 50
all_X = train_X.clone()
all_Y = train_Y.clone()
all_images = seed_images.copy()

print("\n" + "="*60)
print("Starting Multi-Objective Bayesian Optimization with EHVI")
print("="*60)

ref_point = train_Y.min(dim=0).values - 0.1 * train_Y.std(dim=0)
print(f"Reference point: {ref_point.detach().cpu().numpy()}")

for step in range(n_bo_steps):
    print(f"\n--- BO Step {step + 1}/{n_bo_steps} ---")
    
    # Train GP (ensure model, transforms, mll on same device/dtype)
    print("Training Multi-Output GP...")
    gp_model = SingleTaskGP(
        all_X, all_Y,
        input_transform=Normalize(d=all_X.shape[-1]).to(device=device, dtype=dtype),
        outcome_transform=Standardize(m=all_Y.shape[-1]).to(device=device, dtype=dtype),
    ).to(device=device, dtype=dtype)
    
    gp_model.likelihood.noise_covar.initialize(noise=0.01)
    mll = ExactMarginalLogLikelihood(gp_model.likelihood, gp_model).to(device=device, dtype=dtype)
    fit_gpytorch_mll(mll)
    
    # EHVI acquisition
    print("Computing Pareto frontier...")
    pareto_mask = is_non_dominated(all_Y)
    pareto_Y = all_Y[pareto_mask]
    print(f"Pareto frontier size: {pareto_Y.shape[0]}")
    
    partitioning = NondominatedPartitioning(ref_point=ref_point, Y=pareto_Y)
    EHVI = qLogExpectedHypervolumeImprovement(
        model=gp_model,
        ref_point=ref_point.tolist(),  # list is fine for EHVI
        partitioning=partitioning,
    )
    
    # Optimize (bounds already on correct device/dtype)
    print("Optimizing acquisition function...")
    candidate, acq_value = optimize_acqf(
        acq_function=EHVI,
        bounds=bounds,
        q=1,
        num_restarts=10,
        raw_samples=100,
    )
    
    next_X = candidate.detach()  # stays on device
    next_params = next_X.squeeze().detach().cpu().numpy()  # CPU for simulator
    print(f"EHVI value: {acq_value:.6f}")
    
    # Query
    print("Querying STEM simulator...")
    objective1, objective2, next_image = get_stem_image_contrast_and_fft(
        next_params[0], next_params[1], next_params[2], next_params[3]
    )
    next_Y = torch.tensor([[objective1, objective2]], dtype=dtype, device=device)
    
    print(f"Observed objectives: [{objective1:.4f}, {objective2:.4f}]")
    
    # Update (tensors already on device)
    all_X = torch.cat([all_X, next_X], dim=0)
    all_Y = torch.cat([all_Y, next_Y], dim=0)
    all_images.append(next_image)
    
    new_pareto_mask = is_non_dominated(all_Y)
    if new_pareto_mask[-1]:
        print("✓ NEW PARETO POINT!")


Observed objectives: [0.4826, 0.0175]

--- BO Step 50/50 ---
Training Multi-Output GP...
Computing Pareto frontier...
Pareto frontier size: 39
Optimizing acquisition function...
EHVI value: -10.172721
Querying STEM simulator...
0.03
Observed objectives: [0.5048, 0.0156]
✓ NEW PARETO POINT!


#### 5c. Lets look at Pareto front

In [None]:
# ========== FIND EXTREME AND MID PARETO POINTS ==========
print("\n" + "="*60)
print("Pareto Frontier Analysis")
print("="*60)

final_pareto_mask = is_non_dominated(all_Y)
final_pareto_X = all_X[final_pareto_mask]
final_pareto_Y = all_Y[final_pareto_mask]
pareto_indices = torch.where(final_pareto_mask)[0].numpy()

print(f"Number of Pareto optimal points: {final_pareto_Y.shape[0]}")

# Find extreme points
extreme_indices = []

# Extreme for Objective 1
max_obj1_idx = torch.argmax(final_pareto_Y[:, 0]).item()
min_obj1_idx = torch.argmin(final_pareto_Y[:, 0]).item()

# Extreme for Objective 2
max_obj2_idx = torch.argmax(final_pareto_Y[:, 1]).item()
min_obj2_idx = torch.argmin(final_pareto_Y[:, 1]).item()

extreme_indices.extend([max_obj1_idx, min_obj1_idx, max_obj2_idx, min_obj2_idx])
extreme_indices = list(set(extreme_indices))  # Remove duplicates

# Find middle point (balanced trade-off)
# Normalize objectives to [0,1] then find point closest to (0.5, 0.5)
normalized_pareto_Y = (final_pareto_Y - final_pareto_Y.min(dim=0).values) / (final_pareto_Y.max(dim=0).values - final_pareto_Y.min(dim=0).values + 1e-8)
distances_to_center = torch.norm(normalized_pareto_Y - 0.5, dim=1)
mid_idx = torch.argmin(distances_to_center).item()

# Combine: extremes + mid
selected_indices = sorted(list(set(extreme_indices + [mid_idx])))

print(f"\nSelected Pareto points for visualization: {len(selected_indices)}")
for idx in selected_indices:
    pareto_idx = pareto_indices[idx]
    params = all_X[pareto_idx].numpy()
    obj1, obj2 = all_Y[pareto_idx, 0].item(), all_Y[pareto_idx, 1].item()
    
    label = ""
    if idx == max_obj1_idx:
        label += "[MAX Obj1] "
    if idx == min_obj1_idx:
        label += "[MIN Obj1] "
    if idx == max_obj2_idx:
        label += "[MAX Obj2] "
    if idx == min_obj2_idx:
        label += "[MIN Obj2] "
    if idx == mid_idx:
        label += "[MID/Balanced] "
    
    print(f"  {label}")
    print(f"    Obj1={obj1:.4f}, Obj2={obj2:.4f}")
    print(f"    C23a={params[0]:.2f}, C23b={params[1]:.2f}, C21a={params[2]:.2f}, C21b={params[3]:.2f}")

# ========== VISUALIZE ONLY EXTREME + MID PARETO IMAGES ==========
n_selected = len(selected_indices)
n_cols = min(3, n_selected)
n_rows = int(np.ceil(n_selected / n_cols))

fig = plt.figure(figsize=(7*n_cols, 7*n_rows))

for plot_idx, pareto_idx_in_frontier in enumerate(selected_indices):
    ax = fig.add_subplot(n_rows, n_cols, plot_idx + 1)
    
    pareto_idx = pareto_indices[pareto_idx_in_frontier]
    img = all_images[pareto_idx]
    params = all_X[pareto_idx].numpy()
    obj1, obj2 = all_Y[pareto_idx, 0].item(), all_Y[pareto_idx, 1].item()
    
    # Determine label
    label = ""
    if pareto_idx_in_frontier == max_obj1_idx:
        label = "MAX Obj1"
        color = 'red'
    elif pareto_idx_in_frontier == min_obj1_idx:
        label = "MIN Obj1"
        color = 'blue'
    elif pareto_idx_in_frontier == max_obj2_idx:
        label = "MAX Obj2"
        color = 'green'
    elif pareto_idx_in_frontier == min_obj2_idx:
        label = "MIN Obj2"
        color = 'orange'
    elif pareto_idx_in_frontier == mid_idx:
        label = "BALANCED (Mid)"
        color = 'purple'
    else:
        label = "Extreme"
        color = 'black'
    
    ax.imshow(np.array(img), cmap='gray')
    ax.set_title(
        f'{label}\n'
        f'Obj1={obj1:.4f}, Obj2={obj2:.4f}\n'
        f'C23a={params[0]:.2f}, C23b={params[1]:.2f}\n'
        f'C21a={params[2]:.2f}, C21b={params[3]:.2f}',
        fontsize=12,
        fontweight='bold',
        color=color
    )
    ax.axis('off')

plt.tight_layout()
plt.savefig('pareto_extreme_mid_images.png', dpi=150, bbox_inches='tight')
plt.show()

# ========== COMBINED RESULTS PLOT ==========
fig = plt.figure(figsize=(18, 6))

# 1. Objective space
ax1 = plt.subplot(1, 3, 1)
ax1.scatter(all_Y[:, 0].numpy(), all_Y[:, 1].numpy(), 
           c='lightblue', s=150, alpha=0.6, edgecolors='gray',
           label='All evaluations')
ax1.scatter(final_pareto_Y[:, 0].numpy(), final_pareto_Y[:, 1].numpy(), 
           c='lightcoral', s=200, alpha=0.5, edgecolors='black', 
           linewidths=1, label='Pareto frontier')

# Highlight extreme and mid points
colors = []
labels_legend = []
for idx in selected_indices:
    if idx == max_obj1_idx:
        colors.append('red')
        if 'MAX Obj1' not in labels_legend:
            labels_legend.append('MAX Obj1')
    elif idx == max_obj2_idx:
        colors.append('green')
        if 'MAX Obj2' not in labels_legend:
            labels_legend.append('MAX Obj2')
    elif idx == mid_idx:
        colors.append('purple')
        if 'Balanced' not in labels_legend:
            labels_legend.append('Balanced')
    else:
        colors.append('orange')

for idx, color in zip(selected_indices, colors):
    ax1.scatter(final_pareto_Y[idx, 0].numpy(), final_pareto_Y[idx, 1].numpy(),
               c=color, s=400, marker='*', edgecolors='black', linewidths=2, zorder=10)

ax1.set_xlabel('Objective 1 (Contrast)', fontsize=12)
ax1.set_ylabel('Objective 2 (Other Metric)', fontsize=12)
ax1.set_title('Pareto Frontier (★ = Extreme/Mid points)', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. BO progress
ax2 = plt.subplot(1, 3, 2)
iterations = range(len(all_Y))
ax2.plot(iterations, all_Y[:, 0].numpy(), 'o-', label='Objective 1', 
        alpha=0.7, linewidth=2, markersize=8)
ax2.plot(iterations, all_Y[:, 1].numpy(), 's-', label='Objective 2', 
        alpha=0.7, linewidth=2, markersize=8)
ax2.axvline(len(train_Y)-1, color='red', linestyle='--', 
          label='BO start', linewidth=2)
ax2.set_xlabel('Iteration', fontsize=12)
ax2.set_ylabel('Objective Value', fontsize=12)
ax2.set_title('BO Progress', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# 3. Hypervolume
from botorch.utils.multi_objective.hypervolume import Hypervolume

hv_computer = Hypervolume(ref_point=ref_point)
hypervolumes = []
for i in range(len(train_Y), len(all_Y) + 1):
    current_Y = all_Y[:i]
    pareto_mask_i = is_non_dominated(current_Y)
    pareto_Y_i = current_Y[pareto_mask_i]
    hv = hv_computer.compute(pareto_Y_i)
    hypervolumes.append(hv)

ax3 = plt.subplot(1, 3, 3)
ax3.plot(range(len(train_Y), len(all_Y) + 1), hypervolumes, 'o-', 
        linewidth=2, markersize=8, color='purple')
ax3.set_xlabel('Iteration', fontsize=12)
ax3.set_ylabel('Hypervolume', fontsize=12)
ax3.set_title('Hypervolume Improvement', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('multi_objective_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n=== SUMMARY ===")
print(f"Total evaluations: {len(all_Y)}")
print(f"Pareto frontier size: {len(pareto_indices)}")
print(f"Extreme + Mid points shown: {len(selected_indices)}")
print(f"Final hypervolume: {hypervolumes[-1]:.4f}")

### 6. Suggested Exercises:
- a) Try different reward functions
- b) Try other multiobjective acquisiton functions
- c) Try methods other than gaussian processes
- d) Try PCA-GP or VAE-GP kind of methods
- e) Go for higher order aberrations - and get crazy with choosing parameters
