In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import os
import numpy as np
from tifffile import imread
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible
from stardist.models import StarDist3D
from skimage.morphology import disk, dilation
from tqdm.notebook import tqdm

In [2]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')

for device in gpus:
    tf.config.experimental.set_memory_growth(device, True)

In [3]:
# from tensorflow.compat.v1 import ConfigProto
# from tensorflow.compat.v1 import InteractiveSession

# config = ConfigProto()
# config.gpu_options.allow_growth = True
# session = InteractiveSession(config=config)

## Input

In [4]:
base_path = 'Z:/Data/Analyzed/2024-01-08-Jiakun-MouseSpleen64Gene/images/flamingo/'
data_path = os.path.join(base_path, 'output')
output_path = os.path.join(base_path, 'stardist_segmentation')
if not os.path.exists(output_path):
    os.mkdir(output_path)

In [5]:
model = StarDist3D(None, name='3D_spleen_resnet_2', basedir='models')

Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.642898, nms_thresh=0.5.


## Batch prediction

In [11]:
positions = [f"Position{i+1:03}" for i in range(762, 1524)]
se = disk(1, dtype=np.int32)
axis_norm = (0,1,2)
prob_thresh = 0.6
nms_thresh = 0.1

In [13]:
with open(os.path.join(output_path, "log_2.txt"), "w") as f:
    f.write(f"prob_threshold: {prob_thresh}\nnms_threshold: {nms_thresh}")
    for i, current_position in tqdm(enumerate(positions)):
        current_img = imread(os.path.join(data_path, f"{current_position}.tif"))
        current_img = normalize(current_img, 1, 99.8, axis=axis_norm)
        labels, details = model.predict_instances(current_img, n_tiles=[1, 2, 2], prob_thresh=prob_thresh, nms_thresh=nms_thresh)
    
        for z in range(labels.shape[0]):
            current_slice = labels[z,:,:]
            labels[z,:,:] = dilation(current_slice, se)
    
        current_output = os.path.join(output_path, f"{current_position}.tif")
        save_tiff_imagej_compatible(current_output, labels, axes='ZYX')
        ncells = np.unique(labels).shape[0] - 1
        f.write(f"{current_position}: {ncells}\n")
            

0it [00:00, ?it/s]


  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A
 25%|█████████████████████                                                               | 1/4 [00:11<00:34, 11.62s/it][A
 50%|██████████████████████████████████████████                                          | 2/4 [00:15<00:14,  7.08s/it][A
 75%|███████████████████████████████████████████████████████████████                     | 3/4 [00:19<00:05,  5.60s/it][A
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:23<00:00,  5.80s/it][A
__init__.py (43): Converting data type from 'int32' to ImageJ-compatible 'int16'.

  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A
 25%|█████████████████████                                                               | 1/4 [00:03<00:11,  3.75s/it][A
 50%|██████████████████████████████████████████        