# Remote smart STEM with AutoScript: Tutorial for Arems 2025 
- Contributors:
    - Gerd Duscher
    - Utkarsh Pratiush
    - Austin Houston

- Outline:
    - Connect to the microscope server
    - Get HAADF image
    - Find particles using segmentation:
        - selct 10 particles with highest area
            - Get edx
            - do current caliberation using gun lens
            - Get edx at higher current
            - Get diffraction






## 1a. import and set paths 

In [None]:
from stemOrchestrator.logging_config   import setup_logging
data_folder  = "."
out_path = data_folder
setup_logging(out_path=out_path) 

In [None]:
from stemOrchestrator.acquisition import TFacquisition, DMacquisition
from stemOrchestrator.simulation import DMtwin
from stemOrchestrator.process import HAADF_tiff_to_png, tiff_to_png
from autoscript_tem_microscope_client import TemMicroscopeClient
import matplotlib.pyplot as plt
import logging
plot = plt
from typing import Dict

## 1b. Connect to the Microscope server

In [None]:
import os
import json
from pathlib import Path

ip = os.getenv("MICROSCOPE_IP")
port = os.getenv("MICROSCOPE_PORT")

if not ip or not port:
    secret_path = Path("../../../config_secret.json")
    if secret_path.exists():
        with open(secret_path, "r") as f:
            secret = json.load(f)
            ip = ip or secret.get("ip_TF")
            port = port or secret.get("port_TF")



config = {
    "ip": ip,
    "port": port,
    "haadf_exposure": 40e-8,  # micro-seconds per pixel
    "haadf_resolution": 512, # square
    "out_path": "."
}

In [None]:

ip = config["ip"]
port = config["port"]
haadf_exposure = config["haadf_exposure"]
out_path = config["out_path"]
haadf_resolution = config["haadf_resolution"]



microscope = TemMicroscopeClient()
microscope.connect(ip, port = port)# 7521 on velox  computer
# microscope.connect( port = port)# 7521 on velox  computer

# query state:

tf_acquisition = TFacquisition(microscope=microscope)

# put beam shift to 0,0
# tf_acquisition.move_beam_shift_positon([0, 0])


## 1c. Get HAADF image

In [None]:
# Get haadf
haadf_np_array, haadf_tiff_name = tf_acquisition.acquire_haadf(exposure = haadf_exposure, resolution=haadf_resolution)

HAADF_tiff_to_png(haadf_tiff_name)

## 1d. Find particles using segmentation

In [None]:
########SAM part ********************************************************************************************************

from stemOrchestrator.MLlayer.MLlayerSAM import setup_device, download_sam_model, initialize_sam_model, preprocess_image, generate_and_save_masks, create_normalized_particle_positions, display_image_with_masks, display_image_with_labels, extract_mask_contours, generate_mask_colors, visualize_masks_with_boundaries, extract_particle_data, print_boundary_points_info, plot_centroids, sample_particle_positions, plot_sampled_positions, create_visualization_with_masks
import pickle
import numpy as np
from typing import List, Dict, Union


def run_sam(image_data: np.ndarray, path_folder: str) -> Union[List, Dict]:
    """Main function to run SAM segmentation pipeline."""
    device = setup_device()
    
    model_type = "vit_b"  # Options: 'vit_b', 'vit_l', 'vit_h'
    checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
    checkpoint_path = "sam_vit_b_01ec64.pth"
    download_sam_model(model_type, checkpoint_url, checkpoint_path)
    sam, mask_generator = initialize_sam_model(model_type, checkpoint_path, device)
    img_np = preprocess_image(image_data)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()
    
    # Generate and visualize masks
    masks_path = f'{path_folder}/masks_Au_online.pkl'
    masks = generate_and_save_masks(mask_generator, img_np, masks_path)
    visual_image, centroids = create_visualization_with_masks(img_np, masks)
    display_image_with_masks(visual_image, "Image with Segmentation Masks")
    display_image_with_labels(visual_image, centroids, "Image with Segmentation Masks and Labels")
    
    mask_contours = extract_mask_contours(masks)
    mask_colors = generate_mask_colors(len(masks))
    boundaries_path = f"{path_folder}/Segmentation Masks with Boundaries and Centroids.png"
    visualize_masks_with_boundaries(visual_image, centroids, mask_contours, mask_colors, boundaries_path)
    particles = extract_particle_data(masks)
    # Save particle data
    # with open(f'{path_folder}/particles.pkl', 'wb') as f:
    #     pickle.dump(particles, f)
    
    print_boundary_points_info(particles)
    centroids_array = np.array(centroids)
    plot_centroids(centroids_array, img_np)
    positions_sampled = sample_particle_positions(particles, img_np)
    plot_sampled_positions(positions_sampled, img_np, len(centroids))
    each_particle_position = create_normalized_particle_positions(particles, img_np.shape[:2])
    # with open(f'{path_folder}/sampled_boundary_pts_particles.pkl', 'wb') as f: # Save normalized particle positions
    #     pickle.dump(each_particle_position, f)
    
    all_particle_keys = each_particle_position.keys()


    
    print("Processing complete!")
    return all_particle_keys, each_particle_position

