In [None]:
# SPECIFY path to input TIFF stack
file_path = 'ENTER PATH TO FCMA TIF STACK' 

# SPECIFY filename for the output stack; default is given
output_filename = 'fcma_analysis_result' 

In [None]:
# Import modules to load FCMA tiff files
import skimage
from skimage import color
import skimage.io as skio
import matplotlib.pyplot as plt
import numpy as np
import random
import colorsys
from skimage.measure import find_contours
from matplotlib.patches import Polygon
from PIL import Image
import imageio.v2 as imageio
from tqdm import tqdm

# Import modules for Mask R-CNN model
import os
import sys
import skimage.draw

# Define important directories
ROOT_DIR = os.path.abspath("../")
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")

# Create the logs folder if it doesn't exist
os.makedirs(DEFAULT_LOGS_DIR, exist_ok=True)
DEFAULT_WEIGHTS_PATH = os.path.join(ROOT_DIR, r"Mask_RCNN/mask_rcnn_fcma.h5")
sys.path.append(os.path.join(ROOT_DIR, "Mask_RCNN"))

import FCMAConfig
import mrcnn.model as modellib

############################################################
#  Functions
############################################################

def process_and_save_frames(stack, output_folder, model):
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Process and save the frames
    output_images = []
    for i, frame in tqdm(enumerate(stack), total=len(stack), desc="Processing frames"):
        # Convert frame from grayscale to RGB format
        rgb_frame = color.gray2rgb(frame)
        
        min_value = rgb_frame.min()
        max_value = rgb_frame.max()

        # Normalise the image and convert it to 8-bit
        norm = ((rgb_frame - min_value) / (max_value - min_value) * 255).astype(np.uint8)

        # Convert the normalized PIL Image back to a NumPy array
        png_frame = np.array(norm)
        
        # Detect RBCs using the Mask R-CNN model
        results = model.detect([png_frame], verbose=0)
        r = results[0]
        
        # Append the instance result to the list of detected cells
        detected_cells = []
            
        for j in range(r['rois'].shape[0]):
            class_id = r['class_ids'][j]
            mask = r['masks'][:, :, j]
            bbox = r['rois'][j]
            
            # Calculate the mean fluorescence intensity within the instance mask using the original TIF file
            fluorescence_intensity = round(np.mean(frame[mask]), 3)

            # Append the instance result to the list of detected cells
            detected_cells.append(fluorescence_intensity)
            
        img_result = display_instances_with_metrics(png_frame, r['rois'], r['masks'], r['class_ids'], 
                                                    measurements=detected_cells)
        
        # Save the frame as an image in the output folder
        output_file = os.path.join(output_folder, f"frame_{i}.png")
        image = Image.fromarray(img_result)
        image.save(output_file)
        
        # Open the saved image and append it to the list
        output_images.append(imageio.imread(output_file))

    return output_images

def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.

    Function from Matterport Mask R-CNN visualize module
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


def apply_mask(image, mask, color, alpha=0.2):
    """
    Apply the given mask to the image.
    
    Function from Matterport Mask R-CNN visualize module
    """
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image

def display_instances_with_metrics(image, boxes, masks, class_ids, show_mask = True, measurements=None):
    """
    Overlay instance masks and fluorescence intensity measurements to corresponding frame.
    Return masked image. 

    Function modified from Matterport Mask R-CNN visualize module to display the calculated measurements
    """
    
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
        

    # Create a figure and axis
    height, width = image.shape[:2]
    figsize = (16, 16)
    fig, ax = plt.subplots(1, figsize=figsize)

    # Generate random colors
    colors = random_colors(N)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    ax.set_ylim(height + 10, -10)
    ax.set_xlim(-10, width + 10)
    ax.axis('off')
    
    masked_image = image.astype(np.uint32).copy()
    
    for i in range(N):
        color = colors[i]

        # Bounding box
        if not np.any(boxes[i]):
            # Skip this instance. Has no bbox. Likely lost in image cropping.
            continue
        
        y1, x1, y2, x2 = boxes[i]
        
        # Display fluorescence intensity measurement of each instance
        cell_intensity = measurements[i]
        caption = "Intensity: {}".format(cell_intensity)
        
        
        # Calculate text_x and text_y positions
        text_x = x1
        text_y = y1 - 8   # Move the caption slightly above the bounding box
        
        if text_x >= image.shape[1] - 50:
            text_x = image.shape[1] - 65
            
        if text_y <= 0:
            text_y = 8
        
        ax.text(text_x, text_y, caption,
                color='w', size=13, backgroundcolor="none")
    
        # Mask
        mask = masks[:, :, i]
        if show_mask:
            masked_image = apply_mask(masked_image, mask, color)

        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        padded_mask = np.zeros(
            (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        
        for verts in contours:
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            p = Polygon(verts, facecolor="none", edgecolor=color)
            ax.add_patch(p)
    
    ax.imshow(masked_image.astype(np.uint8))   
    
    # Convert the figure to an image
    fig.canvas.draw()
    img = np.array(fig.canvas.renderer._renderer)

    # Close the figure to release resources
    plt.close(fig)
    
    # Crop figure
    img = img[371:1241, 226:1412]

    return img


############################################################
#  Main
############################################################

# Load Model Configuration
config = FCMAConfig.InferenceConfig()
config.display()

    
# Create model
model = modellib.MaskRCNN(mode="inference", 
                        config=config,
                        model_dir=DEFAULT_LOGS_DIR)

# Select weights file to load
weights_path = DEFAULT_WEIGHTS_PATH

# Load weights
print("Logs: ", DEFAULT_LOGS_DIR)
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
model.keras_model.compile(run_eagerly=config.RUN_EAGERLY)

# Open TIFF stack
stack = skio.imread(file_path, plugin="tifffile")

# Define the output folder
output_folder = os.path.join(ROOT_DIR, 'Results')

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Process input stack; detect and calculate intensity
output_images = process_and_save_frames(stack, output_folder, model)

# Create an image stack from the saved frames
output_stack_file = os.path.join(output_folder, "{}.tif".format(output_filename))
imageio.mimwrite(output_stack_file, output_images, format="tif")

print("Analysis complete. Image stack creation is complete.")

# Iterate through the files in the directory and delete frame_*.png files
for filename in os.listdir(output_folder):
    if filename.startswith("frame_") and filename.endswith(".png"):
        file_path = os.path.join(output_folder, filename)
        os.remove(file_path)

print("Individual frame images have been deleted.")