## Cell segmentaion

Please note that this notebook is for just cell segmenation. If you already have generated segmented images (in Fiji/ImageJ, ilastik, or using custom Python/MATLAB program), please ignore this notebook and directly move forward to `makemontage.ipynb`.

You may have to edit this code for correctly segmenting your cell type. 

## Import everything and set up the environment

We are using TkAgg backend for matplotlib to allow for interactive plotting. This is important for the ROI selection process. Also, this make sure that the code does not include the figures into the notebook, but rather opens them in a separate window.

In [None]:
import os
from shutil import copyfile, which
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
import tifffile
from matplotlib.pyplot import pause
from matplotlib.path import Path


matplotlib.use('TKAgg')
plt.ion()
plt.rc("axes",grid=False)

## Set up the file paths and load the data

In [None]:

Current_folder = os.getcwd()
Parent_folder = os.path.dirname(Current_folder)
folder_name = "/media/tatsatb/Data/"
file_name = "Experiment.tif"
full_path_to_load = os.path.join(folder_name, file_name)


Path_to_Save = os.path.join(folder_name, 'MASKS', 'Experiment')
os.makedirs(Path_to_Save, exist_ok=True)
tiff_file = full_path_to_load
info = tifffile.TiffFile(tiff_file)

n_images = len(info.pages)
cols = info.pages[0].shape[1]
rows = info.pages[0].shape[0]

n_frames = int(n_images / 2)



## Load the separated channels into numpy arrays

In [None]:


I_tmp = np.zeros((rows, cols, n_images), dtype=np.uint16)


full_path_red_channel_data = os.path.join(Path_to_Save, 'Red_channel.npy')
full_path_DIC_channel_data = os.path.join(Path_to_Save, 'DIC_channel.npy')


for i in range(0, n_images, 1):
    tmp = tifffile.imread(tiff_file, key=i)
    I_tmp[:, :, i] = tmp


DIC_channel = I_tmp[:, :, 0::2]
Red_channel = I_tmp[:, :, 1::2]


np.save(full_path_red_channel_data, Red_channel)
np.save(full_path_DIC_channel_data, DIC_channel)


plt.figure(figsize=(4, 4))
ax1 = plt.gca()
plt.imshow(Red_channel[:, :, 0], cmap='viridis')
ax1.grid(False)
plt.colorbar()
plt.title('Channel Red - First Frame')
plt.tight_layout()
plt.show()


plt.figure(figsize=(4, 4))
ax3 = plt.gca()
plt.imshow(DIC_channel[:, :, 0], cmap='viridis')
ax3.grid(False)
plt.colorbar()
plt.title('Channel DIC - First Frame')
plt.show()

## Use RoiPoly to select regions of interest (ROIs) in the images

This would be important for segmenting the cells (especially the cells that comes close) in the images. The user can draw polygons around the cells they want to segment and the code will create a mask based on these polygons.

In [None]:
from roipoly import RoiPoly


def enhance_contrast(frame, low_in=2, high_in=98):

    # Convert to float for processing
    frame_float = frame.astype(float)

    # Calculate percentiles for contrast stretching
    low_val = np.percentile(frame_float, low_in)
    high_val = np.percentile(frame_float, high_in)

    # Clip and stretch contrast
    adjusted = np.clip(frame_float, low_val, high_val)
    adjusted = ((adjusted - low_val) * 255 / (high_val - low_val))

    # Ensure uint8 output
    return np.clip(adjusted, 0, 255).astype(np.uint8)


def process_single_channel(frame, frame_num):
    # Enhance contrast for display
    frame_display = enhance_contrast(frame)

    attempt = 1
    while True:
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111)
        ax.grid(False)
        ax.imshow(frame_display, cmap='gray')
        ax.set_title(f'Draw ROI - Frame {frame_num+1} - Attempt {attempt}')

        try:
            roi = RoiPoly(fig=fig, ax=ax, color='r')
            plt.show(block=True)

            # Verify ROI points exist
            if hasattr(roi, 'x') and len(roi.x) > 2:
                ny, nx = frame.shape
                poly_verts = [(roi.x[i], roi.y[i]) for i in range(len(roi.x))]
                x, y = np.meshgrid(np.arange(nx), np.arange(ny))
                xy = np.vstack((x.flatten(), y.flatten())).T

                # Create path for mask
                path = Path(poly_verts)
                mask = path.contains_points(xy).reshape(frame.shape)

                processed = frame.copy()
                processed[~mask] = 0
                plt.close(fig)
                return processed, mask
            else:
                print("Invalid ROI - please try again")
                plt.close(fig)
                attempt += 1
                continue

        except KeyboardInterrupt:
            print("Process interrupted by user. Exiting...")
            plt.close(fig)
            return None, None

        except Exception as e:
            print(f"ROI Error: {str(e)}")
            plt.close(fig)
            attempt += 1
            continue

