In [30]:
import numpy as np
import matplotlib.pyplot as plt
import os
from towbintools.foundation.image_handling import read_tiff_file
from tifffile import imwrite
from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from csbdeep.utils import Path, normalize
import matplotlib
from skimage.measure import label, regionprops
from shutil import copy
import pandas as pd


In [31]:
from stardist.utils import mask_to_categorical
from stardist.plot import render_label

np.random.seed(0)
lbl_cmap = random_label_cmap()
lbl_cmap_classes = matplotlib.colormaps['Set1']

def plot_img_label(img, lbl, cls_dict, n_classes=2, img_title="image", lbl_title="label", cls_title="classes", **kwargs):
    c = mask_to_categorical(lbl, n_classes=n_classes, classes=cls_dict)
    res = np.zeros(lbl.shape, np.uint16)
    for i in range(1,c.shape[-1]):
        m = c[...,i]>0
        res[m] = i
    class_img = lbl_cmap_classes(res)
    class_img[...,:3][res==0] = 0 
    class_img[...,-1][res==0] = 1
    
    fig, (ai,al,ac) = plt.subplots(1,3, figsize=(17,7), gridspec_kw=dict(width_ratios=(1.,1,1)))
    im = ai.imshow(img, cmap='gray')
    #fig.colorbar(im, ax = ai)
    ai.set_title(img_title)    
    al.imshow(render_label(lbl, .8*normalize(img, clip=True), normalize_img=False, alpha_boundary=.8,cmap=lbl_cmap))
    al.set_title(lbl_title)
    ac.imshow(class_img)
    ac.imshow(render_label(res, .8*normalize(img, clip=True), normalize_img=False, alpha_boundary=.8, cmap=lbl_cmap_classes))
    ac.set_title(cls_title)
    plt.tight_layout()    
    for a in ai,al,ac:
        a.axis("off")
    return ai,al,ac

In [32]:
dataset_dir = "/mnt/towbin.data/shared/spsalmon/443 test/dataset/"
raw_dir = os.path.join(dataset_dir, "raw")
mask_dir = os.path.join(dataset_dir, "masks")

split_dataset_dir = "/mnt/towbin.data/shared/spsalmon/443 test/split_dataset/"
split_images_dir = os.path.join(split_dataset_dir, "images")
split_masks_dir = os.path.join(split_dataset_dir, "masks")
split_classes_dir = os.path.join(split_dataset_dir, "class")

os.makedirs(split_dataset_dir, exist_ok=True)
os.makedirs(split_images_dir, exist_ok=True)
os.makedirs(split_masks_dir, exist_ok=True)
os.makedirs(split_classes_dir, exist_ok=True)

In [33]:
raw_images = sorted([os.path.join(raw_dir, f) for f in os.listdir(raw_dir)])
masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)])

# # rename into Time0000_Point000X
# renamed_raw_dir = os.path.join(dataset_dir, "raw")
# renamed_mask_dir = os.path.join(dataset_dir, "corrected_mask")

# os.makedirs(renamed_raw_dir, exist_ok=True)
# os.makedirs(renamed_mask_dir, exist_ok=True)

# for i, (raw_image, mask) in enumerate(zip(raw_images, masks)):
#     target_raw_path = os.path.join(renamed_raw_dir, f"Time0000_Point{str(i).zfill(4)}.tif")
#     target_mask_path = os.path.join(renamed_mask_dir, f"Time0000_Point{str(i).zfill(4)}.tif")
#     copy(raw_image, target_raw_path)
#     copy(mask, target_mask_path)
    

In [34]:
def get_patch_of_label(mask_of_label, image, patch_size=64):
    """
    Compute a patch around the centroid of a region, padding with zeros if necessary.
    
    Parameters:
        mask_of_label (np.ndarray): The mask of the region.
        image (np.ndarray): The intensity image.
        patch_size (int): The size of the patch (must be even).
    
    Returns:
        np.ndarray: A patch of the defined size with the centroid at its center.
    """
    if patch_size % 2 != 0:
        raise ValueError("patch_size must be even")

    # Get the centroid of the region
    centroid = regionprops(mask_of_label.astype("uint8"))[0].centroid
    
    # Calculate patch boundaries
    half_size = patch_size // 2
    minr, minc = int(centroid[0] - half_size), int(centroid[1] - half_size)
    maxr, maxc = minr + patch_size, minc + patch_size
    
    # Create a zero-padded patch
    padded_shape = list(image.shape[:-2]) + [patch_size, patch_size]
    patch = np.zeros(padded_shape, dtype=image.dtype)
    
    # Calculate the overlapping region between the image and the patch
    src_minr, src_minc = max(0, minr), max(0, minc)
    src_maxr, src_maxc = min(image.shape[-2], maxr), min(image.shape[-1], maxc)
    dst_minr, dst_minc = src_minr - minr, src_minc - minc
    dst_maxr, dst_maxc = dst_minr + (src_maxr - src_minr), dst_minc + (src_maxc - src_minc)
    
    # Copy the overlapping region from the image to the patch
    patch[..., dst_minr:dst_maxr, dst_minc:dst_maxc] = image[..., src_minr:src_maxr, src_minc:src_maxc]
    
    return patch

In [35]:
def build_class_dict(raw, mask, low_threshold = 15, high_threshold = 25, patch_size = 64):
    class_dict = {}
    for label in np.unique(mask):
        if label == 0:
            continue
        mask_of_label = mask == label
        patch = get_patch_of_label(mask_of_label, raw, patch_size=patch_size)
        mask_patch = get_patch_of_label(mask_of_label, mask, patch_size=patch_size)
        mask_patch_binary = mask_patch > 0
        inverted_mask_patch = np.invert(mask_patch_binary)
        mask_of_label_patch = mask_patch == label

        gfp_intensity_label = np.median(patch[1, mask_of_label_patch] - np.median(patch[1, inverted_mask_patch]))

        if gfp_intensity_label > high_threshold:
            c = 1 # if high enough, epidermal
        elif gfp_intensity_label < low_threshold:
            c = 2 # if really low, other nuclei type
        else:
            c = 3 # if in between, uncertain

        class_dict[label] = c

    return class_dict

for i, (raw_img, mask) in enumerate(zip(raw_images, masks)):
    raw = read_tiff_file(raw_img)
    mask = read_tiff_file(mask)

    for j, plane in enumerate(raw):
        mask_plane = mask[j]

        raw_save_path = os.path.join(split_images_dir, f"{i}_{j}.tiff")
        mask_save_path = os.path.join(split_masks_dir, f"{i}_{j}.tiff")
        class_save_path = os.path.join(split_classes_dir, f"{i}_{j}.csv")

        class_dict = build_class_dict(plane, mask_plane)

        # save images if class dict is not empty

        if len(class_dict) == 0:
            continue

        imwrite(raw_save_path, plane[0], compression="zlib")
        imwrite(mask_save_path, mask_plane, compression="zlib")

        # save class dict as pandas dataframe

        # Create DataFrame
        class_df = pd.DataFrame(list(class_dict.items()), columns=['label', 'class'])
        
        # Save DataFrame to file
        class_df.to_csv(class_save_path, index=False)