In [6]:
import napari
import cupy as cp
import numpy as np
import higra as hg
import pandas as pd
import matplotlib.pyplot as plt
from toolz import curry

from cucim.skimage.transform import downscale_local_mean
from cucim.skimage.filters import median

from skimage.measure import regionprops

import torch as th
import torch.nn.functional as F

from tifffile import imread

from dexp.processing.morphology import area_white_top_hat

from dexp_dl.inference import ModelInference
from dexp_dl.postprocessing import hierarchy
from dexp_dl.models import hrnet, unet
from dexp_dl.transforms import *

import sys
import os

SPARSE_DECONV_PATH = f'{os.environ["HOME"]}/Softwares/sparse-decon-py'
sys.path.append(SPARSE_DECONV_PATH)

try:
    from sparse_recon.sparse_deconv import sparse_deconv
except ImportError:
    sparse_deconv = None

In [7]:
### INPUT ###
IM_PATH = '10somite 100821 sample2 20x_1_MMStack_Default.ome.tif'
WEIGHTS_PATH = 'logs/hrnet_bn/last.ckpt'
CELL_CHANNEL = 2
Z_SCALE = 2
DATASET_PATH = IM_PATH.split('.')[0] + '.csv'

th.cuda.set_device(0)

### PARAMETERS ###
PRED_THOLD = 0.25
CUT_THOLD = 1

In [8]:
def in_transform(image):
    return th.Tensor(image).unsqueeze_(0).half()

def out_transform(image):
    return th.sigmoid(F.interpolate(image, scale_factor=2, mode='trilinear', align_corners=True))

def normalize(image, upper_limit):
    im_min = image.min()
    image = (image - im_min) / (quantile - im_min)
    return np.clip(image, 0, 1)

net = hrnet.hrnet_w18_small_v2(pretrained=False, in_chans=1, num_classes=3, image_ndim=3)

model = ModelInference(
    net,
    transforms=in_transform,
    after_transforms=out_transform,
    tile=(48, 96, 96), num_outputs=3,
)

model.load_weights(WEIGHTS_PATH)


Missing keys []

Unexpected keys []


In [9]:
image = imread(IM_PATH)
if image.shape[1] < 5:
    image = image.transpose((1, 0, 2, 3))
    
# making it anisotropic
image = downscale_local_mean(cp.asarray(image), (1, Z_SCALE, 1, 1)).get()

In [10]:
wth = np.stack([
    area_white_top_hat(image[i], 1e4, sampling=1, axis=0) for i in range(len(image))
])

viewer = napari.Viewer()
viewer.add_image(image, name='before', channel_axis=0)
viewer.add_image(wth, name='after', channel_axis=0)

# wth0 = np.stack([
#     white_top_hat(image[i], 0) for i in range(len(image)
# ])
# viewer.add_image(wth0, name='after 0', channel_axis=0)
# viewer.add_image(wth.astype(np.int16) - wth0, name='dif', channel_axis=0)

image = wth

In [None]:
quantile = np.quantile(image[CELL_CHANNEL], 0.999)
print('Quantile', quantile, 'Maximum', image[CELL_CHANNEL].max())

# normalizing image
norm_image = normalize(image[CELL_CHANNEL], quantile)

In [None]:
# deep learning
with th.cuda.amp.autocast():
    pred = model(norm_image)
    
th.cuda.empty_cache()  

# displaying
viewer = napari.Viewer()
viewer.add_image(image, channel_axis=0, name='Input image')
viewer.add_image(pred[0], blending='additive', name='Cell prediction')
viewer.add_image(pred[1], blending='additive', name='Distance map')
if len(pred) > 2:
    viewer.add_image(pred[2], blending='additive', name='Denoising')

In [None]:
# computing segmentation
hiers = hierarchy.create_hierarchies(
    pred[0] > PRED_THOLD,
    pred[1],
    hierarchy_fun=hg.watershed_hierarchy_by_area,
    cache=True,
    min_area=10,
    min_frontier=0,
)
    
