## Segment Anything Model (SAM) - EDX: live workflow
- Acquire an image using HAADF detector
- Feed HAADF image in SAM pipeline:
        - get particles with segmented masks
        - find center point of each particle
- Acquire EDX detector signals at center of each particle
#### Contributor(s): Utkarsh Pratiush <utkarshp1161@gmail.com> - 2nd May 2025
#### edited - 
   

In [1]:
from stemOrchestrator.logging_config   import setup_logging
data_folder  = "."
out_path = data_folder
setup_logging(out_path=out_path) 

In [2]:
from stemOrchestrator.acquisition import TFacquisition, DMacquisition
from stemOrchestrator.simulation import DMtwin
from autoscript_tem_microscope_client.enumerations import EdsDetectorType
from stemOrchestrator.process import HAADF_tiff_to_png, tiff_to_png
from autoscript_tem_microscope_client import TemMicroscopeClient
import matplotlib.pyplot as plt
import logging
plot = plt
from typing import Dict
import os

In [3]:
########SAM part ********************************************************************************************************

from stemOrchestrator.MLlayer.MLlayerSAM import setup_device, download_sam_model, initialize_sam_model, preprocess_image, generate_and_save_masks, create_normalized_particle_positions, display_image_with_masks, display_image_with_labels, extract_mask_contours, generate_mask_colors, visualize_masks_with_boundaries, extract_particle_data, print_boundary_points_info, plot_centroids, sample_particle_positions, plot_sampled_positions, create_visualization_with_masks
import pickle
import numpy as np
from numpy.typing import NDArray 
from typing import List, Dict, Union


def run_sam(image_data: np.ndarray, path_folder: str) -> Union[List, Dict]:
    """Main function to run SAM segmentation pipeline."""
    device = setup_device()
    
    model_type = "vit_b"  # Options: 'vit_b', 'vit_l', 'vit_h'
    checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
    checkpoint_path = "sam_vit_b_01ec64.pth"
    download_sam_model(model_type, checkpoint_url, checkpoint_path)
    sam, mask_generator = initialize_sam_model(model_type, checkpoint_path, device)
    img_np = preprocess_image(image_data)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()
    
    # Generate and visualize masks
    masks_path = f'{path_folder}/masks_Au_online.pkl'
    masks = generate_and_save_masks(mask_generator, img_np, masks_path)
    visual_image, centroids = create_visualization_with_masks(img_np, masks)
    display_image_with_masks(visual_image, "Image with Segmentation Masks")
    display_image_with_labels(visual_image, centroids, "Image with Segmentation Masks and Labels")
    
    mask_contours = extract_mask_contours(masks)
    mask_colors = generate_mask_colors(len(masks))
    boundaries_path = f"{path_folder}/Segmentation Masks with Boundaries and Centroids.png"
    visualize_masks_with_boundaries(visual_image, centroids, mask_contours, mask_colors, boundaries_path)
    particles = extract_particle_data(masks)
    # Save particle data
    with open(f'{path_folder}/particles.pkl', 'wb') as f:
        pickle.dump(particles, f)
    
    print_boundary_points_info(particles)
    centroids_array = np.array(centroids)
    plot_centroids(centroids_array, img_np)
    positions_sampled = sample_particle_positions(particles, img_np)
    plot_sampled_positions(positions_sampled, img_np, len(centroids))
    each_particle_position = create_normalized_particle_positions(particles, img_np.shape[:2])
    with open(f'{path_folder}/sampled_boundary_pts_particles.pkl', 'wb') as f: # Save normalized particle positions
        pickle.dump(each_particle_position, f)
    
    all_particle_keys = each_particle_position.keys()


    
    print("Processing complete!")
    return all_particle_keys, each_particle_position

##########****************************************************************************************************************************


