## Script to do drift correction using AutoScript
- Get refernce image - ref_image
- Get new image - new_image
- Get pixel correction: Pixel_movement_x, Pixel_movement_y -->drift_correction(ref_image, new_image)
- Get pixel size from new_image
- movement_x, movement_y = Pixel_size_x*Pixel_movement_x, Pixel_size_y*Pixel_movement_y
- acquire a new image after drift correction



In [None]:
# the Autoscript packages

from autoscript_tem_microscope_client import TemMicroscopeClient
from autoscript_tem_microscope_client.enumerations import *
from autoscript_tem_microscope_client.structures import *
import numpy as np
# General packages
import os, time, sys, math

# General image processing packages
from matplotlib import pyplot as plot
import numpy as np
import cv2 as cv



In [None]:

microscope = TemMicroscopeClient()
ip = ""
if ip == "":
    ip = input("Please enter the IP address of the microscope: ")
microscope.connect(ip)
print("Connected to the microscope")

In [None]:
import numpy as np
from scipy.ndimage import fourier_shift
import matplotlib.pyplot as plot

def drift_correction(image1: np.ndarray, image2: np.ndarray) -> tuple[int, int]:
    """Returns pixel movement in X and Y directions between two images due to drift.

    Args:
        image1 (np.ndarray): The first image (reference image).
        image2 (np.ndarray): The second image (drifted image).

    Returns:
        Pixel_movement_x (int): Number of pixels moved in the X direction.
        Pixel_movement_y (int): Number of pixels moved in the Y direction.
    """
    
    # Compute cross-correlation using Fourier transform
    f_image1 = np.fft.fftn(image1)
    f_image2 = np.fft.fftn(image2)
    cross_corr = np.fft.ifftn(f_image1 * np.conj(f_image2))
    
    # Find peak of the cross-correlation
    maxima = np.unravel_index(np.argmax(np.abs(cross_corr)), cross_corr.shape)
    
    # Calculate pixel shift (movement)
    shifts = np.array(maxima)
    shifts = np.array(shifts, dtype=np.float64)
    
    # Correct for wrapping due to Fourier transform
    shifts[shifts > np.array(image1.shape) // 2] -= np.array(image1.shape)[shifts > np.array(image1.shape) // 2]
    
    Pixel_movement_x, Pixel_movement_y = int(shifts[1]), int(shifts[0])
    
    return Pixel_movement_x, Pixel_movement_y




In [None]:
# Step 1: Get reference image
ref_image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, 256, 4e-5).data


# Normalize the new image for visualization (optional)
img = ref_image - np.min(ref_image)
ref_image_data = (255 * (img / np.max(img))).astype(np.uint8)

# Plot the acquired new image
fig = plot.figure(figsize=(6, 6))
plot.imshow(ref_image_data, cmap='gray')
plot.title('Acquired image')
plot.show()


In [None]:

# Step 2: Acquire new image
new_image_haadf = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, 256, 4e-5)

# Normalize the new image for visualization (optional)
img = new_image_haadf.data - np.min(new_image_haadf.data)
new_image_data = (255 * (img / np.max(img))).astype(np.uint8)

# Plot the acquired new image
fig = plot.figure(figsize=(6, 6))
plot.imshow(new_image_data, cmap='gray')
plot.title('Acquired image')
plot.show()



In [None]:
# Step 3: Get pixel correction
Pixel_movement_x, Pixel_movement_y = drift_correction(ref_image_data, new_image_data)

print(f'Pixel movement in X direction: {Pixel_movement_x}', f'Pixel movement in Y direction: {Pixel_movement_y}')

In [None]:
# Step 4: Get pixel size from new_image
pixelsize_x = new_image_haadf.metadata.binary_result.pixel_size.x  # in meters/pixel
pixelsize_y = new_image_haadf.metadata.binary_result.pixel_size.y  # in meters/pixel

# Step 5: Calculate actual movement in physical space
movement_x = pixelsize_x * Pixel_movement_x  # in meters
movement_y = pixelsize_y * Pixel_movement_y  # in meters

print(f'Actual movement in X direction: {movement_x} m', f'Actual movement in Y direction: {movement_y} m')

In [None]:
## drift correction by moving stage
# current_position = microscope.specimen.stage.position

# # Calculate new position after drift correction
# new_position_x = current_position.x - movement_x
# new_position_y = current_position.y - movement_y

# print(f'Current stage position: {current_position.x} m, {current_position.y} m')
# microscope.specimen.stage.absolute_move([new_position_x, new_position_y])

In [None]:
# drift correction by moving beam
old_x, old_y = microscope.optics.deflectors.beam_shift 
microscope.optics.deflectors.beam_shift = [old_x + movement_x, old_y + movement_y]

In [None]:
# Optionally, you can acquire another image to confirm the drift correction
corrected_image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, 256, 4e-5).data

# Plot the corrected image (optional)
fig = plot.figure(figsize=(6, 6))
plot.imshow(corrected_image, cmap='gray')
plot.title('Corrected image after drift correction')
plot.show()