## Segment Anything Model (SAM) - DIFFRACTION(ceta): live workflow
- Acquire an image using HAADF detector
- Feed HAADF image in SAM pipeline:
        - get particles with segmented masks
        - find center point of each particle
- Acquire diffraction(ceta) detector signals at center of each particle
#### Contributor(s): Utkarsh Pratiush <utkarshp1161@gmail.com> - 31st March 2025
#### edited - 
   

In [None]:
from stemOrchestrator.acquisition import TFacquisition, DMacquisition
from TEM.stemOrchestrator.stemOrchestrator.simulation import DMtwin
from stemOrchestrator.process import HAADF_tiff_to_png
from autoscript_tem_microscope_client import TemMicroscopeClient
from stemOrchestrator.logging_config   import setup_logging
import matplotlib.pyplot as plt
import logging
plot = plt
from typing import Dict

In [None]:
########SAM part ********************************************************************************************************

from stemOrchestrator.MLlayerSAM import setup_device, download_sam_model, initialize_sam_model, preprocess_image, generate_and_save_masks, create_normalized_particle_positions, display_image_with_masks, display_image_with_labels, extract_mask_contours, generate_mask_colors, visualize_masks_with_boundaries, extract_particle_data, print_boundary_points_info, plot_centroids, sample_particle_positions, plot_sampled_positions, create_visualization_with_masks
import pickle
import numpy as np
from numpy.typing import NDArray 


def run_sam(image_data: np.ndarray, path_folder: str) -> NDArray:
    """Main function to run SAM segmentation pipeline."""
    device = setup_device()
    
    model_type = "vit_b"  # Options: 'vit_b', 'vit_l', 'vit_h'
    checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
    checkpoint_path = "sam_vit_b_01ec64.pth"
    download_sam_model(model_type, checkpoint_url, checkpoint_path)
    sam, mask_generator = initialize_sam_model(model_type, checkpoint_path, device)
    img_np = preprocess_image(image_data)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()
    
    # Generate and visualize masks
    masks_path = f'{path_folder}/masks_Au_online.pkl'
    masks = generate_and_save_masks(mask_generator, img_np, masks_path)
    visual_image, centroids = create_visualization_with_masks(img_np, masks)
    display_image_with_masks(visual_image, "Image with Segmentation Masks")
    display_image_with_labels(visual_image, centroids, "Image with Segmentation Masks and Labels")
    
    mask_contours = extract_mask_contours(masks)
    mask_colors = generate_mask_colors(len(masks))
    boundaries_path = f"{path_folder}/Segmentation Masks with Boundaries and Centroids.png"
    visualize_masks_with_boundaries(visual_image, centroids, mask_contours, mask_colors, boundaries_path)
    particles = extract_particle_data(masks)
    # Save particle data
    with open(f'{path_folder}/particles.pkl', 'wb') as f:
        pickle.dump(particles, f)
    
    print_boundary_points_info(particles)
    centroids_array = np.array(centroids)
    plot_centroids(centroids_array, img_np)
    positions_sampled = sample_particle_positions(particles, img_np)
    plot_sampled_positions(positions_sampled, img_np, len(centroids))
    each_particle_position = create_normalized_particle_positions(particles, img_np.shape[:2])
    with open(f'{path_folder}/sampled_boundary_pts_particles.pkl', 'wb') as f: # Save normalized particle positions
        pickle.dump(each_particle_position, f)
    


    
    print("Processing complete!")
    return centroids_array

##########****************************************************************************************************************************


In [None]:
def main(config :Dict) -> None:
    ip = config["ip"]
    port = config["port"]
    haadf_exposure = config["haadf_exposure"]
    haadf_resolution = config["haadf_resolution"]
    out_path = config["out_path"]
    setup_logging(out_path=out_path)
    

    microscope_tf = TemMicroscopeClient()
    microscope_tf.connect(ip, port = port)# 7521 on velox  computer
    microscope_dm = DMtwin()
    # query state:

    global tf_acquisition
    tf_acquisition = TFacquisition(microscope=microscope_tf, offline=True)
    # dm_acquisition = DMacquisition(microscope=microscope_dm, offline=True)

    # Get haadf
    haadf_np_array, haadf_tiff_name = tf_acquisition.acquire_haadf(exposure = haadf_exposure, resolution=haadf_resolution)

    HAADF_tiff_to_png(haadf_tiff_name)
    logging.info("END acquisition.")
    
    centroids_array = run_sam(haadf_np_array, out_path)
    for points in centroids_array:
        ceta_cp_array, ceta_tiff_name = tf_acquisition.acquire_ceta(exposure=0.1, resolution=64)
    

## Connect to the microscope

In [None]:
import os
import json
from pathlib import Path

ip = os.getenv("MICROSCOPE_IP")
port = os.getenv("MICROSCOPE_PORT")

if not ip or not port:
    secret_path = Path("../../config_secret.json")
    if secret_path.exists():
        with open(secret_path, "r") as f:
            secret = json.load(f)
            ip = ip or secret.get("ip_TF")
            port = port or secret.get("port_TF")


if not ip:
    ip = input("Enter microscope IP: ")
if not port:
    port = input("Enter microscope Port: ")

try:
    port = int(port)
except ValueError:
    raise ValueError("Port must be an integer")

config = {
    "ip": ip,
    "port": port,
    "haadf_exposure": 40e-8,  # micro-seconds per pixel
    "haadf_resolution": 512, # square
    "out_path": "."
}

main(config=config)


In [None]:
tf_acquisition.ceta_cam.insert
ceta_cp_array, ceta_tiff_name = tf_acquisition.acquire_ceta(exposure=0.1, resolution=1024)
