# Segmentation of cell nuclei with [Cellpose](http://www.cellpose.org/)

This notebook segments cell nuclei in each z-layer and combines them into a 3D mask.

#### Documentation

- [Cellpose documentation](https://cellpose.readthedocs.io/en/latest/)
- [Paper](https://www.biorxiv.org/content/10.1101/2020.02.02.931238v1)
- [Code](https://github.com/MouseLand/cellpose)

## Requirements
- A folder with images that should be segmented. All z-layers for a specific sample must be combined into a single file. To combine z-layers and channels, run [images_to_stack.ipynb](images_to_stack.ipynb).

## Config

### The following code imports and declares functions used for the processing:

In [None]:
#################################
#  Don't modify the code below  #
#################################

import os
import re
import time
import numpy as np
import intake_io
import xarray as xr
import warnings
from tqdm import tqdm
from skimage import io
from skimage.morphology import remove_small_objects
from scipy import ndimage
import matplotlib.pyplot as plt
from am_utils.utils import walk_dir, imsave

from cellpose import models, utils, plot, transforms

from lib import rescale_intensity

## Specify data paths and analysis parameters

### Please provide data paths:

`input_dir`: folder with images of cells to be segmented

`output_dir`: folder to save results

`output_combined`: (optional) folder to save cell segmentation results combined with the raw data

In [None]:
input_dir = "/research/sharedresources/cbi/common/Anna/test/input"
output_dir = "/research/sharedresources/cbi/common/Anna/test/Analysis/cell_segmentation"
output_combined = None

### The following code lists all image files in the input directory:

In [None]:
#################################
#  Don't modify the code below  #
#################################

samples = walk_dir(input_dir)

print(f'{len(samples)} images were found:')
print(np.array(samples))

### The following code loads a random image

In [None]:
#################################
#  Don't modify the code below  #
#################################

sample = samples[np.random.randint(len(samples))]
dataset = intake_io.imload(sample)
print(dataset)

print('\n\nThe following voxel sizes were detected:')
coords = dict({'x': None, 'y': None, 'z': None})
for c in ['x', 'y', 'z']:
    if c in dataset.coords:
        coords[c] = dataset.coords[c].data[1]
        print(rf'{c}: {dataset.coords[c].data[1]}')
    else:
        if c in dataset.dims:
            print(rf'{c}: not detected, will use 1')
            
print('\n The following channels have been detected:')
print(dataset['c'].data)

### Please specify correct voxel size

Keep `None`, if the value loaded from the dataset is correct

In [None]:
x = None
y = None
z = None

### Please specify channels names

If you'd like to relabel channels, specify channel names as an array (e.g. ['channel_1', 'channel_2']). Specify `None` to keep the default channel labels.

In [None]:
channel_names = ['DNA', 'GFP']

### The following code updates voxel sizes and channel names:

In [None]:
#################################
#  Don't modify the code below  #
#################################
if channel_names is None:
    channel_names = dataset['c'].data
vs = [x, y, z]
for i, c in enumerate(['x', 'y', 'z']):
    if vs[i] is not None:
        coords[c] = vs[i]
dataset = intake_io.imload(sample, metadata={"spacing": coords,
                                             "coords": {'c': channel_names}})

dataset

### Specify the channel to segment

In [None]:
channel = 'DNA'

### The following code displays an example image

In [None]:
img = dataset.loc[dict(c=channel, z=dataset['z'][int(len(dataset['z'])/2 - 0.5)])]['image'].data
plt.figure(figsize=(8,8))
io.imshow(img)

### Please specify model parameters

`diameter`: average nucleus diameter in pixels; set to `None` to automatically detect the cell diameter

`model_type`: `nuclei` for nuclei segmentationm, `cyto` to segment cells

`do_3D`: to perform segmentation in 3D set to `True`, to segment layer-by-layer set to `False`. If there is only one layer of cells, choose `False` since the 2D segmentation is faster

In [None]:
diameter = 120 
model_type = "cyto" # it seems like the "cyto" model works better for this dataset
do_3D = False

### Advanced parameters (do not change unless you know what you do):

`flow_threshold`: increase if model returns too few masks, decrease if model returns too many ill-shaped masks

`probability_threshold`:  decrease if model returns too few cells, increase if model returns too many cells  
Values should be between -6 and +6

In [None]:
flow_threshold = 0.4        # default: 0.4
probability_threshold = 0   # default: 0.0
channels = [0,0]

### The following code segments the random image and displays the results:

In [None]:
#################################
#  Don't modify the code below  #
#################################
if do_3D:
    img = dataset.loc[dict(c=channel)]['image'].data
    anisotropy = dataset.coords['z'].data[1]/dataset.coords['x'].data[1]
else:
    img = dataset.loc[dict(c=channel, z=dataset['z'][int(len(dataset['z'])/2 - 0.5)])]['image'].data
    anisotropy = None
imgs = [rescale_intensity(np.array(img))]        
model = models.Cellpose(gpu=True, model_type=model_type)
masks, flows, styles, diams = model.eval(imgs, channels=channels, diameter=diameter, 
                                         flow_threshold=flow_threshold, cellprob_threshold=probability_threshold,
                                         do_3D=do_3D, anisotropy=anisotropy)

#display the results
for i in range(len(imgs)):
    maski = masks[i]
    flowi = flows[i][0]
    img = imgs[i]
    fig = plt.figure(figsize=(30,10))
    if len(img.shape) > 2:
        img = img[int(img.shape[0]/2)]
        maski = maski[int(maski.shape[0]/2)]
        flowi = flowi[int(flowi.shape[0]/2)]
    plot.show_segmentation(fig, img, maski, flowi, channels=channels)
    plt.tight_layout()
    plt.show()

### The following code segments all images and saves the results: 

In [None]:
%%time
#################################
#  Don't modify the code below  #
#################################

for i, sample in enumerate(samples):
    print(sample)
    print(fr'Processing sample {i+1} of {len(samples)}')
    dataset = intake_io.imload(sample, metadata={"spacing": coords,
                                                 "coords": {'c': channel_names}})

    if do_3D or len(dataset['z']) == 1:
        imgs = [dataset.loc[dict(c=channel)]['image'].data]
    else:
        imgs = dataset.loc[dict(c=channel)]['image'].data
    if len(dataset['z']) > 1:
        anisotropy = dataset.coords['z'].data[1]/dataset.coords['x'].data[1]
    else:
        anisotropy = None
    imgs = [rescale_intensity(np.array(img)) for img in imgs] 

    model = models.Cellpose(gpu=True, model_type=model_type)
    masks, flows, styles, diams = model.eval(imgs, channels=channels, diameter=diameter, 
                                             flow_threshold=flow_threshold, cellprob_threshold=probability_threshold,
                                             do_3D=do_3D, anisotropy=anisotropy)


    if do_3D or len(dataset['z']) == 1:
        output = masks[0]
    else:
        output = np.array(masks)

    if do_3D is False and len(output.shape) > 2:

        area = (output > 0).sum(-1).sum(-1)
        if len(area) > 21:
            ind = np.argmax(area[10:-10]) + 10
        else:
            ind = np.argmax(area) 
        labels = output[ind:ind+1].copy()
        output = ndimage.median_filter(output > 0, 3)
        output = output * labels
        minvol = 4./3*np.pi*(diameter/4)**3
        output = remove_small_objects(output, min_size=minvol)

    fn = sample[len(input_dir):]
    fn = fn.replace(fn.split('.')[-1], 'tif')
    imsave(output_dir + fn, output.astype(np.uint16))

    if output_combined is not None:
        stack = np.zeros((len(channel_names) + 1,) + output.shape)
        for ch_ind, chname in enumerate(channel_names):
            stack[ch_ind] = np.array(dataset.loc[dict(c=chname)]['image'].data)
        stack[-1] = output
        
        stack = xr.Dataset(data_vars=dict(image=(dataset['image'].dims, stack.astype(np.uint16))),
                           coords=dict(c=channel_names + ['Nuclei segmentation'], x=dataset.coords['x'], y=dataset.coords['y'], z=dataset.coords['z']),
                           attrs=dataset.attrs)
        stack['image'].attrs = dataset['image'].attrs
    
        os.makedirs(os.path.dirname(output_combined + fn), exist_ok=True)
        intake_io.imsave(stack, output_combined + fn)
        
    