# Segmentation trial acquisitions (calcium imaging)
* Segmentation of (motion corrected) images


In [2]:
# Imports
import os
import glob
from tifffile import imwrite, imread
from scripts.sample_db import SampleDB

import numpy as np

import napari
from cellpose import models, io

from itertools import product

# Load the sample database
db_path = r'\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\sample_db.csv'
sample_db = SampleDB()
sample_db.load(db_path)
print(sample_db)

SampleDB(sample_ids=['20220426_RM0008_130hpf_fP1_f3', '20220118_RM0012_124hpf_fP8_f2', '20220427_RM0008_126hpf_fP3_f3'])


In [3]:
# Loading experiment
sample_id = '20220427_RM0008_126hpf_fP3_f3'
exp = sample_db.get_sample(sample_id)
print(exp.sample.id)

# Import model
model_path = r'D:\montruth\cellpose\models\CP_20230803_101131' 
model = models.CellposeModel(model_type=model_path, gpu=True)

# Making shortcuts of sample parameters/information
sample = exp.sample
root_path = exp.paths.root_path
trials_path = exp.paths.trials_path

n_planes = exp.params_lm.n_planes
n_frames = exp.params_lm.n_frames
n_slices = exp.params_lm.lm_stack_range
n_trials = exp.params_lm.n_trials
doubling = 2 if exp.params_lm.doubling else 1

# Getting paths of the trial acquisitions
trial_paths = os.listdir(trials_path)

# Get the path for the preprocessed folder and the images stack
processed_folder = os.path.join(trials_path, 'processed')
print(f'Processed trials folder exists: {os.path.exists(processed_folder)}')
images_path = glob.glob(os.path.join(processed_folder, 'sum_elastic_*.tif'))[0]
images_stack = io.imread(images_path)

import skimage.exposure
# Rescale the image to [0, 1] range
images_stack_rescaled = skimage.exposure.rescale_intensity(images_stack, out_range=(0, 1))

# Apply CLAHE
images_stack_clahe = skimage.exposure.equalize_adapthist(images_stack_rescaled)
images_stack = images_stack_clahe

# Define the path for the masks folder
masks_folder = os.path.join(trials_path, "masks")
os.makedirs(masks_folder, exist_ok=True)


20220427_RM0008_126hpf_fP3_f3
Processed trials folder exists: True


100%|██████████| 192/192 [00:00<00:00, 266.29it/s]


In [4]:
# Define parameter (ranges) for cellpose to apply or test
cellprob_threshold_range = [-3]
flow_threshold_range = [0]
resample_options = [True]
augment_options = [False]
stitch_threshold_range = [0.01]

# Prepare the output array
masks_stack = np.empty(images_stack.shape, dtype=np.uint16)

# Generate all combinations of parameters
parameter_combinations = list(product(cellprob_threshold_range, flow_threshold_range, resample_options, augment_options,stitch_threshold_range))
print(f"Number of combinations to test: {len(parameter_combinations)}")


Number of combinations to test: 1


In [14]:
if len(parameter_combinations)>1:
    print('Multiple combinations found: Test mode active')
    test_viewer = napari.Viewer()
    test_plane = 3
    images = images_stack[test_plane]
    
    for idx, (cellprob_threshold, flow_threshold, resample, augment, stitch_threshold) in enumerate(parameter_combinations):
        params_text = f"cp_{cellprob_threshold}-ft_{flow_threshold}-st_{stitch_threshold}-resample_{resample}-augment_{augment}"
        combi_text = f"Combination {idx + 1}/{len(parameter_combinations)}: {params_text}"
        print(combi_text)

        # Segment the images using Cellpose with current parameter combination
        masks, _, _ = model.eval(images, 
                                 channels=[0, 0], 
                                 cellprob_threshold=cellprob_threshold, 
                                 flow_threshold=flow_threshold, 
                                 resample=resample, 
                                 augment=augment, 
                                 stitch_threshold=stitch_threshold)

        # Add the masks to Napari viewer
        test_viewer.add_labels(masks, name=params_text)
    
else:
    # Loop through each plane and process images
    for plane in range(n_planes*doubling):
        print(f"Processing plane: {plane}")
        images = images_stack[plane]
        print('images shape', images.shape)
    
        for idx, (cellprob_threshold, flow_threshold, resample, augment, stitch_threshold) in enumerate(parameter_combinations):
            params_text = f"cp_{cellprob_threshold}-ft_{flow_threshold}-st_{stitch_threshold}-resample_{resample}-augment_{augment}"
            combi_text = f"Combination {idx + 1}/{len(parameter_combinations)}: {params_text}"
            print(combi_text)
    
            # Segment the images using Cellpose with current parameter combination
            masks, _, _ = model.eval(images, 
                                 channels=[0, 0], 
                                 cellprob_threshold=cellprob_threshold, 
                                 flow_threshold=flow_threshold, 
                                 resample=resample, 
                                 augment=augment, 
                                 stitch_threshold=stitch_threshold)
            
            # Store the masks for visualization
            masks_stack[plane] = masks
    viewer = napari.Viewer()
    viewer.add_image(images_stack)
    viewer.add_labels(masks_stack)


Processing plane: 0
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 95.02it/s] 


Processing plane: 1
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 119.78it/s]


Processing plane: 2
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 105.73it/s]


Processing plane: 3
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 85.82it/s] 


Processing plane: 4
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 82.72it/s] 


Processing plane: 5
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 77.82it/s] 


Processing plane: 6
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 74.19it/s] 


Processing plane: 7
images shape (24, 256, 512)
Combination 1/1: cp_-3-ft_0-st_0.01-resample_True-augment_False


100%|██████████| 23/23 [00:00<00:00, 82.73it/s] 
  warn(message=warn_message)


In [15]:
# Save masks

imwrite(os.path.join(masks_folder, f'masks_{exp.sample.id}_{params_text}.tif'), masks_stack)