# Segmentation trial acquisitions (calcium imaging)
* Segmentation of (motion corrected) images


In [11]:
import numpy as np
from cellpose import models
import napari
import os
from scripts.sample_db import SampleDB
from tifffile import imwrite, imread
import glob
from itertools import product

# Load the sample database
db_path = r'\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\sample_db.csv'
sample_db = SampleDB()
sample_db.load(db_path)
print(sample_db)

SampleDB(sample_ids=['20220426_RM0008_130hpf_fP1_f3', '20220118_RM0012_124hpf_fP8_f2', '20220427_RM0008_126hpf_fP3_f3'])


In [12]:
# Loading experiment
sample_id = '20220427_RM0008_126hpf_fP3_f3'
exp = sample_db.get_sample(sample_id)
print(exp.sample.id)

# Import model
model_path = r'D:\montruth\cellpose\models\CP_20230803_101131' 
model = models.CellposeModel(model_type=model_path, gpu=True)

# Making shortcuts of sample parameters/information
sample = exp.sample
root_path = exp.paths.root_path
trials_path = exp.paths.trials_path

n_planes = exp.params_lm.n_planes
n_frames = exp.params_lm.n_frames
n_slices = exp.params_lm.lm_stack_range
n_trials = exp.params_lm.n_trials
doubling = 2 if exp.params_lm.doubling else 1

# Getting paths of the trial acquisitions
trial_paths = os.listdir(trials_path)

# Define the path for the preprocessed folder
processed_folder = os.path.join(trials_path, 'processed')
os.makedirs(processed_folder, exist_ok=True)

# Define the path for the masks folder
masks_folder = os.path.join(trials_path, "masks")
os.makedirs(masks_folder, exist_ok=True)

# Segment all trials per plane
all_masks = []
images_path = glob.glob(os.path.join(processed_folder, 'sum_elastic_*.tif'))[0]
mask_plane_path = os.path.join(masks_folder, f"masks_{exp.sample.id}.tif")
images_stack = imread(images_path)



20220427_RM0008_126hpf_fP3_f3


In [13]:
# Define parameter ranges
cellprob_threshold_range = [-3]
flow_threshold_range = [0.1]
resample_options = [True]
augment_options = [False]
stitch_threshold_range = [0.01]

# Prepare the output array
mask_stack = np.zeros_like(images_stack)
print(mask_stack.shape)

# Generate all combinations of parameters
parameter_combinations = list(product(cellprob_threshold_range, flow_threshold_range, resample_options, augment_options,stitch_threshold_range))
print(f"Number of combinations to test: {len(parameter_combinations)}")


(8, 24, 256, 512)
Number of combinations to test: 1


In [14]:
# Initialize the Napari viewer
viewer = napari.Viewer()

# Loop through each plane and process images
for plane in range(n_planes*doubling):
    print(f"Processing plane: {plane}")
    images = images_stack[plane]
    print('images shape', images.shape)

    for idx, (cellprob_threshold, flow_threshold, resample, augment, stitch_threshold) in enumerate(parameter_combinations):
        params_text = f"cp_{cellprob_threshold}-ft_{flow_threshold}-st_{stitch_threshold}-resample_{resample}_augment={augment}"
        combi_text = f"Combination {idx + 1}/{len(parameter_combinations)}: {params_text}"
        print(combi_text)

        # Segment the images using Cellpose with current parameter combination
        masks, _, _ = model.eval(images, 
                                 channels=[0, 0], 
                                 cellprob_threshold=cellprob_threshold, 
                                 flow_threshold=flow_threshold, 
                                 resample=resample, 
                                 augment=augment, 
                                 stitch_threshold=stitch_threshold)
        
        # Store the masks for visualization
        mask_stack[plane] = masks
        
        # Add the masks to Napari viewer
        viewer.add_labels(masks, name=params_text)

  warn(message=warn_message)


Processing plane: 0
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1533.27it/s]


Processing plane: 1
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1533.23it/s]


Processing plane: 2
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1533.54it/s]


Processing plane: 3
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1642.78it/s]


Processing plane: 4
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1533.40it/s]


Processing plane: 5
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1436.96it/s]


Processing plane: 6
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1533.79it/s]


Processing plane: 7
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0.1-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 1149.93it/s]


In [18]:
viewer.add_image(images_stack[0], name="images")
viewer.add_labels(mask_stack[0], name= "masks")

TypeError: Only integer types are supported for Labels layers, but data contains float32.

In [17]:
mask_stack.shape

(8, 24, 256, 512)

In [49]:
# Save the mask stack for each parameter combination
output_path = f"{mask_plane_path.replace('.tif', '')}_params_{params_text}.tif"
imwrite(output_path, mask_stack)

print("All combinations processed and saved.")
viewer.add_image(images_stack)
napari.run()


  warn(message=warn_message)


