## Setup

Please execute the cell(s) below to initialize the notebook environment.

In [None]:
# @title Install dependencies
!pip install poetry

In [None]:
# @title Install SMorph Python module
!pip install https://github.com/swanandlab/SMorph/releases/download/v0.1.1-alpha/SMorph-0.1.1.tar.gz

In [None]:
# Imports
on_colab = 'google.colab' in str(get_ipython())

import warnings
warnings.filterwarnings('ignore')

if not on_colab:
    import napari
import smorph as sm
import smorph.util.autocrop as ac
import ipywidgets as widgets

In [None]:
# Helper function
def view_3D(*args):
    """Views a 3D image."""
    n_images = len(args)
    if not on_colab and n_images > 0:  
        with napari.gui_qt():
            viewer = napari.view_image(**args[0], ndisplay=3)
            for itr in range(1, n_images):
                if args[itr]['data'].max() > 1 and args[itr]['data'].max() % 1 == 0:
                    if 'colormap' in args[itr].keys():
                        del args[itr]['colormap']
                    if 'gamma' in args[itr].keys():
                        del args[itr]['gamma']
                    viewer.add_labels(**args[itr])
                else:
                    viewer.add_image(**args[itr])

---

## Step 1: Import Confocal Microscopic Image of the Tissue

Set `CONFOCAL_TISSUE_IMAGE` to the path of the image file to be processed.
- Followed by non-local means denoising using auto-calibrated parameters

In [None]:
CONFOCAL_TISSUE_IMAGE = 'Datasets/Confocal/SAL,DMI, FLX ADN HALO_TREATMENT_21 DAYS/CONTROLS_CZI/MSPP2.1MA_1_SINGLE MARK _CTRL_21 DAY/MSP2.1MA_1_SINGLE MARK_20X_SEC 2_RIGHT_HILUS.czi'  #@param

original = ac.import_confocal_image(CONFOCAL_TISSUE_IMAGE)

if original.ndim == 2:
  original = (original - original.min()) / (original.max() - original.min())
  import numpy as np
  original = np.expand_dims(original, 0)

from skimage.morphology import closing
denoised = closing(original)
# deconvolved = ac.deconvolve(original, CONFOCAL_TISSUE_IMAGE, iters=10)
# denoiser = ac.calibrate_nlm_denoiser(original)
# denoise_parameters = denoiser.keywords['denoiser_kwargs']
# print(denoise_parameters)
# denoised = ac.denoise(original, denoise_parameters)
ac.projectXYZ(denoised, .5, .5, 1)

In [None]:
view_3D({'data': original, 'colormap': 'gray', 'name': 'original'},
        {'data': denoised, 'colormap': 'gray', 'name': 'denoised'})

In [None]:
# import imagej
# import scyjava
# scyjava.config.add_option('-Xmx6g')
# ij = imagej.init('C:/Program Files (x86)/Fiji.app', headless=False)
# from scyjava import jimport
# WindowManager = jimport('ij.WindowManager')

In [None]:
# # You need to specify that no changes have been made, or else
# # the close window dialogue at the end will ask for confirmation.
# ij.ui().show('denoised', ij.py.to_java(denoised))
# img = WindowManager.getCurrentImage()
# img.changes = False

In [None]:
# plugin = '3D Fast Filters'

# args = { 
#     'filter': 'Median',  # StringField
#     'radius_x_unit': .6918883015525974,  # NumericField
#     'radius_x_pix': 1,  # NumericField
#     'radius_y_unit': .6918883015525974,  # NumericField
#     'radius_y_pix': 1,  # NumericField
#     'radius_z_unit': 1.0785801681301463,
#     'radius_z_pix': 1
# }

# ij.py.run_plugin(plugin, args)#, ij1_style=False)
# result = WindowManager.getWindow()
# # stack = result.getStack()
# # stack.setPixels(stack.getProcessor().getPixels(), stack.getCurrentSlice())
# # ij.py.show(stack)
# # result_array = ij.py.from_java(stack)
# # result_array.shape
# ij.py.from_java(result)

In [None]:
# plugin = 'Tubeness'

# args = {
#     'sigma': .691884765625
# }

# ij.py.run_plugin(plugin, args, ij1_style=False)
# result = WindowManager.getCurrentImage()
# result_array = ij.py.from_java(result)
# view_3D({'data': result_array, 'colormap': 'gray', 'name': 'result_array'},
#         {'data': denoised, 'colormap': 'gray', 'name': 'denoised'})