In [4]:
def main(config :Dict) -> None:
    ip = config["ip"]
    port = config["port"]
    haadf_exposure = config["haadf_exposure"]
    haadf_resolution = config["haadf_resolution"]
    out_path = config["out_path"]
    microscope_tf = TemMicroscopeClient()
    
    
    global tf_acquisition
    tf_acquisition = TFacquisition(microscope=microscope_tf, offline=True)

    # Get haadf
    haadf_np_array, haadf_tiff_name = tf_acquisition.acquire_haadf(exposure = haadf_exposure, resolution=haadf_resolution)

    HAADF_tiff_to_png(haadf_tiff_name)
    logging.info("END acquisition.")
    
    all_particle_keys, each_particle_position = run_sam(haadf_np_array, out_path)
    
    

    def acquire_and_plot_combined(image_data, eds_detector_name, particle_key, particle_dict, edx_exposure):
        """Acquire EDS spectrum, CETA image, and plot all three components in a single figure."""
        
        
        # lets do just eels at boundary points
        print(f"edx at centers for particle{particle_key}")
        # Ensure the directory exists
        directory = f'{out_path}/particle{particle_key}'
        if not os.path.exists(directory):
            os.makedirs(directory)


        # Acquire the EDS spectrum
        settings = tf_acquisition.configure_eds_settings(eds_detector_name, dispersion=5, shaping_time=3e-6, exposure_time=edx_exposure)
        tf_acquisition.unblank_beam()
        spec = tf_acquisition.acquire_eds(settings)
        tf_acquisition.blank_beam()


        # Get the current beam position
        position = tf_acquisition.query_paused_beam_positon()
        x = position.x
        y = position.y
        formatted_position = f"({x:.2g}, {y:.2g})"

        # Create a figure with three subplots
        fig, axs = plt.subplots(1, 2, figsize=(18, 6))
        axs[0].imshow(image_data, cmap='gray')
        axs[0].set_title('Acquired Image')
        axs[0].set_axis_off()  # Hide axes for the image plot
        axs[0].scatter(x * image_data.shape[0], y * image_data.shape[1], c='r', s=100, marker='x', label=f"Position: {formatted_position}")

        # Plot the EDS spectrum
        axs[1].plot(np.arange(len(spec)) * 5 / 1000, spec, label="EDS Spectrum")
        axs[1].set_title('Acquired EDS Spectrum')
        axs[1].set_xlabel('Channel (KeV)')
        axs[1].set_ylabel('Counts')
        axs[1].legend()

        plt.tight_layout()
        plt.savefig(f'{out_path}/particle{particle_key}/haadf_edx_at_centroid of particle{particle_key}_{formatted_position}.png', dpi = 300)
        plt.close()
        np.save(f'{out_path}/particle{particle_key}/EDX_spectrum at position {formatted_position}.npy', spec)# eds spectrum save


        
    def run_acquisition_for_particle(image_data, particle_key, particle_dict, edx_exposure):
            """Run the entire acquisition process for a given position."""
            position = list(particle_dict[particle_key]["centroid"])
            tf_acquisition.move_paused_beam(position[0], position[1])
            acquire_and_plot_combined(image_data, eds_detector_name, particle_key, particle_dict, edx_exposure)

    spec_metadata_dict = {"edx_exposure_sec" : 0.05}

    edx_exposure = spec_metadata_dict["edx_exposure_sec"]

    file_path = f'{out_path}/spectrometer_metadata_dict.pkl'
    # Save the dictionary to disk
    with open(file_path, 'wb') as file:
        pickle.dump(spec_metadata_dict, file)
    print(f"spectrometer_metadata_dictsaved to {file_path}")
    eds_detector_name = EdsDetectorType.SUPER_X
    for particle_key in all_particle_keys:
        run_acquisition_for_particle(haadf_np_array, eds_detector_name, particle_key, each_particle_position, edx_exposure)

    

In [5]:
import os
import json
from pathlib import Path

# Initialize to None
ip = os.getenv("MICROSCOPE_IP")
port = os.getenv("MICROSCOPE_PORT")


if not ip or not port:
    secret_path = Path("../../config_secret.json")
    if secret_path.exists():
        with open(secret_path, "r") as f:
            secret = json.load(f)
            ip = ip or secret.get("ip_TF")
            port = port or secret.get("port_TF")


if not ip:
    ip = input("Enter microscope IP: ")
if not port:
    port = input("Enter microscope Port: ")

try:
    port = int(port)
except ValueError:
    raise ValueError("Port must be an integer")

config = {
    "ip": ip,
    "port": port,
    "haadf_exposure": 40e-8,  # microseconds per pixel
    "haadf_resolution": 512,  # square
    "out_path": "."
}

main(config=config)

Failed to initialize detectors: Cannot perform a call because the client is not running.


ApiException: Cannot perform a call because the client is not running.