In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import skimage
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'nearest'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tifffile import imread
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import random_label_cmap
from stardist.models import StarDist3D
import skimage.io as io

np.random.seed(6)
lbl_cmap = random_label_cmap()
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


In [None]:
# Running
def merge_overlap(
    image,
    labels,
    x_pos,
    x_num,
    x_st,
    cut_x,
    y_pos,
    y_num,
    y_st,
    cut_y,
    centroids,
    half_overlap_size,
):
    if x_pos == 0:
        x_left_border, x_right_border = 0, 1
    elif x_pos == x_num - 1:
        x_left_border, x_right_border = 1, 0
    else:
        x_left_border, x_right_border = 1, 1

    if y_pos == 0:
        y_left_border, y_right_border = 0, 1
    elif y_pos == y_num - 1:
        y_left_border, y_right_border = 1, 0
    else:
        y_left_border, y_right_border = 1, 1

    image[
        :,
        x_st + half_overlap_size * x_left_border : x_st + cut_x - half_overlap_size * x_right_border,
        y_st + half_overlap_size * y_left_border : y_st + cut_y - half_overlap_size * y_right_border
    ] = labels[
        :,
        half_overlap_size * x_left_border : cut_x - half_overlap_size * x_right_border,
        half_overlap_size * y_left_border : cut_y - half_overlap_size * y_right_border
    ]

    included_centroids = []
    for point in centroids:
        z, x, y = point

        # Check if the point is within the specified range
        if (
            half_overlap_size * x_left_border
            <= x
            < cut_x - half_overlap_size * x_right_border
            and half_overlap_size * y_left_border
            <= y
            < cut_y - half_overlap_size * y_right_border
        ):
            # Add the point to the filtered list
            included_centroids.append(point)

    return image, included_centroids


def divide_and_reconstract(model, image, great_number=1000, overlap_size=50):
    img = normalize(image, 1, 99.8, axis=axis_norm)
    full = np.empty_like(image, dtype=np.int16)

    centroid = []

    x_num = -(-(image.shape[1] - overlap_size) // (512 - overlap_size))
    y_num = -(-(image.shape[2] - overlap_size) // (512 - overlap_size))

    cut_x = image.shape[1] // x_num + overlap_size
    cut_y = image.shape[2] // y_num + overlap_size

    x_st = 0
    y_st = 0
    x_step = cut_x - overlap_size
    y_step = cut_y - overlap_size

    for x_pos in range(x_num):
        for y_pos in range(y_num):
            cut = img[
                :,
                x_st : x_st + cut_x,
                y_st : y_st + cut_y,
            ]

            labels, details = model.predict_instances(cut)
            num = x_pos * x_num + y_pos

            labels += num * great_number
            labels[labels == num * great_number] = 0

            half_overlap_size = overlap_size // 2
            full, label_centroid = merge_overlap(
                full,
                labels,
                x_pos,
                x_num,
                x_st,
                cut_x,
                y_pos,
                y_num,
                y_st,
                cut_y,
                centroid=details["points"],
                half_overlap_size=half_overlap_size,
            )

            centroid.append(label_centroid)

            y_st += y_step
        x_st += x_step

    plt.figure(figsize=(26, 16))
    img_show = img
    z = cut.shape[0] // 2
    y = cut.shape[1] // 2
    plt.subplot(221)
    plt.imshow(img_show[z], cmap="gray", clim=(0, 1))
    plt.axis("off")
    plt.title("XY slice")
    plt.subplot(222)
    plt.imshow(img_show[:, y], cmap="gray", clim=(0, 1))
    plt.axis("off")
    plt.title("XZ slice")
    plt.subplot(223)
    plt.imshow(img_show[z], cmap="gray", clim=(0, 1))
    plt.axis("off")
    plt.title("XY slice")
    plt.imshow(full[z], cmap=lbl_cmap, alpha=0.5)
    plt.subplot(224)
    plt.imshow(img_show[:, y], cmap="gray", clim=(0, 1))
    plt.axis("off")
    plt.title("XZ slice")
    plt.imshow(full[:, y], cmap=lbl_cmap, alpha=0.5)
    plt.tight_layout()
    plt.show()
    return full, np.concatenate(centroid, axis=0)

In [None]:
#Running
S = sorted(glob('data/ctx_dapi/test/images/*0512.tif'))# file name needed
print(S)
X = list(map(imread,S))

n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]
axis_norm = (0,1,2)   # normalize channels independently
# axis_norm = (0,1,2,3) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))

# show all test images
if True:
    fig, ax = plt.subplots(1,3, figsize=(16,16))
    for i,(a,x) in enumerate(zip(ax.flat, X)):
        a.imshow(x[x.shape[0]//2],cmap='gray')
        a.set_title(i)
    [a.axis('off') for a in ax.flat]
    plt.tight_layout()
None;

In [None]:
for i in range(len(X)):
    l, c = divide_and_reconstract(model, X[i], dim=8, great_number=1000)
    cen1 = np.zeros(l.shape)
    root = S[i].replace("images", "r4")  # output dir needed
    print(root)
    root = root.replace(".tif", "_predict.tif")
    print(root)
    l = l.astype(np.uint16)
    io.imsave(root, l)
    root_point = root.replace("tif", "txt")
    with open(root_point, "w") as f:
        for line in c:
            cen1[line[0], line[1], line[2]] = 65535
            f.write(str(line[0]) + "," + str(line[1]) + "," + str(line[2]) + "\n")

    cen1 = skimage.morphology.dilation(cen1, skimage.morphology.ball(5))
    cen1 = cen1.astype(np.uint16)
    io.imsave(root.replace("predict", "centroid"), cen1)