In [None]:
from skimage import exposure

# deconvolved = ac.deconvolve(denoised, CONFOCAL_TISSUE_IMAGE, iters=10)

# Adaptive Equalization
# denoised = exposure.equalize_adapthist(denoised, clip_limit=0.03)

# denoiser = ac.calibrate_nlm_denoiser(img_adapteq)
# denoise_parameters = denoiser.keywords['denoiser_kwargs']
# img_adapteq = ac.denoise(img_adapteq, denoise_parameters)
CLIP_LIMIT = .03
def adapt_eq(clip_limit=0.03):
  global CLIP_LIMIT
  CLIP_LIMIT = clip_limit
  ac.projectXYZ(exposure.equalize_adapthist(denoised, clip_limit=clip_limit), .5, .5, 1)

_ = widgets.interact(adapt_eq, clip_limit=(0, 1, .01))

In [None]:
denoised = exposure.equalize_adapthist(denoised, clip_limit=CLIP_LIMIT)

---

## Step 2: Select ROI using Polygonal Lasso Tool

Set two variables:
- `SELECT_ROI`: True, If you want to select ROI manually; else False
- `NAME_ROI`: Name of the manually selected ROI
- `FILE_ROI`: Path to the ROI file; else None

In [None]:
SELECT_ROI = True
NAME_ROI = ''
FILE_ROI = 'Datasets/Confocal/SAL,DMI, FLX ADN HALO_TREATMENT_21 DAYS/CONTROLS_CZI/MSPP2.1MA_1_SINGLE MARK _CTRL_21 DAY/MSP2.1MA_1_SINGLE MARK_20X_SEC 2_RIGHT_HILUS-ML.roi'
linebuilder = None
import matplotlib.pyplot as plt
%matplotlib widget

IMG_NAME = CONFOCAL_TISSUE_IMAGE.split('/')[-1].split('.')[0]

import matplotlib.pyplot as plt
file_roi_widget = widgets.Text(value=FILE_ROI, description='file_roi')
def roi_interact(select_roi=SELECT_ROI,
                 name_roi=NAME_ROI, draw_roi=False, file_roi=FILE_ROI):
  global SELECT_ROI, NAME_ROI, FILE_ROI, linebuilder, file_roi_widget
  SELECT_ROI, NAME_ROI, FILE_ROI = select_roi, name_roi, file_roi
  NAME_ROI = NAME_ROI if SELECT_ROI else ''
  if draw_roi:
    FILE_ROI = None
    file_roi_widget.layout.visibility = 'hidden'
  else:
    file_roi_widget.value = FILE_ROI
    file_roi_widget.layout.visibility = 'visible'
    plt.clf()
  linebuilder = None if not SELECT_ROI else ac.select_ROI(denoised, IMG_NAME + '-' + NAME_ROI, FILE_ROI)


_ = widgets.interact(roi_interact, select_roi=SELECT_ROI,
                     name_roi=NAME_ROI, draw_roi=False, file_roi=file_roi_widget)

In [None]:
%matplotlib inline
if SELECT_ROI:
    original, denoised = ac.mask_ROI(original, denoised, linebuilder)
    ac.projectXYZ(denoised, .5, .5, 1)

---

## Step 3: Segmentation

### 3.1 Threshold & color label cells

Set two parameters:
- `LOW_THRESH`: Pixel intensity value corresponding to faintest branch's edge
- `HIGH_THRESH`: Pixel intensity value corresponding to faintest soma

Understand their effect by configuring three parameters:
- `LOW_DELTA`: Pixel intensity value corresponding to change in `LOW_THRESH`
- `HIGH_DELTA`: Pixel intensity value corresponding to change in `HIGH_THRESH`
- `N_STEPS`: Number of steps of delta in threshold to take in both directions

In [None]:
import skimage
otsu_value = skimage.filters.threshold_otsu(denoised)

LOW_THRESH = .4
HIGH_THRESH = .5