##########****************************************************************************************************************************


In [None]:
## run the segmentaiotn on haadf to get particles
all_particle_keys, each_particle_position = run_sam(haadf_np_array, out_path)




In [None]:
# filter particles based on need

## 1e. Get dummy edx at center of overview HAADF

In [None]:
import xmltodict
import json
from autoscript_tem_microscope_client.structures import EdsAcquisitionSettings
from autoscript_tem_microscope_client.enumerations import  EdsDetectorType, ExposureTimeType


def get_channel_index(energy_keV: float, dispersion: float, offset: float) -> int:
    """Convert energy (keV) into spectrum channel index."""
    return int(round((energy_keV - offset) / dispersion))



def get_dispersion_and_offset(spectrum):
    """
    Extract dispersion and offset from EDS spectrum metadata (xml).
    Returns (dispersion_keV_per_ch, offset_keV).
    """
    xml_string = spectrum.metadata.metadata_as_xml
    metadata = xmltodict.parse(xml_string)
    metadata = json.loads(json.dumps(metadata))

    detectors = metadata["Metadata"]["Detectors"]["AnalyticalDetector"]

    # If only one detector, wrap it into a list
    if isinstance(detectors, dict):
        detectors = [detectors]

    # Take the first detector (or filter by name if needed)
    det = detectors[0]
    dispersion = float(det.get("Dispersion", 0))
    offset = float(det.get("OffsetEnergy", 0))

    return dispersion, offset

def configure_acquisition(exposure_time=2):
    """Configure the EDS acquisition settings."""
    # mic_server is global variable intriduced in def run function
    eds_detector_name = microscope.detectors.eds_detectors[0]
    eds_detector = microscope.detectors.get_eds_detector(eds_detector_name)
    # Configure the acquisition
    global eds_settings
    eds_settings = EdsAcquisitionSettings()
    eds_settings.eds_detector = eds_detector_name
    eds_settings.dispersion = eds_detector.dispersions[-1]# 20 keV
    eds_settings.shaping_time = eds_detector.shaping_times[-1]
    eds_settings.exposure_time = exposure_time
    eds_settings.exposure_time_type = ExposureTimeType.LIVE_TIME
    return eds_settings

In [None]:
# position beam at center
microscope.optics.paused_scan_beam_position = (0.5,0.5) 

In [None]:
# Acquire the EDS spectrum
edx_exposure = 1 # in seconds
eds_settings = configure_acquisition(exposure_time=edx_exposure)

microscope.optics.unblank()
spectrum = microscope.analysis.eds.acquire_spectrum(eds_settings)
microscope.optics.blank()


## plotting the spectrum
# Average spectrum data from 4 detectors
n_channels_per_detector = len(spectrum.data) // 4
summed_spectrum = np.zeros(n_channels_per_detector)

for i in range(4):
    start_idx = i * n_channels_per_detector
    end_idx = (i + 1) * n_channels_per_detector
    summed_spectrum += spectrum.data[start_idx:end_idx]

# Use summed spectrum for analysis
spectrum_data = summed_spectrum

