In [1]:
# Import necessary libraries for training
import tifffile
from stardist.models import Config2D, StarDist2D
from csbdeep.utils import Path, normalize
import numpy as np
import random
import os
import time
import tensorflow as tf

# Check if GPU is available
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))


Num GPUs Available:  0


In [2]:
# Paths to your training data
data_path = 'D:/stardist_segmentation/original'  # Update with your path

# Load and normalize your data
X_train_files = sorted(Path(data_path).glob('*.tif'))  # Your image file naming pattern

if len(X_train_files) == 0:
    raise ValueError("No training images found. Please check the paths and ensure there are .tif files in the directories.")

print(f'Found {len(X_train_files)} training images.')

X_train = [normalize(tifffile.imread(file)) for file in X_train_files]
Y_train = [tifffile.imread(file) for file in X_train_files]  # Using the same images as pseudo-labels


Found 28 training images.


In [3]:
def generate_patches(X, Y, patch_size=(256, 256), num_patches=100):
    patches_X = []
    patches_Y = []
    
    for _ in range(num_patches):
        idx = random.randint(0, len(X) - 1)
        img = X[idx]
        lbl = Y[idx]
        
        h, w = img.shape[:2]
        ph, pw = patch_size
        
        if h < ph or w < pw:
            continue
        
        y = random.randint(0, h - ph)
        x = random.randint(0, w - pw)
        
        patch_X = img[y:y+ph, x:x+pw]
        patch_Y = lbl[y:y+ph, x:x+pw]
        
        patches_X.append(patch_X)
        patches_Y.append(patch_Y)
    
    return np.array(patches_X), np.array(patches_Y)

# Create patches
X_patches, Y_patches = generate_patches(X_train, Y_train, patch_size=(256, 256), num_patches=500)  # Reducing the number of patches to speed up

if len(X_patches) == 0 or len(Y_patches) == 0:
    raise ValueError("No patches generated. Please check the input images and patch size.")


In [4]:
# Define the model configuration
conf = Config2D(
    n_rays=32,
    grid=(1, 1),
    use_gpu=False,  # Ensure GPU usage
)


In [5]:
# Instantiate the StarDist model
model = StarDist2D(conf, name='stardist_custom', basedir='models')

# Custom training loop to print the epoch number and measure time
epochs = 5  # Measure time for first 5 epochs
total_time = 0

for epoch in range(epochs):
    print(f'Epoch {epoch + 1} started.')
    start_time = time.time()
    model.train(X_patches, Y_patches, validation_data=(X_patches, Y_patches), epochs=1)
    epoch_time = time.time() - start_time
    total_time += epoch_time
    print(f'Epoch {epoch + 1} completed in {epoch_time:.2f} seconds.')

average_epoch_time = total_time / epochs
print(f'Average time per epoch: {average_epoch_time:.2f} seconds.')
estimated_total_time = average_epoch_time * 400
print(f'Estimated total time for 400 epochs: {estimated_total_time / 3600:.2f} hours.')


base_model.py (198): output path for model already exists, files may be overwritten: D:\stardist_segmentation\models\stardist_custom


Using default values: prob_thresh=0.5, nms_thresh=0.4.
Epoch 1 started.
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1001s[0m 10s/step - dist_dist_iou_metric: 0.1766 - dist_relevant_mae: 4.1032 - dist_relevant_mse: 209.8510 - loss: 1.4095 - prob_kld: 0.1716 - val_dist_dist_iou_metric: 0.3874 - val_dist_relevant_mae: 3.2889 - val_dist_relevant_mse: 167.3624 - val_loss: 1.1371 - val_prob_kld: 0.0647 - learning_rate: 3.0000e-04

Loading network weights from 'weights_best.h5'.
Epoch 1 completed in 2067.16 seconds.
Epoch 2 started.
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m924s[0m 9s/step - dist_dist_iou_metric: 0.3939 - dist_relevant_mae: 2.8995 - dist_relevant_mse: 126.1358 - loss: 1.0587 - prob_kld: 0.0631 - val_dist_dist_iou_metric: 0.4234 - val_dist_relevant_mae: 3.1555 - val_dist_relevant_mse: 155.1654 - val_loss: 1.1041 - val_prob_kld: 0.0583 - learning_rate: 3.0000e-04

Loading network weights from 'weights_best.h5'.
Epoch 2 completed in 2021.44 seco

In [6]:
# Import necessary libraries for segmentation
import os
import tifffile
import cv2
import numpy as np
import matplotlib.pyplot as plt
from stardist.models import StarDist2D
from csbdeep.utils import normalize
from stardist.plot import render_label


In [7]:
# Directory containing the images
input_dir = 'D:/stardist_segmentation/original'
output_dir = 'D:/stardist_segmentation/segmentation_results'

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)


In [8]:
# Load the custom StarDist model
model = StarDist2D(None, name='stardist_custom', basedir='models')


Loading network weights from 'weights_best.h5'.
Couldn't load thresholds from 'thresholds.json', using default values. (Call 'optimize_thresholds' to change that.)
Using default values: prob_thresh=0.5, nms_thresh=0.4.


In [9]:
# Process each image in the directory
for image_name in os.listdir(input_dir):
    if image_name.endswith('.tif') or image_name.endswith('.tiff'):
        image_path = os.path.join(input_dir, image_name)
        img = tifffile.imread(image_path)

        # Normalize the image
        img_normalized = normalize(img, 1, 99.8, axis=(0, 1))

        # Predict instances with default parameters
        labels, details = model.predict_instances(img_normalized)

        # Post-processing using advanced morphological operations
        kernel = np.ones((3, 3), np.uint8)
        labels_post = cv2.morphologyEx(labels.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        labels_post = cv2.morphologyEx(labels_post, cv2.MORPH_OPEN, kernel)

        # Additional post-processing (dilation followed by erosion)
        labels_post = cv2.dilate(labels_post, kernel, iterations=1)
        labels_post = cv2.erode(labels_post, kernel, iterations=1)

        # Save the results
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.imshow(img, cmap='gray')
        plt.title('Input Image')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(render_label(labels_post, img=img))
        plt.title('Prediction + Input Overlay')
        plt.axis('off')

        plt.tight_layout()
        output_path = os.path.join(output_dir, f'{os.path.splitext(image_name)[0]}_segmentation.png')
        plt.savefig(output_path)
        plt.close()

        print(f'Saved segmentation for {image_name} to {output_path}')


Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00001_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00001_CH3_segmentation.png
Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00002_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00002_CH3_segmentation.png
Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00007_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00007_CH3_segmentation.png
Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00008_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00008_CH3_segmentation.png
Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00009_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00009_CH3_segmentation.png
Saved segmentation for Muc1_Ecad_SPC_x20_2_XY10_00010_CH3.tif to D:/stardist_segmentation/segmentation_results\Muc1_Ecad_SPC_x20_2_XY10_00010_CH3_segmentation.png
Saved segmentation for