def process_all_frames(single_channel, masks=None):
    processed = np.zeros_like(single_channel)
    all_masks = np.zeros_like(single_channel, dtype=bool)

    for i in range(0,n_frames):
        processed[:,:,i], mask = process_single_channel(single_channel[:,:,i].copy(), i)
        all_masks[:,:,i] = mask
    return processed, all_masks



processed_red_channel, red_masks = process_all_frames(Red_channel)


## Display a sample frame for the processed images

In [None]:
plt.figure(figsize=(4, 4))
ax1 = plt.gca()
plt.imshow(processed_red_channel[:, :, 0], cmap='viridis')
ax1.grid(False)
plt.colorbar()
plt.title('Channel Red - First Frame')
plt.tight_layout()
plt.show()

## Threhold the image to segment the cells

Change the threhold value to see how it affects the segmentation. This can be adjusted based on the image data.

In [None]:
from scipy.signal import convolve2d
from matplotlib.widgets import Cursor
plt.rc("axes",grid=False)

smoothed_image_red = np.zeros_like(processed_red_channel)

kernel = np.ones((3, 3)) / 9.0

for channel in range(smoothed_image_red.shape[2]):
    smoothed_image_red[:, :, channel] = convolve2d(processed_red_channel[:, :, channel], kernel, mode='same', boundary='fill', fillvalue=0)

threshold = 1200

thresholded_combined = np.where(smoothed_image_red > threshold, 1, 0)

plt.figure(figsize=(12, 4))



ax1 = plt.subplot(221)
im1 = plt.imshow(processed_red_channel[:, :, 0], cmap='viridis')
plt.title('Processed Red Channel - First Frame')


ax2 = plt.subplot(222)
im2 = plt.imshow(smoothed_image_red[:, :, 0], cmap='viridis')
plt.title('Smoothed Red Channel - First Frame')


ax3 = plt.subplot(223)
im3 = plt.imshow(thresholded_combined[:, :, 0], cmap='viridis')
plt.title('Thresholded Red Channel - First Frame')


ax4 = plt.subplot(224)
im4 = plt.imshow(enhance_contrast(DIC_channel[:, :, 0]), cmap='gray')
plt.title('DIC Channel - First Frame')


plt.tight_layout()
plt.show(block=True)
plt.close()



## Properly segment the image using different morphological processing

In [None]:
from scipy import ndimage
from skimage.morphology import binary_erosion, disk, binary_dilation
from skimage.filters import median


eroded = np.zeros_like(thresholded_combined)
despecked = np.zeros_like(thresholded_combined)
filled = np.zeros_like(thresholded_combined)
dilated = np.zeros_like(thresholded_combined)
mask_image_update = np.zeros_like(thresholded_combined)
# mask_image_final = np.zeros_like(thresholded_combined)

for frame in range(thresholded_combined.shape[2]):

    # Fill holes in the image
    filled[:, :, frame] = ndimage.binary_fill_holes(thresholded_combined[:, :, frame])

    # Erode the image
    eroded[:, :, frame] = binary_erosion(filled[:, :, frame], disk(2))

    # Despeckle the image using median filter
    despecked[:, :, frame] = median(eroded[:, :, frame], disk(1))

    # Dilate the image
    dilated[:, :, frame] = ndimage.binary_dilation(despecked[:, :, frame], disk(2))

    # Update mask image
    mask_image_update[:, :, frame] = dilated[:, :, frame]


fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
ax.imshow(mask_image_update[:,:,0], cmap='viridis')

## Save the binarized, segmented images as TIFF files

In [None]:

full_path_mask = os.path.join(Path_to_Save, 'Thesholded_mask.npy')
np.save(full_path_mask, mask_image_update)

mask_image_final_8bit = (mask_image_update * 255).astype(np.uint8)

mask_save_path = os.path.join(Path_to_Save, 'mask_8bit.tif')
tifffile.imwrite(mask_save_path, mask_image_final_8bit.transpose(2,0,1))


## Load the binary image for testing

In [None]:
full_path_mask = os.path.join(Path_to_Save, 'Thresholded_mask.npy')
mask_save_path = os.path.join(Path_to_Save, 'mask_8bit.tif')

mask_image_load = tifffile.imread(mask_save_path)

mask_image_reshape = mask_image_load.transpose(1, 2, 0)
mask_image_final = mask_image_reshape.astype(np.int64)/255

## At this point, please move to `makemontage.ipynb` notebook. 