In [None]:
import napari
import cupy as cp
import numpy as np
import higra as hg
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from toolz import curry

from cucim.skimage.transform import downscale_local_mean, rescale
from cucim.skimage.filters import gaussian

from skimage.filters import threshold_otsu
from skimage.measure import regionprops
from skimage import restoration

from scipy.linalg import lstsq

import torch as th
import torch.nn.functional as F

from tifffile import imread

from dexp_dl.inference import ModelInference
from dexp_dl.postprocessing import hierarchy
from dexp_dl.models import hrnet, unet
from dexp_dl.transforms import *

import sys
import os

sys.path.append(f'{os.environ["HOME"]}/Softwares/sparse-deconv-py')

from sparse_recon.sparse_deconv import sparse_deconv

In [None]:
### INPUT ###
IM_PATH = 'bud stage - embryo3_1_crop_denoised.tif'
WEIGHTS_PATH = 'logs/hrnet_bn/last.ckpt'
CELL_CHANNEL = 0
Z_SCALE = 2
DATASET_PATH = IM_PATH.split('.')[0] + '.csv'

th.cuda.set_device(0)

### PARAMETERS ###
PRED_THOLD = 0.25
CUT_THOLD = 1

In [None]:
def in_transform(image):
    return th.Tensor(image).unsqueeze_(0).half()

def out_transform(image):
    return th.sigmoid(F.interpolate(image, scale_factor=2, mode='trilinear', align_corners=True))

def normalize(image, upper_limit):
    im_min = image.min()
    image = (image - im_min) / (quantile - im_min)
    return np.clip(image, 0, 1)

net = hrnet.hrnet_w18_small_v2(pretrained=False, in_chans=1, num_classes=3, image_ndim=3)

model = ModelInference(
    net,
    transforms=in_transform,
    after_transforms=out_transform,
    tile=(48, 96, 96), num_outputs=3,
)

model.load_weights(WEIGHTS_PATH)

In [None]:
image = imread(IM_PATH)
if image.shape[1] == 3:
    image = image.transpose((1, 0, 2, 3))
    
# making it anisotropic
image = downscale_local_mean(cp.asarray(image), (1, Z_SCALE, 1, 1)).get()

In [None]:
quantile = np.quantile(image[CELL_CHANNEL], 0.999)
print('Quantile', quantile, 'Maximum', image[CELL_CHANNEL].max())

# normalizing image
norm_image = normalize(image[CELL_CHANNEL], quantile)

In [None]:
# deep learning
with th.cuda.amp.autocast():
    pred = model(norm_image)
    
th.cuda.empty_cache()  

# displaying
viewer = napari.Viewer()
viewer.add_image(image, channel_axis=0, name='Input image')
viewer.add_image(pred[0], blending='additive', name='Cell prediction')
viewer.add_image(pred[1], blending='additive', name='Distance map')
if len(pred) > 2:
    viewer.add_image(pred[2], blending='additive', name='Denoising')

In [None]:
# computing segmentation
hiers = hierarchy.create_hierarchies(
    pred[0] > PRED_THOLD,
    pred[1],
    hierarchy_fun=hg.watershed_hierarchy_by_area,
    cache=True,
    min_area=10,
    min_frontier=0,
)
    
for h in hiers:
    h.cut_threshold = CUT_THOLD
labels = hierarchy.to_labels(hiers, pred[0].shape)
    
labels_layer = viewer.add_labels(labels)
labels_layer.contour = 1

In [None]:
# blurring expressions
if CELL_CHANNEL == 0:
    expressions = image[1:].transpose((1, 2, 3, 0))
elif CELL_CHANNEL == 2:
    expressions = image[:CELL_CHANNEL].transpose((1, 2, 3, 0))
else:
    raise RuntimeError
    
viewer.add_image(expressions, name='original expressions', channel_axis=3)

### DECONVOLUTION ###

# change this to use/not use deconvolution
use_deconv = True

if use_deconv:
    """
    Parameter description:
    
    background:
        0 = no background noise
        1 = low background noise
        2 = high background noise
    
    fidelity:
        higher values forces result to be closer to input
    
    sparsity:
        higher values forces results to have more zeros --- it removes more blur/noise
    
    NOTE:
        very HIGH `sparsity` and/or LOW `fidelity` might lead to an empty image as result
    """
    deconv = np.zeros_like(expressions, dtype=float)
    deconv[..., 0] = sparse_deconv(expressions[..., 0], [], background=1, fidelity=150, sparsity=10).get()
    deconv[..., 1] = sparse_deconv(expressions[..., 1], [], background=1, fidelity=150, sparsity=10).get()
    
    viewer.add_image(deconv, name='deconvolved expressions', channel_axis=3)
    
    expressions = deconv

In [None]:
# extracting expressions from segments
props = regionprops(labels, expressions)
exp_statistics = curry(np.mean, axis=0)

# spreading expression to segments
mask = np.zeros(labels.shape + (2,), dtype=expressions.dtype)
for p in props:
    exp1, exp2 = exp_statistics(p.intensity_image[p.image])
    mask[p.slice + (0,)][p.image] = exp1
    mask[p.slice + (1,)][p.image] = exp2

exp1_layer, exp2_layer = viewer.add_image(mask, name='orig. label exp.', channel_axis=3)
exp1_layer.colormap = 'green'
exp2_layer.colormap = 'red'

In [None]:
df = []

for p in props:
    exp1, exp2 = exp_statistics(p.intensity_image[p.image])
    if exp1 > exp1_layer.contrast_limits[0] and exp2 > exp2_layer.contrast_limits[0]:
        df.append([p.label, exp1, exp2, *p.centroid])

df = pd.DataFrame(df, columns=['label', 'exp1', 'exp2', 'z', 'y', 'x'])

# computing colors
green = np.array([[0, 1, 0, 1]])  # exp1
red = np.array([[1, 0, 0, 1]])  # exp2

def normalize_column(df, column, upper_quantile=1):
    col = df[column].values.copy()
    minimum = col.min()
    col -= minimum
    quantile = np.quantile(col, upper_quantile)
    col /= quantile
    
    def norm_fun(x):
        return min(1, max(0, (x - minimum) / quantile))
    
    return np.clip(col, 0, 1), norm_fun

blending_exp1, normfun_exp1 = normalize_column(df, 'exp1')
blending_exp2, normfun_exp2 = normalize_column(df, 'exp2')

colors = green * blending_exp1[:, None] + red * blending_exp2[:, None]
colors[:,3] = 1 # alpha channel

plt.scatter(x=df['exp1'], y=df['exp2'], c=colors)
plt.xlabel('exp. 1'); plt.ylabel('exp. 2')

df.to_csv(DATASET_PATH, index=False)

In [None]:
# matching segments color to plot
mask = np.zeros(labels.shape + (2,), dtype=np.float16)
for p in props:
    exp1, exp2 = exp_statistics(p.intensity_image[p.image])
    if exp1 > exp1_layer.contrast_limits[0] and exp2 > exp2_layer.contrast_limits[0]:
        mask[p.slice + (0,)][p.image] = normfun_exp1(exp1)
        mask[p.slice + (1,)][p.image] = normfun_exp2(exp2)

layers = viewer.add_image(mask, name='selected label exp.', channel_axis=3)
layers[0].colormap = 'green'
layers[1].colormap = 'red'