LOW_DELTA = .1
HIGH_DELTA = .1
N_STEPS = 1
results = None
%matplotlib inline
def test_thresholds(low_thresh_init, low_thresh, high_thresh_init,
                    high_thresh, low_delta, high_delta, n_steps):
  global results, LOW_THRESH, HIGH_THRESH, LOW_DELTA, HIGH_DELTA, N_STEPS
  LOW_THRESH, HIGH_THRESH, N_STEPS = low_thresh, high_thresh, n_steps
  if low_thresh_init is not None:
    LOW_THRESH = eval(f'skimage.filters.threshold_{low_thresh_init}(denoised)')
  if high_thresh_init is not None:
    if high_thresh_init == 'isodata':
      HIGH_THRESH = eval(f'skimage.filters.threshold_{high_thresh_init}(denoised)')
  LOW_DELTA, HIGH_DELTA = low_delta, high_delta
  results = ac.testThresholds(denoised, LOW_THRESH, HIGH_THRESH, LOW_DELTA,
                              HIGH_DELTA, N_STEPS, 'gist_earth')

_ = widgets.interact(test_thresholds,
                     low_thresh_init=[None, *sm.util.THRESHOLD_METHODS],
                     low_thresh=widgets.FloatSlider(LOW_THRESH, min=0, max=1, step=.01,
                                                    readout_format='.4f', layout=widgets.Layout(width='100%')),
                     high_thresh_init=[None, *sm.util.THRESHOLD_METHODS],
                     high_thresh=widgets.FloatSlider(HIGH_THRESH, min=0, max=1, step=.01,
                                                     readout_format='.4f', layout=widgets.Layout(width='100%')),
                     low_delta=widgets.FloatSlider(LOW_DELTA, min=0, max=1, step=.0005,
                                                   readout_format='.4f', layout=widgets.Layout(width='100%')),
                     high_delta=widgets.FloatSlider(HIGH_DELTA, min=0, max=1, step=.0005,
                                                    readout_format='.4f', layout=widgets.Layout(width='100%')),
                     n_steps=widgets.IntSlider(N_STEPS, min=0, max=10,
                                               layout=widgets.Layout(width='100%'))
)

In [None]:
view_3D({'data': original, 'colormap': 'inferno', 'name': 'original'}, *results)

### Thresholding results

In [None]:
thresholded = ac.threshold(denoised, LOW_THRESH, HIGH_THRESH)
labels = ac.label_thresholded(thresholded)

In [None]:
prefiltering_volume = thresholded.sum()
f'Prefiltering Volume: {prefiltering_volume}'

### 3.2 Filter segmented individual cells by removing ones in borders (touching the convex hull)

In [None]:
# discard objects connected to border of approximated tissue, potential partially captured
# filtered_labels = ac.filter_labels(labels, thresholded, None)#linebuilder) if original.shape[0] > 1 else labels
filtered_labels = labels
ac.projectXYZ(filtered_labels, .5, .5, 1, 'gist_earth')

In [None]:
view_3D({'data': filtered_labels, 'colormap': 'gray', 'gamma': .8, 'name': 'filtered_labels'},
        {'data': labels, 'colormap': 'gist_earth', 'gamma': .8, 'name': 'labels'})

In [None]:
# get the centroids and label values from the label image
regions = ac.arrange_regions(filtered_labels)
centroid_coords = [r.centroid for r in regions]

# store the labels for each blob in a properties dictionary
pts_properties = {'obj': [i for i in range(len(regions))]}

### 3.3 Visualize segmented cells to determine cutoff volumes

#### 3.3.1 Check segmented cells on whole image

In [None]:
if not on_colab:
    with napari.gui_qt():
        viewer = napari.view_image(denoised, name='denoised', ndisplay=3)
        viewer.add_labels(filtered_labels, name='filtered_labels')
        viewer.add_points(centroid_coords, edge_color='transparent',
                          face_color='transparent',
                          properties=pts_properties, text='obj')

In [None]:
if not on_colab:
    with napari.gui_qt():
        man_viewer = napari.view_image(denoised, ndisplay=3)
        man_viewer.add_labels(filtered_labels, name='filtered_labels')
        man_viewer.add_points(centroid_coords, face_color='red', size=5, symbol='cross')

In [None]:
from os import listdir, path

import numpy as np
import tifffile
from skimage.measure import label
from smorph.util.autocrop._io import _build_multipoint_roi

somas_est = np.unique(man_viewer.layers[1].data, axis=0)
filtered_regions = []

for region in regions:
    minz, miny, minx, maxz, maxy, maxx = region.bbox
    ll = np.array([minz, miny, minx])  # lower-left
    ur = np.array([maxz, maxy, maxx])  # upper-right
    mask_coords = np.array(np.where(region.filled_image>0)).T
    inidx = np.all(np.logical_and(ll <= somas_est, somas_est <= ur), axis=1)

    somas_coords = somas_est[inidx]