for h in hiers:
    h.cut_threshold = CUT_THOLD
labels = hierarchy.to_labels(hiers, pred[0].shape)
    
labels_layer = viewer.add_labels(labels)
labels_layer.contour = 1

In [None]:
# blurring expressions
indices = list(range(len(image)))
indices.remove(CELL_CHANNEL)

expressions = image[indices].transpose((1, 2, 3, 0))

n_exp = len(indices)
viewer.add_image(expressions, name='original expressions', channel_axis=3)

### DECONVOLUTION ###

# change this to use/not use deconvolution
use_deconv = False

if use_deconv and sparse_deconv is None:
    print('sparse-deconv-py package not loaded!')
    use_deconv = False

if use_deconv:
    """
    Parameter description:
    
    background:
        0 = no background noise
        1 = low background noise
        2 = high background noise
    
    fidelity:
        higher values forces result to be closer to input
    
    sparsity:
        higher values forces results to have more zeros --- it removes more blur/noise
    
    NOTE:
        very HIGH `sparsity` and/or LOW `fidelity` might lead to an empty image as result
    """
    deconv = np.zeros_like(expressions, dtype=float)
    for i in range(expressions.shape[-1]):
        deconv[..., i] = sparse_deconv(expressions[..., i], [], background=1, fidelity=150, sparsity=10).get()
    
    viewer.add_image(deconv, name='deconvolved expressions', channel_axis=3)
    
    expressions = deconv

In [None]:
# extracting expressions from segments
props = regionprops(labels, expressions)
exp_statistics = curry(np.mean, axis=0)

# spreading expression to segments
mask = np.zeros_like(expressions)
for p in props:
    exps = exp_statistics(p.intensity_image[p.image])
    for i in range(len(exps)):
        mask[p.slice + (i,)][p.image] = exps[i]

layer_colors = ('green', 'red', 'blue')
exps_layer = viewer.add_image(mask, name='orig. label exp.', channel_axis=3)
for c, l in zip(layer_colors, exps_layer):
    l.colormap = c

In [None]:
df = []

for p in props:
    exps = exp_statistics(p.intensity_image[p.image])
    if all(exp > l.contrast_limits[0] for exp, l in zip(exps, exps_layer)):
        df.append([p.label, *p.centroid, *exps])

df = pd.DataFrame(df, columns=['label', 'z', 'y', 'x'] + [f'exp{i}' for i in range(n_exp)])


def normalize_column(df, column, upper_quantile=1):
    col = df[column].values.copy()
    minimum = col.min()
    col -= minimum
    quantile = np.quantile(col, upper_quantile)
    col /= quantile
    
    def norm_fun(x):
        return min(1, max(0, (x - minimum) / quantile))
    
    return np.clip(col, 0, 1), norm_fun

# computing colors
grb = np.array([[0, 1, 0, 1],
                [1, 0, 0, 1],
                [0, 0, 1, 1]])

df_colors = np.zeros((len(df), 4))
norm_funs = []
for i in range(n_exp):
    blending, norm_fun = normalize_column(df, f'exp{i}')
    norm_funs.append(norm_fun)
    df_colors += blending[:, None] * grb[i][None, :]

df_colors[:,3] = 1 # alpha channel

plt.scatter(x=df['exp0'], y=df['exp1'], c=df_colors)
plt.xlabel('exp. 0'); plt.ylabel('exp. 1')

df.to_csv(DATASET_PATH, index=False)

In [None]:
# spreading expression to segments
mask = np.zeros_like(expressions)
for p in props:
    exps = exp_statistics(p.intensity_image[p.image])
    if all(exp > l.contrast_limits[0] for exp, l in zip(exps, exps_layer)):
        for i in range(len(exps)):
            mask[p.slice + (i,)][p.image] = norm_funs[i](exps[i])

layers = viewer.add_image(mask, name='selected label exp.', channel_axis=3)
for c, l in zip(layer_colors, layers):
    l.colormap = c