In [None]:
import napari
import cupy as cp
import numpy as np
import higra as hg
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from toolz import curry

from cucim.skimage.transform import downscale_local_mean, rescale
from cucim.skimage.filters import gaussian

from skimage.filters import threshold_otsu
from skimage.measure import regionprops
from skimage import restoration

from scipy.linalg import lstsq

import torch as th
import torch.nn.functional as F

from tifffile import imread

from dexp_dl.inference import ModelInference
from dexp_dl.postprocessing import hierarchy
from dexp_dl.models import hrnet, unet
from dexp_dl.transforms import *

In [None]:
### INPUT ###
IM_PATH = 'bud stage - embryo3_1_crop_denoised.tif'
WEIGHTS_PATH = 'logs/hrnet_bn/last.ckpt'
CELL_CHANNEL = 0
Z_SCALE = 2
DATASET_PATH = IM_PATH.split('.')[0] + '.csv'

th.cuda.set_device(0)

### PARAMETERS ###
PRED_THOLD = 0.25
CUT_THOLD = 1
EXP_BLUR = 1

In [None]:
def in_transform(image):
    return th.Tensor(image).unsqueeze_(0).half()

def out_transform(image):
    return th.sigmoid(F.interpolate(image, scale_factor=2, mode='trilinear', align_corners=True))

def normalize(image, upper_limit):
    im_min = image.min()
    image = (image - im_min) / (quantile - im_min)
    return np.clip(image, 0, 1)

net = hrnet.hrnet_w18_small_v2(pretrained=False, in_chans=1, num_classes=3, image_ndim=3)

model = ModelInference(
    net,
    transforms=in_transform,
    after_transforms=out_transform,
    tile=(48, 96, 96), num_outputs=3,
)

model.load_weights(WEIGHTS_PATH)

In [None]:
image = imread(IM_PATH)
if image.shape[1] == 3:
    image = image.transpose((1, 0, 2, 3))
    
# making it anisotropic
image = downscale_local_mean(cp.asarray(image), (1, Z_SCALE, 1, 1)).get()

In [None]:
viewer = napari.Viewer()
viewer.add_image(image, channel_axis=0, name='Input image')

In [None]:
quantile = np.quantile(image[CELL_CHANNEL], 0.999)
print('Quantile', quantile, 'Maximum', image[CELL_CHANNEL].max())

# normalizing image
norm_image = normalize(image[CELL_CHANNEL], quantile)

In [None]:
# deep learning
with th.cuda.amp.autocast():
    pred = model(norm_image)
    
    # displaying
viewer = napari.Viewer()
viewer.add_image(image, channel_axis=0, name='Input image')
viewer.add_image(pred[0], blending='additive', name='Cell prediction')
viewer.add_image(pred[1], blending='additive', name='Distance map')
if len(pred) > 2:
    viewer.add_image(pred[2], blending='additive', name='Denoising')

In [None]:
# computing segmentation
hiers = hierarchy.create_hierarchies(
    pred[0] > PRED_THOLD,
    pred[1],
    hierarchy_fun=hg.watershed_hierarchy_by_area,
    cache=True,
    min_area=10,
    min_frontier=0,
)
    
for h in hiers:
    h.cut_threshold = CUT_THOLD
labels = hierarchy.to_labels(hiers, pred[0].shape)
    
labels_layer = viewer.add_labels(labels)
labels_layer.contour = 1

In [None]:
class RemoveBackground:
    """
    The background removal is computed using the cell predictions.
    First the cell predictions are used to compute the median decay in intensity for each z-slice.
    This decay is saved in the variable `self.coefs`.
    After that, we remove the background of any given image by estimating the otsu's threshold for
    each z-slice, computing the median threshold value and applying the previously computed decay.
    Note: we are not centralizing the threshold to z=0 because we observed that this produces betterresults.
    """
    def __init__(self, cell_mask, cell_channel, max_proj_mask=False, display=True):
        n_slices = cell_mask.shape[0]
        if max_proj_mask:
            cell_mask = np.max(cell_mask, axis=0)
            cell_mask = np.tile(cell_mask, (n_slices, 1, 1))
            
        self.cell_mask = cell_mask
        cells = np.ma.MaskedArray(cell_channel, self.cell_mask)
        y = np.ma.median(cells, axis=(1, 2))
        if display:
            plt.plot(y)
        x = np.arange(len(y))
        X = np.stack([np.ones(len(x)), x], axis=1)
        self.coefs = np.linalg.lstsq(X, y, rcond=None)[0]
        if display:
            plt.plot(X @ self.coefs)
            plt.show()
        
    def __call__(self, stack):
        estimated_intensity = []
        for t in range(len(stack)):
            estimated_intensity.append(threshold_otsu(stack[t]))
        x = np.arange(len(stack))
        decay = x * self.coefs[1]
        background_noise = np.median(estimated_intensity) + decay
        # print(background_noise)
        return np.clip(stack - background_noise[:, None, None], 0, None)
    
bkg_rem = RemoveBackground(pred[0] > 0.25, image[CELL_CHANNEL], max_proj_mask=True, display=False)

In [None]:
# blurring expressions
if CELL_CHANNEL == 0:
    expressions = image[1:].transpose((1, 2, 3, 0))
elif CELL_CHANNEL == 2:
    expressions = image[:CELL_CHANNEL].transpose((1, 2, 3, 0))
else:
    raise RuntimeError

for c in range(len(expressions)):
    expressions[c] = bkg_rem(expressions[c])
    
if EXP_BLUR > 0:
    expressions = gaussian(cp.asarray(expressions, dtype=np.float32), sigma=EXP_BLUR, multichannel=True).get()
viewer.add_image(expressions, name='processed expressions', channel_axis=3)

In [None]:
# extracting expressions from segments
props = regionprops(labels, expressions)
exp_statistics = curry(np.median, axis=0)
    
df = []

for p in props:
    exp1, exp2 = exp_statistics(p.intensity_image[p.image])
    df.append([p.label, exp1, exp2, *p.centroid])

df = pd.DataFrame(df, columns=['label', 'tbxt', 'sox2', 'z', 'y', 'x'])

# computing colors
green = np.array([[0, 1, 0, 1]])  # tbxt
magenta = np.array([[1, 0, 1, 1]])  # sox2

def normalize_column(df, column, upper_quantile=0.95):
    col = df[column].values.copy()
    col -= col.min()
    col /= np.quantile(col, upper_quantile)
    return np.clip(col, 0, 1)

tbxt = normalize_column(df, 'tbxt')
sox2 = normalize_column(df, 'sox2')

colors = green * tbxt[:, None] + magenta * sox2[:, None]
colors[:,3] = 1 # alpha channel

plt.scatter(x=df['tbxt'], y=df['sox2'], c=colors)
plt.xlabel('tbxt'); plt.ylabel('sox2')

In [None]:
# spreading expression to segments
mask = np.zeros(labels.shape + (2,), dtype=np.uint16)
for p in props:
    exp1, exp2 = exp_statistics(p.intensity_image[p.image])
    mask[p.slice + (0,)][p.image] = exp1
    mask[p.slice + (1,)][p.image] = exp2

viewer.add_image(mask, name='painted segments', channel_axis=3)

In [None]:
df.to_csv(DATASET_PATH, index=False)