# Plot spectrum using matplotlib instead of vision_toolkit
dispersion, offset = get_dispersion_and_offset(spectrum)
energy_axis = (np.arange(len(spectrum_data)) * dispersion + offset)/1000 # 1000 for Kev

plt.figure(figsize=(12, 6))
plt.plot(energy_axis, spectrum_data)
plt.xlabel('Energy (keV)')
plt.ylabel('Counts')
plt.title('EDS Spectrum (Summed from 4 Detectors)')
plt.xlim(0, 20)  # Focus on physically relevant energy range

## 1f. Lets caliberate the gun lens and increase current for better EDX signal
- Gun lens and screen current caliberation 
- credits:  Austin Houston

- Go to a hole and park the beam there. We want to see how beam current goes with gun lens



In [None]:
original_gun_lens = microscope.optics.monochromator.focus
print(original_gun_lens)

In [None]:
gun_lens_series = np.linspace(5, 100, 20)
current_series = []
import time

for val in gun_lens_series:
    # set lens value
    microscope.optics.monochromator.focus = val # original_gun_lens + val

    # wait
    time.sleep(1)

    # measure current
    screen_current = microscope.detectors.screen.measure_current()
    current_series.append(screen_current)

current_series = np.array(current_series) * 1e12

# reset to original
microscope.optics.monochromator.focus = original_gun_lens

In [None]:
# fit a polynomial:
degree = 11
coeffs = np.polyfit(gun_lens_series, current_series, degree)
poly_func = np.poly1d(coeffs)

# generate fitted values
x_fit = np.linspace(min(gun_lens_series), max(gun_lens_series), 500)
y_fit = poly_func(x_fit)

# plot
plt.figure()
plt.plot(x_fit, y_fit)
plt.scatter(gun_lens_series, current_series, marker='X', c='r')

plt.xlabel('Gun Lens Value')
plt.ylabel('Screen Current')

In [None]:
def current_to_gun(desired_current, poly_func):
    # Define the new polynomial: poly_func(x) - y_target = 0
    adjusted_poly = poly_func - desired_current
    # Find the roots
    x_candidates = adjusted_poly.r
    # Filter for real solutions only (since roots might be complex)
    x_real = x_candidates[np.isreal(x_candidates)].real

    if len(x_real) == 1:
        return(x_real)
    else:
        return(np.max(x_real))

def set_current(desired_current):
    gun_value = current_to_gun(desired_current, poly_func)
    microscope.optics.monochromator.focus = gun_value

In [None]:
# Example
# we want screen current 60 pA for imaging
desired_current = 60 # pA
gun_val = current_to_gun(desired_current=desired_current, poly_func=poly_func)
microscope.optics.monochromator.focus = float(gun_val)

print(f'Set to: {desired_current} with gun value: {gun_val}')

In [None]:
# microscope.optics.paused_scan_beam_position = (0,0) # park the beam in the corner

## 1g. Get EDX at center of particles

In [None]:

def acquire_and_plot_combined(image_data, particle_key):
    """Acquire EDS spectrum, CETA image, and plot all three components in a single figure."""
    # lets do just eels at boundary points
    print(f"edx at centers for particle{particle_key}")


    # Acquire the EDS spectrum
    microscope.optics.unblank()
    spectrum = microscope.analysis.eds.acquire_spectrum(eds_settings)
    microscope.optics.blank()

    # Average spectrum data from 4 detectors
    n_channels_per_detector = len(spectrum.data) // 4
    summed_spectrum = np.zeros(n_channels_per_detector)

    for i in range(4):
        start_idx = i * n_channels_per_detector
        end_idx = (i + 1) * n_channels_per_detector
        summed_spectrum += spectrum.data[start_idx:end_idx]

    # Use summed spectrum for analysis
    spectrum_data = summed_spectrum

    # Plot spectrum using matplotlib instead of vision_toolkit
    dispersion, offset = get_dispersion_and_offset(spectrum)
    energy_axis = (np.arange(len(spectrum_data)) * dispersion + offset)/1000 # 1000 for Kev


    # Get the current beam position
    position = tf_acquisition.query_paused_beam_positon()
    x = position.x
    y = position.y
    formatted_position = f"({x:.2g}, {y:.2g})"


    # Create a figure with three subplots
    fig, axs = plt.subplots(1, 2, figsize=(18, 6))
    axs[0].imshow(image_data, cmap='gray')
    axs[0].set_title('Acquired Image')
    axs[0].set_axis_off()  # Hide axes for the image plot
    axs[0].scatter(x * image_data.shape[0], y * image_data.shape[1], c='r', s=100, marker='x', label=f"Position: {formatted_position}")

    # Plot the EDS spectrum
    axs[1].plot(energy_axis, spectrum_data)
    axs[1].set_title('EDS Spectrum (Summed from 4 Detectors)')
    axs[1].set_xlabel('Energy (keV)')
    axs[1].set_ylabel('Counts')
    axs[1].set_xlim(0, 20) 
    axs[1].legend()
    
    plt.tight_layout()
    plt.show()
    pass


    
        