Processing plane: 0
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 348.43it/s]


Processing plane: 1
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 353.80it/s]


Processing plane: 2
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 343.23it/s]


Processing plane: 3
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 307.00it/s]


Processing plane: 4
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 221.12it/s]


Processing plane: 5
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 280.45it/s]


Processing plane: 6
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 348.43it/s]


Processing plane: 7
images shape (24, 256, 512)
Combination 1/1: cp_-1-ft_0.4-st_0.01-resample_True_augment=False


100%|██████████| 23/23 [00:00<00:00, 396.51it/s]


All combinations processed and saved.


In [45]:
viewer.add_image(images_stack[5])

<Image layer 'Image' at 0x1d780e5aa70>

In [19]:
from skimage import exposure
from skimage.morphology import label
# Visualize masks
viewer = napari.Viewer()
matched_stack = exposure.match_histograms(images_stack,mask_stack)
viewer.add_image(images_stack[0], name='images')
viewer.add_labels(label(mask_stack[0]), name='masks')

  warn(message=warn_message)


<Labels layer 'masks' at 0x1d6cd6e8190>

(8, 24, 256, 512)

In [13]:
#Proofreading segmentation

# --- Imports ---
# Standard libraries
import os
import re
from datetime import datetime
# Image processing and data handling
import numpy as np
import tifffile
import imageio.v2 as imageio
# Visualization and GUI
import napari
from napari.utils.colormaps import Colormap


# --- Data Loading ---
# Load image data
image_path = mask_plane_path
img = tifffile.imread(image_path)


# Set path to save proofreading
dir_to_save = os.path.join(trials_path,"masks")
print(f"Predictions will be saved in {dir_to_save}")

# --- Napari Viewer Setup ---
viewer = napari.Viewer()
scale = (1,1,1)
img_layer = viewer.add_image(img, blending="additive",  contrast_limits =[0,8000], name="img", scale=scale)

load_proofread = False
if load_proofread == True:
    proofread_path = r"C:\Users\montruth\fishPy\tests\proofreading\20231006_162735_proofreading.tif"
    proofread_data = tifffile.imread(proofread_path)
else:
    # Proofread layer for showing proofreading status
    proofread_data = np.zeros_like(img.data)
    
proofread_layer = viewer.add_labels(proofread_data, name='proofread', opacity=2, visible=True)

# Add relabelled stack as a label layer
label_layer = viewer.add_labels(img, opacity=0.3, name="labels", scale=scale)
label_layer.contour=2

# Start at the beginning of the stack
viewer.dims.current_step = [0,0,0]

# --- Keybindings ---
@viewer.bind_key('x')
def toggle_label_visibility(viewer):
    """Toggle visibility of the label layer."""
    if "labels" in viewer.layers:
        viewer.layers["labels"].visible = not viewer.layers["labels"].visible

@viewer.bind_key('a')
def activate_erase_mode(viewer):
    """Activate erase mode for label layers."""
    for layer in viewer.layers:
        if isinstance(layer, napari.layers.Labels):
            layer.mode = 'erase'

@viewer.bind_key('s')
def activate_paint_mode(viewer):
    """Activate paint mode for label layers."""
    for layer in viewer.layers:
        if isinstance(layer, napari.layers.Labels):
            layer.mode = 'paint'

@viewer.bind_key('d')
def activate_fill_mode(viewer):
    """Activate fill mode for label layers."""
    for layer in viewer.layers:
        if isinstance(layer, napari.layers.Labels):
            layer.mode = 'fill'

@viewer.bind_key('f')
def activate_pick_and_mark_mode(viewer):
    """Activate pick mode for label layers and mark the spot with a point."""
    global last_added_point_z

    layer =  viewer.layers["labels"]
    viewer.layers["labels"].visible = True
    viewer.layers.selection.active = viewer.layers['labels']
    viewer.layers["labels"].mode = 'pick'

    def on_click(layer, event):
        """Handle mouse click event to mark the spot with a point."""
        global last_added_point_z
        if event.type == 'mouse_press':
            # Get click coordinates
            coord = viewer.cursor.position
            last_added_point_z = coord[0]

            # If a points layer named 'marks' doesn't exist, create it
            if 'marks' not in viewer.layers:
                viewer.add_points(coord, name='marks', face_color='green', edge_color='white', symbol= 'cross', size=4, opacity=0.5)
            else:
                viewer.layers['marks'].data = np.array([coord])  # Update the existing points layer

            total_z_slices = int(viewer.dims.range[0][1])

            # Add the point at the clicked position on the 'marks' layer for each z-slice
            min_mark = int(max(0, last_added_point_z - 10))
            max_mark = int(min(total_z_slices, last_added_point_z + 10))

            for z in range(min_mark, max_mark):
                if z != int(coord[0]):
                    viewer.layers['marks'].add([z, coord[1], coord[2]])

            viewer.layers.selection.active = viewer.layers['labels']
            viewer.layers.selection.selected = [viewer.layers['img'], viewer.layers['labels'], viewer.layers['marks']]
            viewer.layers['labels'].mode = 'fill'

            # Disconnect the callback to prevent further marking until 'f' is pressed again
            layer.mouse_drag_callbacks.remove(on_click)

    # Connect the callback
    layer.mouse_drag_callbacks.append(on_click)

