In [1]:
# =========================
# User-defined parameters
# =========================

SPIRAL_STACK_PATH = "12_no_cells.tif"
CELLS_STACK_PATH  = "12_cells.tif"

OUTPUT_DIFF_PATH  = "12_cells_minus_spiral.tif"
OUTPUT_FRAMES_DIR = "12_cells_only_frames"

# NLMeans parameters (for spiral denoising)
NL_PATCH_SIZE     = 5
NL_PATCH_DISTANCE = 6
NL_H_FACTOR       = 0.8


In [2]:
# =========================
# Imports
# =========================

import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt

import os
from skimage.restoration import denoise_nl_means, estimate_sigma
from ipywidgets import interact, IntSlider

In [3]:
# =========================
# Utility Functions
# =========================

def minmax_normalize(img):
    """
    Min-max normalize a 2D image to [0, 1].
    Returns float32.
    """
    img = img.astype(np.float32)
    mn, mx = img.min(), img.max()
    if mx - mn == 0:
        return np.zeros_like(img, dtype=np.float32)
    return (img - mn) / (mx - mn)


def nlmeans_denoise(img, patch_size=5, patch_distance=6, h_factor=0.8):
    """
    NL-means denoising for fluorescence images.
    """
    img_f = img.astype(np.float32)
    sigma_est = np.mean(estimate_sigma(img_f, channel_axis=None))
    h = h_factor * sigma_est

    img_norm = minmax_normalize(img_f)
    denoised = denoise_nl_means(
        img_norm,
        h=h,
        patch_size=patch_size,
        patch_distance=patch_distance,
        fast_mode=True,
        channel_axis=None
    )

    # rescale back to original intensity range
    return denoised * (img_f.max() - img_f.min()) + img_f.min()


In [4]:
def subtract_spiral(
    spiral_stack,
    cells_stack,
    denoise_spiral=True
):
    """
    Subtract spiral-only signal from spiral+cells stack.

    Returns:
        spiral_processed
        cells_normalized
        diff_clipped
    """
    assert spiral_stack.shape == cells_stack.shape, \
        "Spiral and cells stacks must have the same shape."

    n_frames = spiral_stack.shape[0]

    spiral_proc = np.zeros_like(spiral_stack, dtype=np.float32)
    cells_norm  = np.zeros_like(cells_stack, dtype=np.float32)
    diff_clip   = np.zeros_like(cells_stack, dtype=np.float32)

    for i in range(n_frames):

        spiral = spiral_stack[i]
        cells  = cells_stack[i]

        if denoise_spiral:
            spiral = nlmeans_denoise(
                spiral,
                patch_size=NL_PATCH_SIZE,
                patch_distance=NL_PATCH_DISTANCE,
                h_factor=NL_H_FACTOR
            )

        spiral_n = minmax_normalize(spiral)
        cells_n  = minmax_normalize(cells)

        diff = cells_n - spiral_n
        diff_clip[i] = np.clip(diff, 0, None)

        spiral_proc[i] = spiral_n
        cells_norm[i]  = cells_n

    return spiral_proc, cells_norm, diff_clip


In [6]:
# =========================
# Load image stacks
# =========================

spiral_stack = tiff.imread(SPIRAL_STACK_PATH)
cells_stack  = tiff.imread(CELLS_STACK_PATH)

# =========================
# Run processing
# =========================

spiral_processed, cells_normalized, diff_clipped = subtract_spiral(
    spiral_stack,
    cells_stack,
    denoise_spiral=True
)


In [7]:
# =========================
# Check the Results for all frames
# =========================

def show_frame(frame_idx=0):
    fig, axes = plt.subplots(1, 4, figsize=(16,4))

    axes[0].imshow(cells_normalized[frame_idx], cmap='gray')
    axes[0].set_title("Cells + spiral (norm)")

    axes[1].imshow(spiral_processed[frame_idx], cmap='gray')
    axes[1].set_title("Spiral only (processed)")

    axes[2].imshow(diff_clipped[frame_idx], cmap='gray')
    axes[2].set_title("Cells only (diff clipped)")

    axes[3].imshow(diff_clipped[frame_idx], cmap='bwr')
    axes[3].set_title("Cells only (diff clipped)")

    for ax in axes:
        ax.axis('off')

    plt.tight_layout()
    plt.show()


interact(
    show_frame,
    frame_idx=IntSlider(min=0, max=diff_clipped.shape[0]-1, step=1, value=0)
);


interactive(children=(IntSlider(value=0, description='frame_idx', max=85), Output()), _dom_classes=('widget-inâ€¦

In [8]:
# =========================
# Save Results
# =========================

tiff.imwrite(
    OUTPUT_DIFF_PATH,
    diff_clipped.astype(np.float32)
)

os.makedirs(OUTPUT_FRAMES_DIR, exist_ok=True)

for i, frame in enumerate(diff_clipped):
    output_path = os.path.join(
        OUTPUT_FRAMES_DIR,
        f"cells_only_{i:04d}.tif"
    )

    tiff.imwrite(output_path, frame.astype(np.float32))