def run_acquisition_for_particle(image_data, particle_key, particle_dict):
    """Run the entire acquisition process for a given position."""
    position = list(particle_dict[particle_key]["centroid"])
    tf_acquisition.move_paused_beam(position[0], position[1])
    acquire_and_plot_combined(image_data, particle_key, particle_dict)
    pass

In [None]:
for particle_key in all_particle_keys:
    run_acquisition_for_particle(haadf_np_array, particle_key, each_particle_position)

## 1h. Get diffraction at the particle centres

In [None]:


def acquire_and_plot_combined(image_data, particle_key, ceta_exposure):
    """Acquire EDS spectrum, CETA image, and plot all three components in a single figure."""
    
    
    # lets do just eels at boundary points
    print(f"cbed at centres for particle{particle_key}")
    wait_sec = 2
    # print(f"ceta-wait insert{wait_sec} sec")
    # time.sleep(wait_sec)
    ceta_cp_array, ceta_tiff_name = tf_acquisition.acquire_ceta_or_flucam(exposure=ceta_exposure, resolution=1024, camera="ceta")
    tiff_to_png(ceta_tiff_name)
    

    # clip the bright spots
    shifted_data = ceta_cp_array
    p99 = np.percentile(shifted_data.ravel(), 99)
    clipped_data = np.clip(shifted_data, 0, p99)
    clipped_data -= clipped_data.min()
    clipped_data /= clipped_data.max()
    norm_data = clipped_data
    # power law 2nd time through
    gamma = 1
    norm_data = norm_data ** gamma
    edge_crop = 256
    norm_data = norm_data[edge_crop:-edge_crop, edge_crop:-edge_crop]
    
    
    # Get the current beam position
    position = tf_acquisition.query_paused_beam_positon()
    x = position.x
    y = position.y
    formatted_position = f"({x:.2g}, {y:.2g})"

    # Create a figure with three subplots
    fig, axs = plt.subplots(1, 2, figsize=(18, 6))

    # Plot the acquired image data
    axs[0].imshow(image_data, cmap='gray')
    axs[0].set_title('Acquired Image')
    axs[0].set_axis_off()  # Hide axes for the image plot
    axs[0].scatter(x * image_data.shape[0], y * image_data.shape[1], c='r', s=100, marker='x', label=f"Position: {formatted_position}")
    axs[1].imshow((norm_data), cmap='gray')  # Using log contrast
    # axs[1].imshow(np.log(center_quarter + 1), cmap='gray')  # Using log contrast
    axs[1].set_title(f'Acquired CETA Image at Position: {formatted_position}')
    axs[1].set_axis_off()  # Hide axes for the image plot


    # Adjust layout and display
    plt.tight_layout()
    plt.show()
    pass



def run_acquisition_for_particle(image_data, particle_key, particle_dict, ceta_exposure):
    """Run the entire acquisition process for a given position."""
    position = list(particle_dict[particle_key]["centroid"])
    tf_acquisition.move_paused_beam(position[0], position[1])
    acquire_and_plot_combined(image_data, particle_key, ceta_exposure)




    

In [None]:

ceta_exposure = 0.1 # seconds

# Save the dictionary to disk

for particle_key in all_particle_keys:
    run_acquisition_for_particle(haadf_np_array,particle_key, each_particle_position, ceta_exposure)# 25 seconds per reading: --> 1min 35 seconds for 4 points

##