regions = filtered_regions

#### 3.3.2: Check batches of objects

In [None]:
N_BATCHES = ac.paginate_objs(regions, pg_size=50)

In [None]:
# Set `BATCH_NO` to view detected objects in paginated 2D MIP views.
def plot_batch(BATCH_NO):
  ac.project_batch(BATCH_NO, N_BATCHES, regions, original)
  plt.show()

_ = widgets.interact(plot_batch, BATCH_NO=widgets.IntSlider(min=0,
                       max=N_BATCHES-1, layout=widgets.Layout(width='100%')))

#### 3.3.2: Check individual objects
Select individual objects using `OBJ_INDEX`.

In [None]:
OBJ_INDEX = 0
extracted_cell = None
minz, miny, minx, maxz, maxy, maxx = 0, 0, 0, 0, 0, 0

def plot_single(obj_index):
  global OBJ_INDEX, extracted_cell, minz, miny, minx, maxz, maxy, maxx
  OBJ_INDEX = obj_index
  extracted_cell = ac.extract_obj(regions[OBJ_INDEX], original)
  minz, miny, minx, maxz, maxy, maxx = regions[OBJ_INDEX].bbox
  ac.projectXYZ(extracted_cell, .5, .5, 1)

_ = widgets.interact(plot_single, obj_index=widgets.IntSlider(min=0,
                       max=len(regions)-1, layout=widgets.Layout(width='100%')))

---

## Step 4: Export autocropped 3D cells or 2D max intensity projections

Set two parameters:
- `LOW_VOLUME_CUTOFF`: to filter out noise/artifacts
- `HIGH_VOLUME_CUTOFF`: to filter out cell clusters

For choosing between 3D segmented cells or 2D max intensity projections:
- Set `OUTPUT_OPTION` = '3d' for 3D cells, or
- Set `OUTPUT_OPTION` = 'mip' for Max Intensity Projections.

In [None]:
LOW_VOLUME_CUTOFF = 200  # filter noise/artifacts
HIGH_VOLUME_CUTOFF = 3133  # filter cell clusters
OUTPUT_OPTION = 'mip'  # '3d' for 3D cells, 'mip' for Max Intensity Projections
SEGMENT_TYPE = 'segmented'
reconstructed_cells = None

import numpy as np

def volume_range(low_volume_cutoff=LOW_VOLUME_CUTOFF,
                 high_volume_cutoff=HIGH_VOLUME_CUTOFF, output_option=OUTPUT_OPTION,
                 segment_type=SEGMENT_TYPE):
  global LOW_VOLUME_CUTOFF, HIGH_VOLUME_CUTOFF, OUTPUT_OPTION, SEGMENT_TYPE, reconstructed_cells
  LOW_VOLUME_CUTOFF, HIGH_VOLUME_CUTOFF = low_volume_cutoff, high_volume_cutoff
  OUTPUT_OPTION, SEGMENT_TYPE = output_option, segment_type

  reconstructed_cells = np.zeros_like(denoised)
  for region in regions:
    if LOW_VOLUME_CUTOFF <= region.area <= HIGH_VOLUME_CUTOFF:
      minz, miny, minx, maxz, maxy, maxx = region.bbox
      reconstructed_cells[minz:maxz, miny:maxy, minx:maxx] += region.filled_image * denoised[minz:maxz, miny:maxy, minx:maxx]
  ac.projectXYZ(reconstructed_cells, .5, .5, 1, 'gist_heat')

_ = widgets.interact(volume_range, low_volume_cutoff=widgets.IntSlider(value=LOW_VOLUME_CUTOFF,
                         min=0, max=regions[-1].area, layout=widgets.Layout(width='100%')),
                     high_volume_cutoff=widgets.IntSlider(value=HIGH_VOLUME_CUTOFF,
                         min=0, max=regions[-1].area, layout=widgets.Layout(width='100%')),
                     output_option=['3d', 'mip', 'both'],
                     segment_type=['segmented', 'unsegmented', 'both'])

In [None]:
view_3D({'data': original, 'colormap': 'gray', 'name': 'original'},
        {'data': reconstructed_cells, 'name': 'output'})

In [None]:
ac.export_cells(CONFOCAL_TISSUE_IMAGE, LOW_VOLUME_CUTOFF,
                HIGH_VOLUME_CUTOFF, OUTPUT_OPTION, original,
                regions, None, SEGMENT_TYPE, NAME_ROI, linebuilder)