@viewer.bind_key('Control-f')
def activate_pick_mode(viewer):
    """Activate pick mode without marking a point."""
    for layer in viewer.layers:
        if isinstance(layer, napari.layers.Labels):
            layer.mode = 'pick'

@viewer.bind_key('r')
def toggle_proofread_visibility(viewer):
    """Toggle visibility of the proofread layer."""
    if "proofread" in viewer.layers:
        viewer.layers["proofread"].visible = not viewer.layers["proofread"].visible

@viewer.bind_key('Space', overwrite=True)
def transfer_proofread_label_and_remove_points(viewer):
    """Transfer the proofread label and remove points."""
    global proofread_data
    # Get the currently selected label
    selected_label = label_layer.selected_label
    print(f"selected label: {selected_label}")
    
    # Set the currently selected label in the copied_data
    proofread_data[label_layer.data == selected_label] = selected_label
    
    # 3. Refresh the new_label_layer to reflect the changes
    proofread_layer.data = proofread_data
    
    proofread_layer.refresh()
    print(f"ROIs proofread: {len(np.unique(viewer.layers['proofread'].data))}")

@viewer.bind_key('q')  
def delete_selected_label(viewer):
    """Delete the selected label in the active layer."""
    layer_active = viewer.layers.selection.active
    print(layer_active)
    selected_label = layer_active.selected_label
    if selected_label != 0:  # Ensure you're not deleting the background
        layer_active.data[layer_active.data == selected_label] = 0
        layer_active.refresh()

@viewer.bind_key("l")
def save_proofread_labels(viewer):
    """Save proofread labels to disk."""
    # Construct the filename
    current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    proofreading_filename = f"{current_time_str}_proofreading.tif"
    profreading_save_path = os.path.join(dir_to_save, proofreading_filename)
    labels_filename = f"{current_time_str}_labels.tif"
    labels_save_path = os.path.join(dir_to_save, labels_filename)
    # Save
    tifffile.imwrite(profreading_save_path, proofread_layer.data)
    tifffile.imwrite(labels_save_path, label_layer.data)
    print(f"Saved file {labels_save_path}")
    print(f"Saved file {profreading_save_path}")

@viewer.bind_key('c')
def connect_labels(viewer):
    """Connect two selected labels in the active layer."""
    active_layer = viewer.layers.selection.active
    if isinstance(active_layer, napari.layers.Labels):
        labels = list(active_layer.selected_label)
        if len(labels) == 2:
            active_layer.data[active_layer.data == labels[1]] = labels[0]
            active_layer.refresh()

@viewer.bind_key('Ctrl-s')
def split_labels(viewer):
    """Split the selected label into two new labels in the active layer."""
    active_layer = viewer.layers.selection.active
    if isinstance(active_layer, napari.layers.Labels):
        selected_label = active_layer.selected_label
        if selected_label != 0:
            # Generate a new label ID
            new_label = active_layer.data.max() + 1
            # Logic to split the label can be customized; for simplicity, we'll use a thresholding approach here.
            mask = active_layer.data == selected_label
            split_mask = mask & (np.random.rand(*mask.shape) > 0.5)
            active_layer.data[split_mask] = new_label
            active_layer.refresh()

# Handling shifts in z
@viewer.bind_key('Ctrl-z')
def handle_shifts_in_z(viewer):
    """Handle shifts in z and align the labels accordingly."""
    active_layer = viewer.layers.selection.active
    if isinstance(active_layer, napari.layers.Labels):
        data = active_layer.data
        for i in range(1, data.shape[0]):
            previous_slice = data[i-1]
            current_slice = data[i]
            unique_labels = np.unique(previous_slice)
            for label in unique_labels:
                if label != 0:
                    shift = np.mean(np.argwhere(current_slice == label)[:, 0] - np.argwhere(previous_slice == label)[:, 0])
                    if np.abs(shift) > 1:  # Apply shift threshold to adjust labels
                        data[i] = np.roll(data[i], int(-shift), axis=0)
        active_layer.data = data
        active_layer.refresh()

# --- Start the Napari Viewer ---
viewer.show()


Predictions will be saved in \\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\2P_RawData\2022-04-26\f3\trials\masks




TypeError: Only integer types are supported for Labels layers, but data contains float32.