# Processing Pipeline for Consistent CTCF in Fluorescence Microscopy

Here's a comprehensive pipeline for analyzing fluorescence microscopy images with consistent CTCF measurement across different conditions and microscopes:



In [None]:
## Import necessary libraries
import time, os, sys
import traceback
from datetime import datetime

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
import tifffile as tiff
from PIL import Image


import numpy as np
import pandas as pd
import scipy.ndimage as ndi
from scipy import sparse
from skimage import measure, draw, exposure
from skimage.transform import resize
from skimage.segmentation import find_boundaries, watershed
from skimage.filters import threshold_otsu, gaussian
from skimage.feature import peak_local_max
from skimage.morphology import remove_small_objects, binary_closing, disk

from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
import scipy.stats

from tqdm import tqdm
import traceback
import concurrent.futures
from functools import partial
from filelock import FileLock
import torch
import glob
import gc

mpl.rcParams['figure.dpi'] = 200
mpl.use('Agg')  # Set non-interactive backend globally

from AutoImgUtils import * 

In [None]:
!nvcc --version
!nvidia-smi

import os, shutil
import numpy as np
import matplotlib.pyplot as plt
from cellpose import core, utils, io, models, metrics, denoise
from glob import glob

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')


## Function Definitions

### Visualization functions
Here the visualization functions are defined, that can be used to save plots regarding the background substraction as well as the segmentation results for quality control

In [None]:
def visualize_background_mask(channel_image, bg_model, output_path, n_components=3, enhance_contrast=True):
    """Visualize background mask from GMM model with distribution plots"""

    # Detect if we're working with 3D data
    is_3d = len(channel_image.shape) == 3 and channel_image.shape[0] > 1

    if is_3d:
        # 3D visualization setup (3x3 grid)
        fig, axes = plt.subplots(3, 3, figsize=(18, 15), gridspec_kw={'height_ratios': [3, 3, 1]})
        
        # Get dimensions
        z_depth, height, width = channel_image.shape
        
        # Show representative Z-slices
        z_positions = [z_depth // 4, z_depth // 2, 3 * z_depth // 4]
        slice_titles = ["25% Z-Depth", "50% Z-Depth", "75% Z-Depth"]
        
        # Row 1: Z-slices with background highlighted
        for i, (z_pos, title) in enumerate(zip(z_positions, slice_titles)):
            slice_img = channel_image[z_pos]
            slice_mask = bg_model['mask'][z_pos] if len(bg_model['mask'].shape) == 3 else None

            
            # Enhance contrast for visualization if requested
            if enhance_contrast:
                p_low, p_high = 2, 98  # Percentiles for contrast stretching
                display_img = exposure.rescale_intensity(
                    slice_img,
                    in_range=tuple(np.percentile(slice_img, (p_low, p_high))),
                    out_range='dtype'
                )
            else:
                display_img = slice_img
            
            # Display image slice
            axes[0, i].imshow(display_img, cmap='gray')
            axes[0, i].set_title(f'Original - {title}')
            
            # If we have a 3D mask, overlay it on the slice
            if slice_mask is not None:
                # Create RGB image for overlay
                norm_img = (display_img - np.min(display_img)) / (np.max(display_img) - np.min(display_img))
                rgb_img = np.stack([norm_img, norm_img, norm_img], axis=-1)
                
                # Highlight background in red
                rgb_img[:,:,0][slice_mask] = 1.0  # Red
                rgb_img[:,:,1][slice_mask] = 0.0  # Green
                rgb_img[:,:,2][slice_mask] = 0.0  # Blue
                
                axes[0, i].imshow(rgb_img)
            
            axes[0, i].axis('off')
        
        # Row 2: Maximum Intensity Projections with background mask overlay
        # XY projection (top view)
        mip_xy = np.max(channel_image, axis=0)
        mip_bg_xy = np.max(bg_model['mask'], axis=0) if len(bg_model['mask'].shape) == 3 else bg_model['mask']
        
        if enhance_contrast:
            mip_xy = exposure.equalize_adapthist(mip_xy)
        
        axes[1, 0].imshow(mip_xy, cmap='gray')
        axes[1, 0].set_title('XY Maximum Projection')
        axes[1, 0].axis('off')
        
        # Create a red overlay for the background on XY projection
        norm_img = np.zeros_like(mip_xy)
        if mip_xy.max() > 0:
            norm_img = (mip_xy - np.min(mip_xy)) / (np.max(mip_xy) - np.min(mip_xy))
        rgb_img = np.stack([norm_img, norm_img, norm_img], axis=-1)
        rgb_img[:,:,0][mip_bg_xy > 0.5] = 1.0  # Red
        rgb_img[:,:,1][mip_bg_xy > 0.5] = 0.0  # Green
        rgb_img[:,:,2][mip_bg_xy > 0.5] = 0.0  # Blue
        axes[1, 0].imshow(rgb_img)
        
        # YZ projection (side view)
        mip_yz = np.max(channel_image, axis=2)
        mip_bg_yz = np.max(bg_model['mask'], axis=2) if len(bg_model['mask'].shape) == 3 else np.zeros_like(mip_yz)
        
        if enhance_contrast:
            mip_yz = exposure.equalize_adapthist(mip_yz)
        
        axes[1, 1].imshow(mip_yz, cmap='gray')
        axes[1, 1].set_title('YZ Maximum Projection')
        axes[1, 1].axis('off')
        
        # Overlay for YZ
        norm_img = np.zeros_like(mip_yz)
        if mip_yz.max() > 0:
            norm_img = (mip_yz - np.min(mip_yz)) / (np.max(mip_yz) - np.min(mip_yz))
        rgb_img = np.stack([norm_img, norm_img, norm_img], axis=-1)
        rgb_img[:,:,0][mip_bg_yz > 0.5] = 1.0
        rgb_img[:,:,1][mip_bg_yz > 0.5] = 0.0
        rgb_img[:,:,2][mip_bg_yz > 0.5] = 0.0
        axes[1, 1].imshow(rgb_img)
        
        # XZ projection (front view)
        mip_xz = np.max(channel_image, axis=1)
        mip_bg_xz = np.max(bg_model['mask'], axis=1) if len(bg_model['mask'].shape) == 3 else np.zeros_like(mip_xz)
        
        if enhance_contrast:
            mip_xz = exposure.equalize_adapthist(mip_xz)
        
        axes[1, 2].imshow(mip_xz, cmap='gray')
        axes[1, 2].set_title('XZ Maximum Projection')
        axes[1, 2].axis('off')
        
        # Overlay for XZ
        norm_img = np.zeros_like(mip_xz)
        if mip_xz.max() > 0:
            norm_img = (mip_xz - np.min(mip_xz)) / (np.max(mip_xz) - np.min(mip_xz))
        rgb_img = np.stack([norm_img, norm_img, norm_img], axis=-1)
        rgb_img[:,:,0][mip_bg_xz > 0.5] = 1.0
        rgb_img[:,:,1][mip_bg_xz > 0.5] = 0.0
        rgb_img[:,:,2][mip_bg_xz > 0.5] = 0.0
        axes[1, 2].imshow(rgb_img)
        
    else:
        # Create figure with 4 subplots (2x2 grid)
        fig, axes = plt.subplots(2, 2, figsize=(15, 12), gridspec_kw={'height_ratios': [3, 1]})
        
        # Enhance contrast for visualization if requested
        if enhance_contrast:
            # Use percentile-based contrast stretching (robust to outliers)
            p_low, p_high = 2, 98  # Percentiles for contrast stretching
            display_img = exposure.rescale_intensity(
                channel_image, 
                in_range=tuple(np.percentile(channel_image, (p_low, p_high))),
                out_range='dtype'
            )
        else:
            display_img = channel_image
        
        # Original image with enhanced contrast
        axes[0, 0].imshow(display_img, cmap='gray')
        axes[0, 0].set_title('Original Channel' + (' (Contrast Enhanced)' if enhance_contrast else ''))
        
        # Background mask
        axes[0, 1].imshow(bg_model['mask'], cmap='hot')
        axes[0, 1].set_title(f'Background Mask\nMean: {bg_model["mean"]:.2f}, Std: {bg_model["std"]:.2f}')
        
        # Original with background highlighted
        norm_img = (channel_image - np.min(channel_image)) / (np.max(channel_image) - np.min(channel_image))
        rgb_img = np.stack([norm_img, norm_img, norm_img], axis=-1)
        
        # Highlight background in red
        rgb_img[:,:,0][bg_model['mask']] = 1.0  # Set red high for background
        rgb_img[:,:,1][bg_model['mask']] = 0.0  # Set green low for background
        rgb_img[:,:,2][bg_model['mask']] = 0.0  # Set blue low for background
        
        axes[1, 0].imshow(rgb_img)
        axes[1, 0].set_title('Background Regions (Red)')
    
    # Plot intensity histogram with GMM distributions
    if 'component_weights' in bg_model and 'component_means' in bg_model and 'component_covs' in bg_model:
        
        # Create histogram - use the appropriate axes depending on 2D/3D
        hist_ax = axes[2, 1] if is_3d else axes[1, 1]

        flat_img = channel_image.flatten()
        
        # Plot histogram
        hist_range = (np.min(flat_img), np.max(flat_img))
        n_bins = 100
        hist_ax.hist(flat_img, bins=n_bins, range=hist_range, density=True, 
                    alpha=0.6, color='gray', label='Pixel Intensity')
        
        # Create x values for plotting GMM curves
        x = np.linspace(hist_range[0], hist_range[1], 1000)
        
        # Plot the individual components
        colors = ['blue', 'green', 'red', 'purple', 'orange']
        bg_component = bg_model['bg_component']
        
        # Check if this is a composite background model
        is_composite = bg_model.get('method') == 'composite_enhanced'
        
        for i in range(len(bg_model['component_means'])):
            # Calculate component density
            weight = bg_model['component_weights'][i]
            mean = bg_model['component_means'][i]
            std = np.sqrt(bg_model['component_covs'][i])
            
            # Create a normal distribution for this component
            y = weight * scipy.stats.norm.pdf(x, mean, std)
            
            # Plot with higher alpha for background component
            alpha = 0.8 if i == bg_component else 0.5
            
            # Adjust labeling based on whether this is composite or individual
            if is_composite:
                if i == bg_component:
                    label = f"Background (μ={mean:.1f}) [Composite]"
                else:
                    label = f"Component {i+1} (μ={mean:.1f}) [Composite]"
            else:
                if i == bg_component:
                    label = f"Background (μ={mean:.1f})"
                else:
                    label = f"Component {i+1} (μ={mean:.1f})"
            
            hist_ax.plot(x, y, color=colors[i % len(colors)], 
                        alpha=alpha, linewidth=2, label=label)
        
        # Add additional information for composite models
        if is_composite:
            # Add text box with composite model info
            info_text = f"Composite Model Applied\nChannel-specific mean: {bg_model['mean']:.2f}\nChannel-specific std: {bg_model['std']:.2f}\nBackground pixels: {bg_model['bg_percentage']:.1f}%"
            hist_ax.text(0.02, 0.98, info_text, transform=hist_ax.transAxes, 
                        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                        fontsize=8)
        
        hist_ax.set_title('Pixel Intensity Distribution')
        hist_ax.set_xlabel('Pixel Value')
        
        # Use log scale only if there's a wide dynamic range
        if hist_range[1] / hist_range[0] > 100:
            hist_ax.set_xscale('log')
        
        hist_ax.set_ylabel('Density')
        hist_ax.legend(fontsize=8)
        
        # Add channel-specific background statistics as text
        bg_stats_text = f"Channel BG: μ={bg_model['mean']:.2f}, σ={bg_model['std']:.2f}"
        hist_ax.text(0.98, 0.02, bg_stats_text, transform=hist_ax.transAxes,
                    horizontalalignment='right', verticalalignment='bottom',
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8),
                    fontsize=8)
        
        del flat_img, hist_range, n_bins, x
        
    plt.tight_layout()
    plt.savefig(output_path, dpi=200)
    plt.close('all')

    del display_img, rgb_img

def save_mask_as_tiff(mask, output_path, bit_depth=16):
    """
    Save a labeled mask as TIFF with specified bit depth
    
    Parameters:
    - mask: Integer labeled mask
    - output_path: Path to save the TIFF file
    - bit_depth: Bit depth (16 or 32) for the output file
    """
    if np.max(mask) > 65535 and bit_depth == 16:
        print(f"Warning: Mask has {np.max(mask)} labels which exceeds 16-bit range. Using 32-bit.")
        bit_depth = 32
    
    if bit_depth == 16:
        # Convert to 16-bit unsigned integer
        mask_out = mask.astype(np.uint16)
    else:
        # Convert to 32-bit unsigned integer
        mask_out = mask.astype(np.uint32)
    
    # Save as TIFF
    tiff.imwrite(output_path, mask_out) 
    # print(f"Saved mask with {len(np.unique(mask))-1} objects to: {output_path}")

def save_segmentation_qc_images(image, cell_masks, output_dir, img_name, config=None):
    """
    Save quality control images showing zoomed-in regions of segmentation results,
    with enhanced support for 3D data visualization
    
    Parameters:
    - image: Multi-channel image array (H,W,C for 2D or Z,H,W,C for 3D)
    - cell_masks: Integer mask with cell labels (H,W for 2D or Z,H,W for 3D)
    - output_dir: Directory to save output images
    - img_name: Base name of the image being processed
    - config: Configuration dictionary with QC settings
    """
    if config is None:
        config = {}
        # Check if 3D data is used
        is_3d = len(image.shape) == 4 and image.shape[0] > 1
    else:
        is_3d = config.get('use_3d', False) or (len(image.shape) == 4 and image.shape[0] > 1)
    
    # Extract configuration
    num_regions = config.get('qc_num_regions', 3)
    region_size = config.get('qc_region_size', 400)
    channels_to_show = config.get('qc_channels', list(range(image.shape[-1])))
    
    # Get cell properties and centroids
    props = measure.regionprops(cell_masks)
    
    if len(props) == 0:
        print("No cells detected for QC visualization")
        return
    
    # Create QC directory
    qc_dir = os.path.join(output_dir, "qc_regions")
    os.makedirs(qc_dir, exist_ok=True)
    
    # Select cell regions to display
    props_sorted_by_area = sorted(props, key=lambda x: x.area, reverse=True)
    
    # Select some large cells and some random cells for diversity
    selected_props = props_sorted_by_area[:min(num_regions, len(props))]
    
    # Add some random cells from the remaining population if available
    remaining_props = props_sorted_by_area[min(num_regions, len(props)):]
    if remaining_props and len(remaining_props) > num_regions:
        random_indices = np.random.choice(len(remaining_props), 
                                         min(num_regions, len(remaining_props)), 
                                         replace=False)
        selected_props.extend([remaining_props[i] for i in random_indices])
    
    # Process each selected region
    for i, prop in enumerate(selected_props):
        if is_3d:
            # For 3D data, handle differently
            z_depth, h, w, _ = image.shape
            
            # Extract centroid coordinates correctly
            z, y, x = prop.centroid
            z, y, x = int(z), int(y), int(x)
            
            # Calculate region boundaries
            half_size = region_size // 2
            
            # Define region in X and Y dimensions
            y1 = max(0, y - half_size)
            y2 = min(h, y + half_size)
            x1 = max(0, x - half_size)
            x2 = min(w, x + half_size)
            
            # Define Z range for visualization (use a portion of the cell's Z extent)
            z_min, z_max = prop.bbox[0], prop.bbox[3]
            z_center = (z_min + z_max) // 2
            z_half_range = min(5, (z_max - z_min) // 2)  # Show at most 11 z-slices
            z_start = max(0, z_center - z_half_range)
            z_end = min(z_depth, z_center + z_half_range + 1)
            
            # Extract region masks for all Z-slices of interest
            region_masks = cell_masks[z_start:z_end, y1:y2, x1:x2]
            
            # Calculate z-slice indices to visualize
            z_indices = list(range(z_start, z_end))
            
            # Create figure with a grid: rows for channels, columns for z-slices
            num_channels = len(channels_to_show)
            num_z_slices = len(z_indices)
            
            # Create figure with 3 rows: original image slices, mask overlays, MIP projection
            fig_width = min(20, num_z_slices * 3)  # Cap width at 20 inches
            fig, axes = plt.subplots(num_channels, 3, figsize=(fig_width, 4*num_channels))
            
            if num_channels == 1:
                axes = np.array([axes])  # Make 2D for consistent indexing
                
            region_name = f"region_{i+1}_cell_{prop.label}"
            fig.suptitle(f"3D Region {i+1}: Cell {prop.label} (Volume: {prop.area}px, Z-range: {z_min}-{z_max})")
            
            # Process each channel
            for ch_idx, ch in enumerate(channels_to_show):
                if ch < image.shape[-1]:  # Ensure channel exists
                    # First column: Middle Z-slice
                    middle_z_idx = len(z_indices) // 2
                    middle_z = z_indices[middle_z_idx]
                    
                    # Extract middle slice for this channel
                    mid_slice = image[middle_z, y1:y2, x1:x2, ch]
                    
                    # Normalize for display
                    p2, p98 = np.percentile(mid_slice, (2, 98))
                    mid_slice_norm = np.clip((mid_slice - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8)
                    
                    # Display middle slice
                    axes[ch_idx, 0].imshow(mid_slice_norm, cmap='gray')
                    axes[ch_idx, 0].set_title(f"Channel {ch+1} (Z={middle_z})")
                    axes[ch_idx, 0].axis('off')
                    
                    # Create mask overlay for middle slice
                    mid_mask = region_masks[middle_z_idx]
                    
                    # Second column: Middle Z-slice with mask overlay
                    mask_overlay = np.zeros((*mid_slice.shape, 4), dtype=np.uint8)
                    
                    # Unique colors for each cell in the region
                    unique_labels = np.unique(mid_mask)
                    unique_labels = unique_labels[unique_labels > 0]  # Skip background
                    
                    # Create colorful mask overlay
                    for label in unique_labels:
                        color = np.array(plt.cm.tab10(label % 10)) * 255
                        mask_overlay[mid_mask == label] = [*color[:3], 128]  # Semi-transparent
                    
                    # Show mask overlaid on middle slice
                    axes[ch_idx, 1].imshow(mid_slice_norm, cmap='gray')
                    axes[ch_idx, 1].imshow(mask_overlay)
                    axes[ch_idx, 1].set_title(f"Channel {ch+1} with segmentation")
                    axes[ch_idx, 1].axis('off')
                    
                    # Third column: Maximum Intensity Projection
                    # Extract full region for this channel and create MIP
                    region_vol = image[z_start:z_end, y1:y2, x1:x2, ch]
                    mip = np.max(region_vol, axis=0)
                    
                    # Normalize MIP for display
                    p2, p98 = np.percentile(mip, (2, 98))
                    mip_norm = np.clip((mip - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8)
                    
                    # Create mask MIP
                    mask_mip = np.max(region_masks > 0, axis=0).astype(np.uint8)
                    
                    # Create overlay for MIP
                    mask_overlay_mip = np.zeros((*mip.shape, 4), dtype=np.uint8)
                    mask_overlay_mip[mask_mip > 0] = [255, 0, 0, 128]  # Red semi-transparent
                    
                    # Display MIP with overlay
                    axes[ch_idx, 2].imshow(mip_norm, cmap='gray')
                    axes[ch_idx, 2].imshow(mask_overlay_mip)
                    axes[ch_idx, 2].set_title(f"Channel {ch+1} MIP")
                    axes[ch_idx, 2].axis('off')
            
        else:
            # 2D processing (existing code)
            # Get centroid and bounds for region extraction
            y, x = int(prop.centroid[0]), int(prop.centroid[1])
            
            # Define boundaries ensuring they're within image bounds
            h, w = image.shape[0:2]
            half_size = region_size // 2
            
            y1 = max(0, y - half_size)
            y2 = min(h, y + half_size)
            x1 = max(0, x - half_size)
            x2 = min(w, x + half_size)
            
            # Extract region masks
            region_mask = cell_masks[y1:y2, x1:x2]
            
            # Create figure with rows for each channel and columns for (original, mask overlay)
            num_channels = len(channels_to_show)
            fig, axes = plt.subplots(num_channels, 2, figsize=(12, 4*num_channels))
            
            if num_channels == 1:
                axes = np.array([axes])  # Make it 2D for consistent indexing
                
            region_name = f"region_{i+1}_cell_{prop.label}"
            fig.suptitle(f"Region {i+1}: Cell {prop.label} (Area: {prop.area}px)")
            
            # Process each channel
            for ch_idx, ch in enumerate(channels_to_show):
                if ch < image.shape[-1]:  # Ensure channel exists
                    # Extract region for this channel
                    region_img = image[y1:y2, x1:x2, ch]
                    
                    # Normalize for display
                    p2, p98 = np.percentile(region_img, (2, 98))
                    region_img_norm = np.clip((region_img - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8)
                    
                    # Display original channel
                    axes[ch_idx, 0].imshow(region_img_norm, cmap='gray')
                    axes[ch_idx, 0].set_title(f"Channel {ch+1}")
                    axes[ch_idx, 0].axis('off')
                    
                    # Create mask overlay
                    mask_overlay = np.zeros((*region_img.shape, 4), dtype=np.uint8)
                    
                    # Unique colors for each cell in the region
                    unique_labels = np.unique(region_mask)
                    unique_labels = unique_labels[unique_labels > 0]  # Skip background
                    
                    # Create colorful mask overlay
                    for label in unique_labels:
                        color = np.array(plt.cm.tab10(label % 10)) * 255
                        mask_overlay[region_mask == label] = [*color[:3], 128]  # Semi-transparent
                        
                    # Show mask overlaid on original image
                    axes[ch_idx, 1].imshow(region_img_norm, cmap='gray')
                    axes[ch_idx, 1].imshow(mask_overlay)
                    axes[ch_idx, 1].set_title(f"Channel {ch+1} with segmentation")
                    axes[ch_idx, 1].axis('off')
        
        # Adjust layout and save
        plt.tight_layout()
        plt.subplots_adjust(top=0.95)  # Make room for suptitle
        region_filename = os.path.join(qc_dir, f"{os.path.splitext(img_name)[0]}_{region_name}.png")
        plt.savefig(region_filename, dpi=150)
        plt.close(fig)
    
    # print(f"Saved {len(selected_props)} QC region visualizations to {qc_dir}")

def create_visualization(image, masks, measurements, output_path, debug=False):
    """
    Create multi-panel visualization for QC with 2D/3D support and optimized memory usage
    
    Parameters:
    - image: Multi-channel image (H,W,C for 2D or Z,H,W,C for 3D)
    - masks: Cell segmentation masks
    - measurements: Cell measurements
    - output_path: Where to save the visualization
    - debug: Enable detailed timing and progress tracking
    """
    try:
        if debug:
            start_time = time.time()
            print(f"Starting visualization for {output_path}...")
        
        # Force garbage collection before starting
        gc.collect()

        # Check if we're working with 3D data
        is_3d = len(image.shape) == 4 and image.shape[0] > 1
        
        if is_3d:
            # 3D visualization
            z_depth, h, w, n_channels = image.shape
            
            # Calculate representative z-slices to display
            z_positions = [z_depth // 4, z_depth // 2, 3 * z_depth // 4]
            
            # Create figure for 3D data - we'll use 2 rows:
            # Row 1: Z-slices of representative planes
            # Row 2: Maximum intensity projections (XY, YZ, XZ)
            fig_width = max(15, n_channels * 3)  # Scale width based on channel count
            fig = Figure(figsize=(fig_width, 10), dpi=200)
            canvas = FigureCanvasAgg(fig)
            
            if debug:
                print(f"Created 3D visualization figure with size {fig_width}x10 inches")
                
            # Create grid with 2 rows
            grid = fig.add_gridspec(2, n_channels + 1)
            
            # Row 1, Col 1: Show segmentation masks at middle z-slice
            ax = fig.add_subplot(grid[0, 0])
            middle_z = z_depth // 2
            mask_slice = masks[middle_z]
            ax.imshow(mask_slice > 0, cmap='viridis')
            ax.set_title(f'Segmentation (Z={middle_z})')
            ax.axis('off')

            # Add cell labels to middle slice
            if len(measurements) > 0:
                # For 3D, filter cells that are present in this z-slice
                if 'z_range' in measurements[0]:
                    # Get cells that span this z-slice
                    slice_cells = [cell for cell in measurements 
                                if cell['z_range'][0] <= middle_z <= cell['z_range'][1]]
                    
                    # Only label a subset of cells
                    num_labels = min(30, len(slice_cells))
                    if num_labels > 0:
                        label_subset = slice_cells[:num_labels]
                        
                        for cell in label_subset:
                            # Use 3D centroid if available, otherwise use 2D
                            if 'centroid_3d' in cell:
                                z, y, x = cell['centroid_3d']
                                # Only add label if close to this z-slice
                                if abs(z - middle_z) <= 2:  # Within 2 slices
                                    ax.text(x, y, str(cell['label']), color='red', fontsize=5)
            
            # Row 2, Col 1: XY Maximum Intensity Projection
            ax = fig.add_subplot(grid[1, 0])
            mip_xy = np.max(masks, axis=0) > 0  # Binary MIP of masks
            ax.imshow(mip_xy, cmap='viridis')
            ax.set_title('Segmentation MIP (XY)')
            ax.axis('off')
            
            # For each channel, show representative Z-slices and MIPs
            for ch_idx in range(n_channels):
                # Extract this channel's 3D data
                channel_3d = image[:, :, :, ch_idx].copy()
                
                # Row 1: Middle z-slice of this channel
                ax = fig.add_subplot(grid[0, ch_idx + 1])
                
                # Enhance contrast for better visualization
                ch_slice = channel_3d[middle_z]
                enhanced_slice = exposure.equalize_adapthist(ch_slice, clip_limit=0.03)
                
                # Show the channel with segmentation boundaries
                ax.imshow(enhanced_slice, cmap='hot')
                
                # Add cell boundaries overlay
                boundaries = find_boundaries(masks[middle_z] > 0)
                ax.imshow(boundaries, alpha=0.3, cmap='cool')
                
                ax.set_title(f'Ch {ch_idx+1} (Z={middle_z})')
                ax.axis('off')
                
                # Row 2: MIP of this channel with segmentation
                ax = fig.add_subplot(grid[1, ch_idx + 1])
                
                # Create channel MIP with contrast enhancement
                ch_mip = np.max(channel_3d, axis=0)
                enhanced_mip = exposure.equalize_adapthist(ch_mip, clip_limit=0.03)
                
                # Show MIP with segmentation overlay
                ax.imshow(enhanced_mip, cmap='hot')
                
                # Add MIP of boundaries
                boundaries_mip = find_boundaries(mip_xy)
                ax.imshow(boundaries_mip, alpha=0.3, cmap='cool')
                
                ax.set_title(f'Ch {ch_idx+1} MIP')
                ax.axis('off')
                
                # Free memory
                del channel_3d, ch_slice, enhanced_slice, enhanced_mip, boundaries
                
        else:
            # Standard 2D visualization (existing code)
            h, w, n_channels = image.shape
            
            # Calculate reasonable figure size to avoid excessive memory usage
            max_dim = 2000  # Maximum dimension in pixels
            scale_factor = min(1.0, max_dim / max(h, w))
            
            # Create figure without displaying (reduces memory usage)
            dpi = 300  # Maintain good quality
            fig_width = (n_channels + 1) * 4  # 4 inches per panel
            fig_height = 4  # Fixed height
            
            # Use Figure directly instead of pyplot to avoid memory leaks
            fig = Figure(figsize=(fig_width, fig_height), dpi=dpi)
            canvas = FigureCanvasAgg(fig)
            
            if debug:
                print(f"Created figure with size {fig_width}x{fig_height} inches at {dpi} DPI")
                print(f"Processing {n_channels} channels and {len(measurements)} cells")
            
            # Create subplot grid
            grid = fig.add_gridspec(1, n_channels + 1)
            
            # Plot segmentation mask (first panel) with enhanced contrast
            if debug:
                print("Rendering segmentation mask...")
            
            ax = fig.add_subplot(grid[0, 0])
            
            # Convert mask to float for better visualization
            mask_display = (masks > 0).astype(float)
            # Apply contrast enhancement to make it more visible
            mask_display = exposure.equalize_adapthist(mask_display)
            ax.imshow(mask_display, cmap='viridis')  # Use viridis for better contrast
            ax.set_title('Cell Segmentation (Enhanced)')
            ax.axis('off')  # Turn off axes to save memory
            
            # Only add labels for a subset of cells
            if len(measurements) > 0:
                if debug:
                    print("Adding cell labels...")
                # Select a random subset of 50 cells or fewer if there are less than 50
                num_labels = min(50, len(measurements))
                # Use numpy's random choice if measurements is a list, otherwise select first num_labels
                if isinstance(measurements, list):
                    indices = np.random.choice(len(measurements), num_labels, replace=False)
                    label_subset = [measurements[i] for i in indices]
                else:
                    label_subset = measurements[:num_labels]
                    
                for cell in label_subset:
                    y, x = cell['centroid']
                    ax.text(x, y, str(cell['label']), color='red', fontsize=5)
            
            # Process channels with progress tracking
            channel_range = range(n_channels)
            if debug:
                from tqdm import tqdm
                channel_range = tqdm(channel_range, desc="Processing channels")
            
            for ch_idx, ch in enumerate(channel_range):
                if debug:
                    ch_start = time.time()
                    
                # Create subplot for this channel
                ax = fig.add_subplot(grid[0, ch_idx + 1])
                
                # Get channel data and apply adaptive contrast enhancement
                channel_data = image[:,:,ch].copy()  # Make a copy to avoid modifying original
                
                # Adaptive histogram equalization - best for visualizing local features
                enhanced_data = exposure.equalize_adapthist(channel_data, clip_limit=0.03)
                
                # Display the image
                ax.imshow(enhanced_data, cmap='hot')
                ax.set_title(f'Channel {ch+1} (Enhanced)')
                ax.axis('off')  # Turn off axes to save memory
                
                # Show cell boundaries efficiently
                boundaries = find_boundaries(masks > 0)
                ax.imshow(boundaries, alpha=0.3, cmap='cool')
                
                # Free memory
                del channel_data
                del enhanced_data
                del boundaries
                
                if debug:
                    print(f"  Channel {ch+1} rendered in {time.time() - ch_start:.2f}s")
        
        # Adjust layout and save
        if debug:
            print("Saving figure...")
            save_start = time.time()
            
        fig.tight_layout()
        fig.savefig(output_path, bbox_inches='tight')
        
        # Clean up matplotlib resources explicitly
        fig.clf()
        canvas.renderer.clear()
        del fig
        del canvas
        
        # Force garbage collection
        gc.collect()
        
        if debug:
            print(f"Visualization saved in {time.time() - save_start:.2f}s")
            print(f"Total visualization time: {time.time() - start_time:.2f}s")
            
    except Exception as e:
        print(f"Error in visualization: {str(e)}")
        traceback.print_exc()
        
        # Ensure cleanup even on error
        if 'fig' in locals():
            fig.clf()
            del fig
        if 'canvas' in locals():
            canvas.renderer.clear()
            del canvas
        gc.collect()


### Utility Functions

In [None]:
def resample_image(image, scale_factor=0.5, downsample_z=False):
    """Resample image by the given scale factor with optional Z-downsampling"""
    
    # Check if this is 3D data with channels
    is_3d = len(image.shape) == 4 and image.shape[0] > 1
    
    if is_3d:
        z_depth, h, w, c = image.shape
        
        if downsample_z:
            # Downsample in all dimensions including Z
            new_z = int(z_depth * scale_factor)
            new_h, new_w = int(h * scale_factor), int(w * scale_factor)
            resized = resize(image, (new_z, new_h, new_w, c), 
                            preserve_range=True, anti_aliasing=True)
        else:
            # Downsample only in X and Y, preserving Z resolution
            new_h, new_w = int(h * scale_factor), int(w * scale_factor)
            resized = np.zeros((z_depth, new_h, new_w, c), dtype=image.dtype)
            
            # Process each Z-slice
            for z in range(z_depth):
                resized[z] = resize(image[z], (new_h, new_w, c), 
                                    preserve_range=True, anti_aliasing=True)
    else:
        # Original 2D case
        h, w, c = image.shape
        new_h, new_w = int(h * scale_factor), int(w * scale_factor)
        resized = resize(image, (new_h, new_w, c), preserve_range=True, anti_aliasing=True)
    
    return resized.astype(image.dtype)

def create_composite_image(image, method='max', config=None):
    """Create composite image from multi-channel data"""

    # Work with a copy to avoid modifying original
    normalized_img = image.copy().astype(np.float32)

    # Normalize each channel for equal contribution to composite
    for ch in range(normalized_img.shape[-1]):
        channel = normalized_img[..., ch]
        
        # Percentile normalization (1st-99th percentile) for robustness
        low_p = np.percentile(channel, 1)
        high_p = np.percentile(channel, 99)
        
        if high_p > low_p:
            channel = np.clip(channel, low_p, high_p)
            channel = (channel - low_p) / (high_p - low_p)
        else:
            channel = np.zeros_like(channel)
            
        normalized_img[..., ch] = channel

    # Create composite from normalized channels
    if method == 'mean':
        composite = np.mean(normalized_img, axis=-1)
    elif method == 'max':
        composite = np.max(normalized_img, axis=-1)
    elif method == 'weighted':
        weights = config.get('channel_weights', None) if config else None
        if weights is None:
            # Weight by channel variance in normalized space
            weights = []
            for ch in range(normalized_img.shape[-1]):
                channel_var = np.var(normalized_img[..., ch])
                weights.append(channel_var)
            weights = np.array(weights)
            weights = weights / np.sum(weights) if np.sum(weights) > 0 else np.ones_like(weights)

        composite = np.zeros(normalized_img.shape[:-1])
        for ch in range(normalized_img.shape[-1]):
            composite += weights[ch] * normalized_img[..., ch]
    else:
        raise ValueError(f"Unknown composite method: {method}")

    return composite, normalized_img

def check_gpu_availability():
    """Check if GPU is available for processing"""
    if torch.cuda.is_available():
        return True
    else:
        return False

### Background Substraction

In [None]:
def estimate_background_gmm(image, config = None, n_components=2, sample_ratio=0.05, 
                                max_iter=100, max_components=6):
    
    use_composite = config.get('use_bg_composite', False)
    bg_models = {}

    if use_composite and image.shape[-1] > 1:
        # Create composite image for background estimation using ORIGINAL values
        composite_image, normalized_channels = create_composite_image(image, method='max', config=config)
        
        # Get composite background model
        composite_bg_model, gmm_model = estimate_background_gmm_single(composite_image, config, n_components=n_components, 
                                                           sample_ratio=sample_ratio, max_iter=max_iter, 
                                                           max_components=max_components)

        for ch in tqdm(range(image.shape[-1]), desc="Estimating Background Masks (composite)"):
            # Get normalized channel data for prediction
            normalized_channel = normalized_channels[..., ch]
            original_channel = image[..., ch]
            
            # Predict background using the composite GMM on normalized channel data
            bg_mask_ch = predict_background_mask(
                normalized_channel, gmm_model, composite_bg_model['bg_component']
            )
            
            # Calculate channel-specific statistics using original data and predicted mask
            bg_pixels = original_channel[bg_mask_ch]
            
            if len(bg_pixels) > 0:
                ch_bg_mean = np.mean(bg_pixels)
                ch_bg_std = np.std(bg_pixels)
                bg_percentage = np.sum(bg_mask_ch) / bg_mask_ch.size * 100
            else:
                # Fallback if no background pixels found
                print(f"Warning: No background pixels found for channel {ch+1}, using global statistics")
                ch_bg_mean = np.mean(original_channel)
                ch_bg_std = np.std(original_channel)
                bg_percentage = 0
            
            # Create channel-specific background model
            bg_models[ch] = {
                'mean': ch_bg_mean,
                'std': ch_bg_std,
                'mask': bg_mask_ch,  # Channel-specific mask
                'bg_percentage': bg_percentage,
                'component_means': composite_bg_model['component_means'],  # From composite
                'component_weights': composite_bg_model['component_weights'],  # From composite
                'component_covs': composite_bg_model['component_covs'],  # From composite
                'n_components': composite_bg_model['n_components'],
                'bg_component': composite_bg_model['bg_component'],
                'method': 'composite_enhanced'  # Flag for enhanced method
            }
            
    elif image.shape[-1] == 1:
        # Single channel case
        bg_models[0] = estimate_background_gmm_single(image, config, n_components=n_components, 
                                                      sample_ratio=sample_ratio, max_iter=max_iter, 
                                                      max_components=max_components)
    else:
        # Individual channel processing (original behavior)
        for ch in tqdm(range(image.shape[-1]), desc="Estimating Background GMM"):
            channel_data = image[..., ch].copy() # Avoid modifying original data
            bg_models[ch] = estimate_background_gmm_single(channel_data, config, n_components=n_components, 
                                                           sample_ratio=sample_ratio, max_iter=max_iter, 
                                                           max_components=max_components)

    return bg_models

def predict_background_mask(normalized_channel, gmm_model, bg_component):
    """
    Predict background mask for a normalized channel using composite GMM model
    """
    try:
        # Predict background for this normalized channel
        flat_channel = normalized_channel.flatten().reshape(-1, 1)
        pixel_labels = gmm_model.predict(flat_channel)

        bg_mask = (pixel_labels == bg_component).reshape(normalized_channel.shape)
        
        return bg_mask
        
    except Exception as e:
        print(f"Error in background prediction: {str(e)}")
        # Fallback: return a conservative background mask
        return np.zeros_like(normalized_channel, dtype=bool)

def estimate_background_gmm_single(image, config = None, n_components=2, sample_ratio=0.05, 
                                max_iter=100, max_components=6):
    """
    Fast background estimation using GMM with optional adaptive component selection
    Parameters:
    - image: 2D or 3D image array (H,W or Z,H,W)
    - config: Configuration dictionary with options like 'use_3d', 'adaptive_gmm'
    - n_components: Number of GMM components to use (default 2)
    - sample_ratio: Fraction of pixels to sample for GMM fitting (default 0.05)
    - max_iter: Maximum iterations for GMM fitting (default 100)
    - max_components: Maximum number of components to try in adaptive mode (default 6)
    Returns:
    - Dictionary with background mean, std, mask, and GMM parameters
    """

    # Check if 3D and change sample size
    use_3d = config.get('use_3d', False) if config else False

    if use_3d:
        sample_ratio = sample_ratio * 0.5  # Reduce sample size for 3D images
        min_sample_nr = 50000  # Minimum sample size for 3D images
    else:
        min_sample_nr = 5000  # Minimum sample size for 2D images

    try:
        adaptive = config.get('adaptive_gmm', False) if config else False

        # Flatten and sample
        flat_img = image.flatten()
        n_samples = max(min_sample_nr, int(sample_ratio * flat_img.size))
        
        # Use systematic sampling for speed (every nth pixel)
        step = max(1, flat_img.size // n_samples)
        sample_data = flat_img[::step].reshape(-1, 1)
        
        # Select GMM model - adaptive or fixed
        if adaptive:
            bic_scores = []
            models = []
            
            # Try different numbers of components
            for n in range(1, max_components + 1):
                # Initialize with K-means for faster convergence
                kmeans = KMeans(n_clusters=n, n_init=1, max_iter=100)
                kmeans.fit(sample_data)
                
                # Configure and fit GMM
                gmm = GaussianMixture(
                    n_components=n, 
                    random_state=42,
                    n_init=1, 
                    max_iter=max_iter,
                    tol=1e-3,
                    means_init=kmeans.cluster_centers_
                )
                
                gmm.fit(sample_data)
                bic_scores.append(gmm.bic(sample_data))
                models.append(gmm)
                
                del kmeans
                
            # Select model with lowest BIC score
            best_idx = np.argmin(bic_scores)
            gmm = models[best_idx]
            n_components = models[best_idx].n_components
            print(f"Adaptive GMM selected {n_components} components with BIC: {bic_scores[best_idx]:.2f}")
            
            # Clean up unused models
            for i, model in enumerate(models):
                if i != best_idx:
                    del model
            models = None
        else:
            # Non-adaptive - just use specified components
            # Initialize with K-means for faster convergence
            kmeans = KMeans(n_clusters=n_components, n_init=1, max_iter=100)
            kmeans.fit(sample_data)
            
            # Configure GMM and fit
            gmm = GaussianMixture(
                n_components=n_components,
                random_state=42,
                n_init=1,
                max_iter=max_iter,
                tol=1e-3,
                means_init=kmeans.cluster_centers_
            )
            gmm.fit(sample_data)
            del kmeans
        
        # Extract model parameters
        means = gmm.means_.flatten()
        covs = np.array([gmm.covariances_[i].flatten()[0] for i in range(gmm.n_components)])
        weights = gmm.weights_
        
        # Identify background component (lowest mean)
        bg_component = np.argmin(means)
        bg_mean = means[bg_component]
        bg_std = np.sqrt(covs[bg_component])
        
        # Use original model to predict components - more efficient
        pixel_labels = gmm.predict(flat_img.reshape(-1, 1))
        bg_mask = (pixel_labels == bg_component).reshape(image.shape)
        
        # Clean up sample data to free memory
        del sample_data, flat_img, pixel_labels
        
        # Store results
        result = {
            'mean': bg_mean,
            'std': bg_std,
            'mask': bg_mask,
            'bg_percentage': np.sum(bg_mask) / bg_mask.size * 100,
            'component_means': means,
            'component_weights': weights,
            'component_covs': covs,
            'n_components': n_components,
            'bg_component': bg_component,
        }

        if config.get('use_bg_composite'):
            return result, gmm
        
        return result
        
    except Exception as e:
        print(f"Error in GMM background estimation: {str(e)}")
        traceback.print_exc()
        # Return fallback values
        return {'mean': 0, 'std': 0, 'mask': np.zeros_like(image, dtype=bool)}

### Segmentation and Aligment

In [None]:
def segment_nuclei_threshold(batch_images_nuclei, config):
    """
    Threshold-based nuclei segmentation for batch processing (optimized)
    
    Parameters:
    - batch_images_nuclei: List of nucleus channel images (already background-corrected)
    - config: Configuration dictionary
    
    Returns:
    - masks_nuclei: List of integer masks with nuclei labels
    """
    # Get configuration parameters
    lower_thresh_factor = config.get('nuclei_thresh_factor', 2.0)
    upper_thresh = config.get('upper_thresh', 60000)
    min_size = config.get('nuclei_min_size', 5)
    max_size = config.get('nuclei_max_size', 1000)
    use_gpu = config.get('use_gpu', True) and check_gpu_availability()
    
    print(f"Processing {len(batch_images_nuclei)} nuclei images using {'GPU' if use_gpu else 'CPU'}...")
    
    if use_gpu:
        return _segment_nuclei_threshold_gpu(
            batch_images_nuclei, lower_thresh_factor, upper_thresh, 
            min_size, max_size
        )
    else:
        return _segment_nuclei_threshold_cpu_parallel(
            batch_images_nuclei, lower_thresh_factor, upper_thresh, 
            min_size, max_size, config
        )

def _segment_nuclei_threshold_gpu(batch_images_nuclei, lower_thresh_factor, upper_thresh, 
                                 min_size, max_size):
    """GPU-accelerated batch nuclei segmentation"""
    
    device = torch.device('cuda')
    masks_nuclei = []
    
    try:
        # Check if all images have the same size
        shapes = [img.shape for img in batch_images_nuclei]
        if len(set(shapes)) > 1:
            print(f"Images have different sizes: {shapes}. Processing individually on GPU.")
            # Process each image individually instead of batching
            for img in batch_images_nuclei:
                # Convert single image to tensor
                img_tensor = torch.from_numpy(img).float().to(device).unsqueeze(0)
                
                with torch.no_grad():
                    # Calculate threshold for this image
                    img_mean = torch.mean(img_tensor)
                    img_std = torch.std(img_tensor)
                    threshold = img_mean + (lower_thresh_factor * img_std)
                    
                    # Apply threshold
                    binary_mask = (img_tensor > threshold) & (img_tensor < upper_thresh)
                
                # Convert back to CPU for morphological operations
                binary_mask_cpu = binary_mask.squeeze(0).cpu().numpy()
                
                # Morphological operations (CPU-based)
                from skimage.morphology import binary_closing, disk, remove_small_objects
                
                binary_mask_cpu = binary_closing(binary_mask_cpu, disk(2))
                binary_mask_cpu = remove_small_objects(binary_mask_cpu, min_size=min_size)
                
                # Label connected components
                nuclei_labels = measure.label(binary_mask_cpu)
                
                # Filter by size
                props = measure.regionprops(nuclei_labels)
                filtered_labels = np.zeros_like(nuclei_labels)
                new_label = 1
                
                for prop in props:
                    if min_size <= prop.area <= max_size:
                        mask = nuclei_labels == prop.label
                        filtered_labels[mask] = new_label
                        new_label += 1
                
                masks_nuclei.append(filtered_labels)
                
                # Clear GPU memory
                del img_tensor, binary_mask
                torch.cuda.empty_cache()
            
            return masks_nuclei
        
        # Original batching code for same-sized images
        # Process in smaller GPU batches to avoid memory issues
        gpu_batch_size = min(4, len(batch_images_nuclei))
        
        for batch_start in range(0, len(batch_images_nuclei), gpu_batch_size):
            batch_end = min(batch_start + gpu_batch_size, len(batch_images_nuclei))
            gpu_batch = batch_images_nuclei[batch_start:batch_end]
            
            # Convert batch to tensor (now we know they're the same size)
            batch_tensor = torch.stack([
                torch.from_numpy(img).float() for img in gpu_batch
            ]).to(device)
            
            with torch.no_grad():
                # Vectorized threshold calculation
                batch_means = torch.mean(batch_tensor.view(batch_tensor.shape[0], -1), dim=1)
                batch_stds = torch.std(batch_tensor.view(batch_tensor.shape[0], -1), dim=1)
                thresholds = batch_means + (lower_thresh_factor * batch_stds)
                
                # Apply thresholds (vectorized)
                thresholds = thresholds.view(-1, 1, 1)  # Broadcast shape
                if len(batch_tensor.shape) == 4:  # 3D case
                    thresholds = thresholds.view(-1, 1, 1, 1)
                
                binary_masks = (batch_tensor > thresholds) & (batch_tensor < upper_thresh)
            
            # Convert back to CPU for morphological operations and labeling
            binary_masks_cpu = binary_masks.cpu().numpy()
            
            # Process each mask individually for morphological operations
            for i, binary_mask in enumerate(binary_masks_cpu):
                # Morphological operations (CPU-based)
                
                binary_mask = binary_closing(binary_mask, disk(2))
                binary_mask = remove_small_objects(binary_mask, min_size=min_size)
                
                # Label connected components
                nuclei_labels = measure.label(binary_mask)
                
                # Filter by size
                props = measure.regionprops(nuclei_labels)
                filtered_labels = np.zeros_like(nuclei_labels)
                new_label = 1
                
                for prop in props:
                    if min_size <= prop.area <= max_size:
                        mask = nuclei_labels == prop.label
                        filtered_labels[mask] = new_label
                        new_label += 1
                
                masks_nuclei.append(filtered_labels)
            
            # Clear GPU memory
            del batch_tensor, binary_masks
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"GPU nuclei segmentation failed: {e}. Falling back to CPU.")
        return _segment_nuclei_threshold_cpu_parallel(
            batch_images_nuclei, lower_thresh_factor, upper_thresh, 
            min_size, max_size, {'max_workers': 2}  # Reduce workers to avoid crashes
        )
    
    return masks_nuclei

def _segment_nuclei_threshold_cpu_parallel(batch_images_nuclei, lower_thresh_factor, 
                                          upper_thresh, min_size, max_size, config=None):
    """CPU parallelized batch nuclei segmentation with error handling"""
    
    # Prepare arguments for parallel processing
    process_func = partial(
        _process_single_nucleus_image,
        lower_thresh_factor=lower_thresh_factor,
        upper_thresh=upper_thresh,
        min_size=min_size,
        max_size=max_size
    )
    
    # Reduce max_workers and add error handling
    max_workers = min(2, config.get('max_workers', 2)) if config else 2
    masks_nuclei = []
    
    try:
        with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_idx = {
                executor.submit(process_func, img): idx 
                for idx, img in enumerate(batch_images_nuclei)
            }
            
            # Initialize results list with proper size
            results = [None] * len(batch_images_nuclei)
            
            # Collect results as they complete
            for future in concurrent.futures.as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    result = future.result(timeout=60)  # 60 second timeout per image
                    results[idx] = result
                except Exception as e:
                    print(f"Error processing nuclei image {idx}: {str(e)}")
                    # Create empty mask for failed processing
                    img_shape = batch_images_nuclei[idx].shape
                    results[idx] = np.zeros(img_shape, dtype=np.int32)
            
            masks_nuclei = results
            
    except Exception as e:
        print(f"Parallel processing failed: {str(e)}. Processing sequentially.")
        traceback.print_exc()
        
        # Fallback to sequential processing
        masks_nuclei = []
        for i, img in enumerate(batch_images_nuclei):
            try:
                mask = process_func(img)
                masks_nuclei.append(mask)
            except Exception as e:
                print(f"Error processing nuclei image {i} sequentially: {str(e)}")
                masks_nuclei.append(np.zeros(img.shape, dtype=np.int32))
    
    return masks_nuclei

def _process_single_nucleus_image(nucleus_image, lower_thresh_factor, upper_thresh, min_size, max_size):
    """Process a single nucleus image (for parallel processing) - no background subtraction needed"""
    
    # Calculate threshold
    mean_val = np.mean(nucleus_image)
    std_val = np.std(nucleus_image)
    threshold = mean_val + (lower_thresh_factor * std_val)
    
    # Apply threshold with upper limit
    binary_mask = (nucleus_image > threshold) & (nucleus_image < upper_thresh)
    
    # Morphological operations
    binary_mask = binary_closing(binary_mask, disk(2))
    binary_mask = remove_small_objects(binary_mask, min_size=min_size)
    
    # Label connected components
    nuclei_labels = measure.label(binary_mask)
    
    # Filter by size
    props = measure.regionprops(nuclei_labels)
    filtered_labels = np.zeros_like(nuclei_labels)
    new_label = 1
    
    for prop in props:
        if min_size <= prop.area <= max_size:
            mask = nuclei_labels == prop.label
            filtered_labels[mask] = new_label
            new_label += 1
    
    return filtered_labels

In [None]:
def segment_nuclei_from_background_mask(bg_models_nuclei, config):
    """
    Fast nuclei segmentation using pre-computed background masks - CPU optimized
    
    Parameters:
    - bg_models_nuclei: List of background models for each nucleus image
    - config: Configuration dictionary
    
    Returns:
    - masks_nuclei: List of integer masks with nuclei labels
    """
    # Get configuration parameters
    min_size = config.get('nuclei_min_size', 5)
    max_size = config.get('nuclei_max_size', 1000)
    morphology_radius = config.get('nuclei_morphology_radius', 1)
    
    print(f"Processing {len(bg_models_nuclei)} nuclei images using background masks (CPU optimized)...")
    
    # Use ThreadPoolExecutor for I/O bound tasks
    max_workers = min(8, config.get('max_workers', 8))
    
    # Prepare processing function with fixed parameters
    process_func = partial(
        _process_single_nucleus_bg_mask,
        min_size=min_size,
        max_size=max_size,
        morphology_radius=morphology_radius
    )
    
    try:
        # Use ThreadPoolExecutor - much faster for this type of task
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Only pass background models - no need for images!
            masks_nuclei = list(executor.map(process_func, bg_models_nuclei))
            
    except Exception as e:
        print(f"Parallel processing failed: {str(e)}. Processing sequentially.")
        
        # Fallback to sequential processing
        masks_nuclei = []
        for bg_model in bg_models_nuclei:
            try:
                mask = process_func(bg_model)
                masks_nuclei.append(mask)
            except Exception as e:
                print(f"Error processing nuclei image: {str(e)}")
                # Create empty mask - we need to get shape from bg_model['mask']
                masks_nuclei.append(np.zeros(bg_model['mask'].shape, dtype=np.int32))
    
    return masks_nuclei

def _process_single_nucleus_bg_mask(bg_model, min_size=5, max_size=1000, morphology_radius=1):
    """
    Optimized single nucleus processing using only background mask
    
    Parameters:
    - bg_model: Background model containing the mask
    - min_size: Minimum object size
    - max_size: Maximum object size  
    - morphology_radius: Radius for morphological operations
    """
    
    # Get foreground mask (inverse of background) - this is the key insight!
    foreground_mask = ~bg_model['mask']
    
    # Skip morphological operations if radius is 0 or very small
    if morphology_radius > 0:
        # Apply minimal morphological operations
        if morphology_radius == 1:
            foreground_mask = binary_closing(foreground_mask, disk(1))
        else:
            foreground_mask = binary_closing(foreground_mask, disk(morphology_radius))
        
        # Remove tiny objects efficiently
        foreground_mask = remove_small_objects(foreground_mask, min_size=max(1, min_size//3))
    
    # Label connected components efficiently
    nuclei_labels = measure.label(foreground_mask)
    
    # Quick size filtering using vectorized operations
    if np.max(nuclei_labels) > 0:
        # Get region properties for size filtering
        props = measure.regionprops(nuclei_labels)
        
        # Create a mapping array for relabeling
        label_mapping = np.zeros(np.max(nuclei_labels) + 1, dtype=np.int32)
        new_label = 1
        
        for prop in props:
            if min_size <= prop.area <= max_size:
                label_mapping[prop.label] = new_label
                new_label += 1
        
        # Apply mapping efficiently using fancy indexing
        filtered_labels = label_mapping[nuclei_labels]
    else:
        filtered_labels = nuclei_labels
    
    return filtered_labels

In [None]:
def segment_cells_cellpose(image, config, bg_models=None):

    """Cell segmentation using CellPose with optimized memory usage"""
    
    # Extract 3D configuration
    use_3d = config.get('use_3d', False)

    model = models.CellposeModel(gpu=config.get('use_gpu', True))
        
    segmentation_channels = config.get('segmentation_channels', list(range(image.shape[-1])))
    print(f"Using custom segmentation with channels: {[ch+1 for ch in segmentation_channels]}")

    # Create a multi-channel image with the channels of interest
    img_to_segment = np.stack([image[...,ch].copy() for ch in segmentation_channels], axis=-1)

    # Apply background subtraction if available
    for i, ch in enumerate(segmentation_channels):
        if ch in bg_models:
            ch_mean = bg_models[ch]['mean']
            img_to_segment[...,i] = np.clip(img_to_segment[...,i] - ch_mean, 0, None)
            print(f"Applied background subtraction to channel {ch+1} (mean: {ch_mean:.2f})")

    print(f"Running Cellpose with channels: {ch+1}" for ch in segmentation_channels)
    print(f"Image shape for segmentation: {img_to_segment.shape}")
    
    # Run segmentation with debug info
    try:
        masks, flows, styles = model.eval(
            img_to_segment, 
            anisotropy=config.get('anisotropy', 3.0),
            flow_threshold=config.get('flow_threshold', 0.4),
            cellprob_threshold=config.get('cellprob_threshold', 0.0),
            normalize=True,
            progress=True,
            do_3D=use_3d
        )
        
        # Free memory
        del img_to_segment, model
        if 'flows' in locals() and flows is not None:
            del flows
        if 'styles' in locals() and styles is not None:
            del styles

        gc.collect()
        
        print(f"Segmentation complete! Found {len(np.unique(masks))-1} objects")
        return masks
    
    except Exception as e:
        print(f"ERROR in Cellpose: {str(e)}")
        traceback.print_exc()
        return np.zeros(image.shape[:2], dtype=np.int32)

def segment_cells_with_downsampling(image, config, bg_models=None):
    """Segment cells with better memory management"""
    # Get downsampling factor from config
    downsample_factor = config.get('downsample_factor', 1.0)
    
    try:
        if downsample_factor >= 1.0:
            # Process at original resolution
            masks = segment_cells_cellpose(image, config, bg_models)
            return masks
        
        # Downsample image for processing
        small_image = resample_image(image, downsample_factor)
        
        # Adjust cell diameter for downsampled image
        small_config = config.copy()
        small_config['cell_diameter'] = config.get('cell_diameter', 20.0) * downsample_factor
        
        # Run segmentation on smaller image
        small_masks = segment_cells_cellpose(small_image, small_config, bg_models)
        
        # Free memory before upsampling
        del small_image
        gc.collect()
        
        # Upsample masks to original size
        masks_upscaled = resize(small_masks, image.shape[0:2], order=0, preserve_range=True)
        masks_upscaled = masks_upscaled.astype(np.int32)
        
        # Free memory
        del small_masks
        
        return masks_upscaled
        
    except Exception as e:
        print(f"Error in segmentation: {str(e)}")
        traceback.print_exc()
        # Return empty mask in case of error
        return np.zeros(image.shape[:2], dtype=np.int32)

In [None]:
def align_cell_masks_to_nuclei(nuclei_masks, cell_masks, is_3d=False):
    """
    Highly optimized alignment of cell and nuclei masks.
    
    Parameters:
    - nuclei_masks: Integer mask with nuclei labels
    - cell_masks: Integer mask with cell labels
    - is_3d: Boolean indicating if masks are 3D
    - config: Optional configuration dictionary
    
    Returns:
    - combined_masks: Final mask with consistent labeling based on nuclei
    """
    start_time = time.time()
    
    # Quick early exit if no nuclei
    if np.max(nuclei_masks) == 0:
        print("No nuclei detected for alignment")
        return cell_masks
    
    # Get unique labels more efficiently (single pass)
    nuclei_labels = np.unique(nuclei_masks)[1:]  # Skip background 0
    cell_labels = np.unique(cell_masks)[1:]      # Skip background 0
    
    # Initialize output with nuclei
    combined_masks = nuclei_masks.copy()
    
    # Create sparse matrices
    flat_nuclei = nuclei_masks.ravel()
    flat_cells = cell_masks.ravel()
    
    # Only consider pixels where both masks have labels
    valid_pixels = (flat_nuclei > 0) & (flat_cells > 0)
    
    # If there's no overlap at all, just return cell masks
    if not np.any(valid_pixels):
        print("No overlap between nuclei and cells")
        return cell_masks
    
    # Get overlapping pixel coordinates and create sparse matrix in one step
    rows = flat_nuclei[valid_pixels]
    cols = flat_cells[valid_pixels]
    data = np.ones(len(rows), dtype=np.int32)
    
    # Create sparse matrix of shape (max_nucleus+1, max_cell+1)
    n_max = np.max(nuclei_labels) + 1
    c_max = np.max(cell_labels) + 1
    
    # Create overlap matrix with optimized memory usage
    overlap_matrix = sparse.csr_matrix((data, (rows, cols)), 
                                      shape=(n_max, c_max))
    
    # Extract maximum overlap info in parallel
    nucleus_to_cell_map = {}
    
    # For each nucleus, find the cell with maximum overlap 
    for _, nucleus_label in enumerate(nuclei_labels):
        # Get overlaps for this nucleus efficiently 
        row = overlap_matrix[nucleus_label]
        
        # Skip if no overlaps
        if row.nnz == 0:
            continue
            
        # Get max overlap cell
        max_idx = row.indices[np.argmax(row.data)]
        if max_idx > 0:  # Ensure not mapping to background
            nucleus_to_cell_map[nucleus_label] = max_idx
    
    # Get binary mask of all nuclei (just once)
    all_nuclei_mask = nuclei_masks > 0
    
    # Create a mapping array from cell labels to nucleus labels
    cell_to_nuc_array = np.zeros(c_max, dtype=np.int32)
    
    # Set up the mapping
    for nuc_label, cell_label in nucleus_to_cell_map.items():
        cell_to_nuc_array[cell_label] = nuc_label
    
    # OPTIMIZATION 6: Only loop through cells that have a mapping
    mapped_cells = set(nucleus_to_cell_map.values())
    if mapped_cells:
        # Process cells in batches for better performance
        batch_size = min(50, len(mapped_cells))
        mapped_cell_list = list(mapped_cells)
        
        for i in range(0, len(mapped_cell_list), batch_size):
            batch_cells = mapped_cell_list[i:i+batch_size]
            
            # Create mask for this batch of cells
            batch_mask = np.isin(cell_masks, batch_cells)
            
            # Only keep cell regions not already covered by nuclei
            cell_regions = batch_mask & ~all_nuclei_mask
            
            # Apply mapping with vectorized operation
            if np.any(cell_regions):
                # Get indices where we need to update
                y_idx, x_idx = np.where(cell_regions) if not is_3d else np.where(cell_regions)
                
                # Get the original cell labels at these positions
                orig_cell_labels = cell_masks[y_idx, x_idx] if not is_3d else cell_masks[y_idx]
                
                # Map cell labels to nucleus labels using our mapping array
                mapped_nucleus_labels = np.array([cell_to_nuc_array[cl] for cl in orig_cell_labels])
                
                # Update the output mask
                if not is_3d:
                    combined_masks[y_idx, x_idx] = mapped_nucleus_labels
                else:
                    combined_masks[y_idx] = mapped_nucleus_labels
        
    
    return combined_masks

In [None]:
def filter_masks_by_size(masks, min_size=None, max_size=None, config=None):
    """
    Optimized filtering of cell masks by size (area or maximum extension radius)
    
    Parameters:
    - masks: Integer mask with cell labels
    - min_size: Minimum size (pixels or radius)
    - max_size: Maximum size (pixels or radius) 
    - config: Configuration dictionary
    
    Returns:
    - filtered_masks: Filtered mask with consecutive labeling
    """
    if config is None:
        config = {}
    
    # Get filtering parameters
    filter_by_radius = config.get('filter_by_radius', False)
    min_filter = min_size if min_size is not None else config.get('cell_min_size', 5)
    max_filter = max_size if max_size is not None else config.get('cell_max_size', 200)
    
    # Get region properties - compute only what we need
    if filter_by_radius:
        props = measure.regionprops(masks, extra_properties=[])  # No extra properties
    else:
        props = measure.regionprops(masks, extra_properties=[])
    
    if len(props) == 0:
        return np.zeros_like(masks)
    
    # Vectorized size calculation and filtering
    if filter_by_radius:
        # Extract major_axis_length for all props at once
        sizes = np.array([prop.major_axis_length / 2.0 for prop in props])
    else:
        # Extract areas for all props at once
        sizes = np.array([prop.area for prop in props])
    
    # Get original labels
    original_labels = np.array([prop.label for prop in props])
    
    # Vectorized filtering - much faster than individual checks
    valid_mask = (sizes >= min_filter) & (sizes <= max_filter)
    valid_labels = original_labels[valid_mask]
    
    if len(valid_labels) == 0:
        return np.zeros_like(masks)
    
    # Create label mapping array for efficient relabeling
    max_label = np.max(original_labels)
    label_mapping = np.zeros(max_label + 1, dtype=np.int32)
    
    # Map valid labels to consecutive new labels
    new_labels = np.arange(1, len(valid_labels) + 1, dtype=np.int32)
    label_mapping[valid_labels] = new_labels
    
    # Apply mapping efficiently using fancy indexing
    filtered_masks = label_mapping[masks]
    
    # print(f"Size filtering: kept {len(valid_labels)}/{len(props)} objects")
    return filtered_masks

### Cell Measurement and counting
These functions measure the cells fluorescence and count the positive cells per channel

In [None]:
def get_threshold_value(channel_data, bg_mean, bg_std, config, ch):
    """Helper function to calculate threshold value for positive cell counting"""
    if config is None:
        config = {}
        
    threshold_method = config.get('positive_threshold_method', 'bg_plus_std')
    
    if threshold_method == 'bg_plus_std':
        n_std = config.get('positive_threshold_std_multiplier', 2.0)
        threshold = bg_mean + (n_std * bg_std)
    
    elif threshold_method == 'percentile':
        percentile = config.get('positive_threshold_percentile', 75)
        if isinstance(channel_data, torch.Tensor):
            threshold = torch.quantile(channel_data, percentile/100.0).item()
        else:
            threshold = np.percentile(channel_data, percentile)
    
    elif threshold_method == 'otsu':
        if isinstance(channel_data, torch.Tensor):
            channel_np = channel_data.cpu().numpy()
            threshold = threshold_otsu(channel_np)
        else:
            threshold = threshold_otsu(channel_data)
    
    else:  # 'manual'
        threshold = config.get(f'channel_{ch+1}_threshold', bg_mean + 2*bg_std)
    
    return threshold

def measure_cells(image, cell_masks, bg_models, config=None):
    """
    Unified CTCF measurement function that handles both 2D and 3D data
    
    Parameters:
    - image: Multi-channel image array (Y,X,C) for 2D or (Z,Y,X,C) for 3D
    - cell_masks: Integer mask with cell labels (Y,X) for 2D or (Z,Y,X) for 3D
    - bg_models: Background models for each channel
    - config: Configuration dictionary
    
    Returns:
    - List of cell measurements
    """
    # Detect if data is 3D
    is_3d = config.get('use_3d', False) or (len(image.shape) == 4 and len(cell_masks.shape) == 3)
    
    # Check if GPU should be used
    use_gpu = config.get('use_gpu', True)
    if use_gpu and config.get('auto_detect_gpu', True):
        use_gpu = check_gpu_availability()
    
    # Select appropriate implementation
    if is_3d:
        print("Using 3D CTCF measurement...")
        if use_gpu:
            try:
                return measure_cells_ctcf_3d_gpu(image, cell_masks, bg_models, config)
            except Exception as e:
                print(f"3D GPU CTCF measurement failed: {str(e)}. Falling back to CPU.")
                return measure_cells_ctcf_3d_cpu(image, cell_masks, bg_models, config)
        else:
            return measure_cells_ctcf_3d_cpu(image, cell_masks, bg_models, config)
    else:
        # Use existing 2D implementations
        if use_gpu:
            try:
                return measure_cells_ctcf_gpu(image, cell_masks, bg_models, config)
            except Exception as e:
                print(f"GPU CTCF measurement failed: {str(e)}. Falling back to CPU.")
                return measure_cells_ctcf_cpu(image, cell_masks, bg_models, config)
        else:
            return measure_cells_ctcf_cpu(image, cell_masks, bg_models, config)
    
def measure_cells_ctcf_cpu(image, cell_masks, bg_models, config=None):
    """
    Measure CTCF for all cells and channels with integrated positive cell counting
    """
    start_time = time.time()
    
    # Get unique cell labels
    unique_labels = np.unique(cell_masks)
    unique_labels = unique_labels[unique_labels > 0]  # Skip background (0)
    total_cells = len(unique_labels)
    
    if total_cells == 0:
        print("No cells found in mask.")
        return [], {}
    
    print(f"Measuring CTCF for {total_cells} cells using parallel CPU processing...")
    
    # Pre-compute thresholds for positive cell counting
    thresholds = {}
    for ch in range(image.shape[-1]):
        bg_mean = bg_models[ch]['mean'] if ch in bg_models else 0
        bg_std = bg_models[ch]['std'] if ch in bg_models else np.std(image[..., ch])
        thresholds[ch] = get_threshold_value(image[..., ch], bg_mean, bg_std, config, ch)
    
    # Determine optimal batch size and number of workers
    max_workers = config.get('max_workers', min(os.cpu_count(), 8))
    batch_size = config.get('batch_size', max(10, total_cells // (max_workers * 2)))
    
    # Create batches of cell labels
    label_batches = [unique_labels[i:i + batch_size] for i in range(0, len(unique_labels), batch_size)]
    print(f"Processing {len(label_batches)} batches with up to {max_workers} workers")
    
    # Prepare partial function with fixed arguments including thresholds
    process_batch_func = partial(
        process_cell_batch,
        image=image,
        cell_masks=cell_masks,
        bg_models=bg_models,
        thresholds=thresholds
    )
    
    # Process batches in parallel
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        batch_results = list(
            tqdm(
                executor.map(process_batch_func, label_batches), 
                total=len(label_batches),
                desc="Processing batches"
            )
        )
        
        # Combine batch results
        for batch_result in batch_results:
            results.extend(batch_result)
    
    # Final timing
    total_time = time.time() - start_time
    cells_per_sec = total_cells / (total_time + 1e-6)
    print(f"Parallel CTCF measurement complete: {total_cells} cells in {total_time:.2f}s ({cells_per_sec:.2f} cells/sec)")
    
    return results

def process_cell_batch(cell_labels, image, cell_masks, bg_models, thresholds = None):
    """
    Process a batch of cells for CTCF measurement (called by parallel worker).
    """
    # Process-safe matplotlib configuration (avoid conflicts)
    
    batch_results = []
    
    for label in cell_labels:
        # Create mask for this cell only
        cell_mask = cell_masks == label
        
        # Get cell properties
        y_indices, x_indices = np.where(cell_mask)
        if len(y_indices) == 0:
            continue
            
        area = len(y_indices)
        y_centroid = np.mean(y_indices)
        x_centroid = np.mean(x_indices)
        
        # Initialize cell data
        cell_data = {
            'label': int(label),
            'area': int(area),
            'centroid': (float(y_centroid), float(x_centroid)),
            'ctcf': {},
            'mean': {},
            'total': {},
            'bg_value': {},
            'c_positive': {} 
        }
        
        # Process each channel
        for ch in range(image.shape[-1]):
            # Extract the channel data
            channel = image[:,:,ch]
            
            # Use efficient boolean indexing
            cell_pixels = channel[cell_mask]
            
            # Get background value
            bg_value = bg_models[ch]['mean'] if ch in bg_models else 0
            cell_data['bg_value'][ch] = float(bg_value)
            
            # Calculate measurements
            total_intensity = np.sum(cell_pixels)
            mean_intensity = np.mean(cell_pixels)
            ctcf = total_intensity - (area * bg_value)
            
            # Store results
            cell_data['total'][ch] = float(total_intensity)
            cell_data['mean'][ch] = float(mean_intensity)
            cell_data['ctcf'][ch] = float(ctcf)
            
            # Determine if cell is positive for this channel
            threshold = thresholds[ch]
            cell_data['c_positive'][ch] = bool(mean_intensity > threshold)
            
        batch_results.append(cell_data)
    
    return batch_results

def measure_cells_ctcf_gpu(image, cell_masks, bg_models, config=None, debug=False):
    """GPU-accelerated CTCF measurement with integrated positive cell counting"""
    
    start_time = time.time()
    
    # Convert arrays to PyTorch tensors on GPU
    device = torch.device('cuda')
    try:    
        image_tensor = torch.from_numpy(image).to(device).float()
        mask_tensor = torch.from_numpy(cell_masks).to(device)
        
        # Get unique cell IDs for processing
        cell_ids = torch.unique(mask_tensor)[1:]  # Skip 0 (background)
        total_cells = len(cell_ids)
        print(f"Measuring CTCF for {total_cells} cells across {image.shape[-1]} channels on GPU...")
        
        # Pre-compute thresholds for positive cell counting
        thresholds = {}
        for ch in range(image.shape[-1]):
            bg_mean = bg_models[ch]['mean'] if ch in bg_models else 0
            bg_std = bg_models[ch]['std'] if ch in bg_models else torch.std(image_tensor[..., ch]).item()
            thresholds[ch] = get_threshold_value(image_tensor[..., ch], bg_mean, bg_std, config, ch)
        
        # Prepare results container
        measurements = []
        
        # Process in batches
        batch_size = config.get('gpu_batch_size', min(500, total_cells))

        for batch_start in tqdm(range(0, total_cells, batch_size), desc="Processing cell batches"):
            batch_end = min(batch_start + batch_size, total_cells)
            batch_ids = cell_ids[batch_start:batch_end]
            
            # Process each cell in the batch
            for cell_idx in range(len(batch_ids)):
                cell_id = batch_ids[cell_idx].item()
                
                # Create binary mask for this cell
                with torch.no_grad():
                    cell_mask = (mask_tensor == cell_id).bool()
                    cell_area = torch.sum(cell_mask).item()
                    
                    if cell_area == 0:
                        continue
                    
                    # Calculate centroid
                    y_indices, x_indices = torch.where(cell_mask)
                    centroid_y = torch.mean(y_indices.float()).item()
                    centroid_x = torch.mean(x_indices.float()).item()
                
                cell_data = {
                    'label': int(cell_id),
                    'area': int(cell_area),
                    'centroid': (float(centroid_y), float(centroid_x)),
                    'ctcf': {},
                    'mean': {},
                    'total': {},
                    'bg_value': {},
                    'c_positive': {}  
                }
                
                # Process all channels
                for ch in range(image_tensor.shape[2]):
                    with torch.no_grad():
                        channel_data = image_tensor[:, :, ch]
                        bg_value = bg_models[ch]['mean'] if ch in bg_models else 0
                        cell_data['bg_value'][ch] = float(bg_value)
                        
                        # GPU-accelerated measurements
                        cell_pixels = torch.masked_select(channel_data, cell_mask)
                        total_intensity = torch.sum(cell_pixels).item()
                        mean_intensity = torch.mean(cell_pixels).item() if cell_pixels.numel() > 0 else 0
                        ctcf = total_intensity - (cell_area * bg_value)
                        
                        cell_data['total'][ch] = float(total_intensity)
                        cell_data['mean'][ch] = float(mean_intensity)
                        cell_data['ctcf'][ch] = float(ctcf)
                        
                        # Determine if cell is positive for this channel
                        threshold = thresholds[ch]
                        cell_data['c_positive'][ch] = bool(mean_intensity > threshold)
                
                measurements.append(cell_data)
            
            # Free GPU memory after each batch
            torch.cuda.empty_cache()
        
        
        total_time = time.time() - start_time
        cells_per_sec = total_cells / total_time
        print(f"GPU CTCF measurement complete: {total_cells} cells in {total_time:.2f}s ({cells_per_sec:.2f} cells/sec)")
        
        return measurements
    
    except RuntimeError as e:
        print(f"GPU memory error: {str(e)}. Falling back to CPU.")
        traceback.print_exc()
        return measure_cells_ctcf_cpu(image, cell_masks, bg_models, config)
    

def measure_cells_ctcf_3d_cpu(image, cell_masks, bg_models, config=None):
    """
    Measure CTCF for all cells in 3D z-stack data using CPU processing
    
    Parameters:
    - image: 3D multi-channel image array (Z, Y, X, C)
    - cell_masks: 3D integer mask with cell labels (Z, Y, X)
    - bg_models: Background models for each channel and potentially each z-slice
    - config: Configuration dictionary
    
    Returns:
    - List of cell measurements
    """
    start_time = time.time()
    
    # Get unique cell labels
    unique_labels = np.unique(cell_masks)
    unique_labels = unique_labels[unique_labels > 0]  # Skip background (0)
    total_cells = len(unique_labels)
    
    if total_cells == 0:
        print("No cells found in 3D mask.")
        return []
    
    print(f"Measuring 3D CTCF for {total_cells} cells using parallel CPU processing...")
    
    # Pre-compute thresholds for positive cell counting
    thresholds = {}
    for ch in range(image.shape[-1]):
        bg_mean = bg_models[ch]['mean'] if ch in bg_models else 0
        bg_std = bg_models[ch]['std'] if ch in bg_models else np.std(image[..., ch])
        thresholds[ch] = get_threshold_value(image[..., ch], bg_mean, bg_std, config, ch)
    
    # Determine optimal batch size and number of workers
    max_workers = config.get('max_workers', min(os.cpu_count(), 8))
    batch_size = config.get('batch_size', max(5, total_cells // (max_workers * 2)))
    
    # Create batches of cell labels
    label_batches = [unique_labels[i:i + batch_size] for i in range(0, len(unique_labels), batch_size)]
    print(f"Processing {len(label_batches)} batches with up to {max_workers} workers")
    
    # Prepare partial function with fixed arguments including thresholds
    process_batch_func = partial(
        process_cell_batch_3d,
        image=image,
        cell_masks=cell_masks,
        bg_models=bg_models,
        thresholds=thresholds,
        config=config
    )
    
    # Process batches in parallel
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        batch_results = list(
            tqdm(
                executor.map(process_batch_func, label_batches),
                total=len(label_batches),
                desc="Processing 3D cell batches"
            )
        )
        
        # Combine batch results
        for batch_result in batch_results:
            results.extend(batch_result)
    
    # Final timing
    total_time = time.time() - start_time
    cells_per_sec = total_cells / (total_time + 1e-6)
    print(f"3D CTCF measurement complete: {total_cells} cells in {total_time:.2f}s ({cells_per_sec:.2f} cells/sec)")
    
    return results

def process_cell_batch_3d(cell_labels, image, cell_masks, bg_models, thresholds=None, config=None):
    """
    Process a batch of 3D cells for CTCF measurement (called by parallel worker).
    """
    batch_results = []
    
    # Get dimensions
    z_slices, height, width, n_channels = image.shape
    
    # Get background model type (per slice or global)
    bg_model_per_slice = config.get('bg_model_per_slice', False) if config else False
    
    for label in cell_labels:
        # Create mask for this cell only
        cell_mask = cell_masks == label
        
        # Get cell properties in 3D
        z_indices, y_indices, x_indices = np.where(cell_mask)
        if len(z_indices) == 0:
            continue
            
        # Calculate 3D properties
        volume = len(z_indices)  # Total voxels
        z_centroid = np.mean(z_indices)
        y_centroid = np.mean(y_indices)
        x_centroid = np.mean(x_indices)
        
        # Min and max z-slices for this cell (useful for visualization)
        min_z = np.min(z_indices)
        max_z = np.max(z_indices)
        
        # Initialize cell data with 3D structure matching 2D format
        cell_data = {
            'label': int(label),  # Ensure label is int for serialization
            'area': int(volume),  # Use volume for area to maintain compatibility
            'volume': int(volume),  # Also include explicit volume
            'centroid': (float(y_centroid), float(x_centroid)),  # 2D-style centroid for compatibility
            'centroid_3d': (float(z_centroid), float(y_centroid), float(x_centroid)),  # Full 3D centroid
            'z_range': (int(min_z), int(max_z)),  # z-range spanned by this cell
            'ctcf': {},
            'mean': {},
            'total': {},
            'bg_value': {},
            'c_positive': {}  # Add positive cell counting
        }
        
        # Process each channel
        for ch in range(n_channels):
            # Extract the channel data - access the entire channel data at once
            channel_data = image[:, :, :, ch]
            
            # Use efficient boolean indexing for the whole 3D cell
            cell_voxels = channel_data[cell_mask]
            
            bg_value = bg_models[ch]['mean'] if ch in bg_models else 0
            
            cell_data['bg_value'][ch] = float(bg_value)
            
            # Calculate 3D measurements
            total_intensity = np.sum(cell_voxels)
            mean_intensity = np.mean(cell_voxels) if len(cell_voxels) > 0 else 0
            ctcf = total_intensity - (volume * bg_value)  # Apply background correction using volume
            
            # Store results (convert to Python types for serialization)
            cell_data['total'][ch] = float(total_intensity)
            cell_data['mean'][ch] = float(mean_intensity)
            cell_data['ctcf'][ch] = float(ctcf)
            
            # Determine if cell is positive for this channel
            if thresholds and ch in thresholds:
                threshold = thresholds[ch]
                cell_data['c_positive'][ch] = bool(mean_intensity > threshold)
            else:
                # Fallback if no thresholds provided
                cell_data['c_positive'][ch] = False
            
        batch_results.append(cell_data)
    
    return batch_results

def measure_cells_ctcf_3d_gpu(image, cell_masks, bg_models, config=None, debug=False):
    """
    GPU-accelerated 3D CTCF measurement for improved performance
    
    Parameters:
    - image: 3D multi-channel image array (Z, Y, X, C)
    - cell_masks: 3D integer mask with cell labels (Z, Y, X)
    - bg_models: Background models for each channel
    - config: Configuration dictionary
    """
    start_time = time.time()
    
    # Convert arrays to PyTorch tensors on GPU with explicit data type conversion
    device = torch.device('cuda')
    try:
        # Handle 3D data - rearrange to (Z, Y, X, C)
        if len(image.shape) == 4:
            z_slices, height, width, channels = image.shape
        else:
            raise ValueError("Expected 4D input for 3D data (Z, Y, X, C)")

        # Convert to float32 since uint16 is not supported by all operations
        image_tensor = torch.from_numpy(image).to(device).float()
        mask_tensor = torch.from_numpy(cell_masks).to(device)
        
        # Get unique cell IDs for processing
        cell_ids = torch.unique(mask_tensor)[1:]  # Skip 0 (background)
        total_cells = len(cell_ids)
        print(f"Measuring 3D CTCF for {total_cells} cells across {channels} channels on GPU...")
        
        # Pre-compute thresholds for positive cell counting
        thresholds = {}
        for ch in range(image.shape[-1]):
            bg_mean = bg_models[ch]['mean'] if ch in bg_models else 0
            bg_std = bg_models[ch]['std'] if ch in bg_models else torch.std(image_tensor[..., ch]).item()
            thresholds[ch] = get_threshold_value(image_tensor[..., ch], bg_mean, bg_std, config, ch)
        
        # Process in smaller batches for 3D data to avoid GPU memory issues
        batch_size = config.get('gpu_batch_size_3d', min(100, total_cells))
        measurements = []
        
        for batch_start in tqdm(range(0, total_cells, batch_size), desc="Processing 3D cell batches"):
            batch_end = min(batch_start + batch_size, total_cells)
            batch_ids = cell_ids[batch_start:batch_end]
            
            # Process each cell in the batch
            for cell_idx in range(len(batch_ids)):
                cell_id = batch_ids[cell_idx].item()
                
                # Create binary mask for this cell
                with torch.no_grad():  # Reduce memory usage
                    cell_mask = (mask_tensor == cell_id).bool()
                    cell_volume = torch.sum(cell_mask).item()
                    
                    if cell_volume == 0:
                        continue
                    
                    # Calculate 3D centroid
                    z_indices, y_indices, x_indices = torch.where(cell_mask)
                    centroid_z = torch.mean(z_indices.float()).item()
                    centroid_y = torch.mean(y_indices.float()).item()
                    centroid_x = torch.mean(x_indices.float()).item()
                    
                    # Get z-range
                    min_z = torch.min(z_indices).item()
                    max_z = torch.max(z_indices).item()
                
                cell_data = {
                    'label': int(cell_id),
                    'area': int(cell_volume),  # Use volume for area to maintain compatibility
                    'volume': int(cell_volume),  # Also include explicit volume
                    'centroid': (float(centroid_y), float(centroid_x)),  # 2D-style centroid for compatibility
                    'centroid_3d': (float(centroid_z), float(centroid_y), float(centroid_x)),
                    'z_range': (int(min_z), int(max_z)),
                    'ctcf': {},
                    'mean': {},
                    'total': {},
                    'bg_value': {},
                    'c_positive': {}  # Add positive cell counting
                }
                
                # Process all channels
                for ch in range(image_tensor.shape[3]):
                    with torch.no_grad():  # Reduce memory usage
                        # Handle background calculation for 3D data
                        channel_data = image_tensor[:, :, :, ch]
                        
                        # Calculate bg_value for this channel
                        bg_value = bg_models[ch]['mean'] if ch in bg_models else 0                            
                        cell_data['bg_value'][ch] = float(bg_value)
                        
                        # GPU-accelerated measurements for 3D data
                        cell_voxels = torch.masked_select(channel_data, cell_mask)
                        total_intensity = torch.sum(cell_voxels).item()
                        mean_intensity = torch.mean(cell_voxels).item() if cell_voxels.numel() > 0 else 0
                        ctcf = total_intensity - (cell_volume * bg_value)
                        
                        cell_data['total'][ch] = float(total_intensity)
                        cell_data['mean'][ch] = float(mean_intensity)
                        cell_data['ctcf'][ch] = float(ctcf)
                        
                        # Determine if cell is positive for this channel
                        threshold = thresholds[ch]
                        cell_data['c_positive'][ch] = bool(mean_intensity > threshold)
                
                measurements.append(cell_data)
            
            # Free GPU memory after each batch
            torch.cuda.empty_cache()
        
        total_time = time.time() - start_time
        cells_per_sec = total_cells / (total_time + 1e-6)
        print(f"GPU 3D CTCF measurement complete: {total_cells} cells in {total_time:.2f}s ({cells_per_sec:.2f} cells/sec)")
        
        return measurements
    
    except RuntimeError as e:
        print(f"GPU memory error for 3D data: {str(e)}. Falling back to CPU.")
        traceback.print_exc()
        return measure_cells_ctcf_3d_cpu(image, cell_masks, bg_models, config)

### Folder-Wise Processing

In [None]:
def prepare_image_for_cellpose(image, config, bg_models):
    """Prepare image for CellPose segmentation with background subtraction"""
    if config.get('cellpose_model') == 'nuclei_only':
        nuc_channel_idx = config.get('nucleus_channel', 1) - 1
        img_to_segment = image[...,nuc_channel_idx].copy()
        
        # Apply background subtraction if available
        if nuc_channel_idx in bg_models:
            ch_mean = bg_models[nuc_channel_idx]['mean']
            img_to_segment = np.clip(img_to_segment - ch_mean, 0, None)
            
        return img_to_segment
    
    elif config.get('cellpose_model') == 'cyto3':
        # Use cytoplasm and nucleus channels
        cyto_channel_idx = config.get('cytoplasm_channel', 4) - 1
        nuc_channel_idx = config.get('nucleus_channel', 1) - 1
        
        # Create a 2-channel image for CellPose
        img_to_segment = np.stack([
            image[...,nuc_channel_idx].copy(), 
            image[...,cyto_channel_idx].copy()
        ], axis=-1)
        
        # Apply background subtraction if available
        if nuc_channel_idx in bg_models:
            ch_mean = bg_models[nuc_channel_idx]['mean']
            img_to_segment[...,0] = np.clip(img_to_segment[...,0] - ch_mean, 0, None)
            
        if cyto_channel_idx in bg_models:
            ch_mean = bg_models[cyto_channel_idx]['mean']
            img_to_segment[...,1] = np.clip(img_to_segment[...,1] - ch_mean, 0, None)
            
        return img_to_segment
    
    elif config.get('cellpose_model') == 'cpsam':
        # Default case: use all channels of interest
        segmentation_channels = config.get('segmentation_channels', list(range(image.shape[-1])))
        # print(f"Using custom segmentation with channels: {[ch+1 for ch in segmentation_channels]}")

        # Create a multi-channel image with the channels of interest
        img_to_segment = np.stack([image[...,ch].copy() for ch in segmentation_channels], axis=-1)

        # Apply background subtraction if available
        for i, ch in enumerate(segmentation_channels):
            if ch in bg_models:
                ch_mean = bg_models[ch]['mean']
                img_to_segment[...,i] = np.clip(img_to_segment[...,i] - ch_mean, 0, None)
                # print(f"Applied background subtraction to channel {ch+1} (mean: {ch_mean:.2f})")

        return img_to_segment
    else:
        raise ValueError(f"Unsupported CellPose model: {config.get('cellpose_model')}. "
                         "Please specify 'nuclei_only', 'cyto3', or 'cpsam'.")
    
def process_experiment_folder(experiment_folder, config, results_dir=None):
    """
    Process a single experiment folder containing TIF/TIFF files
    
    Parameters:
    - experiment_folder: Path to folder containing TIF/TIFF files
    - config: Configuration dictionary
    - results_dir: Optional custom results directory (if None, one will be created)
    
    Returns:
    - Path to results directory
    """
    # Create results directory if not provided
    if results_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        exp_name = os.path.basename(experiment_folder)
        results_dir = os.path.join(experiment_folder, f"{exp_name}_Analysis_{timestamp}")
    
    os.makedirs(results_dir, exist_ok=True)
    
    # Get all TIFF images in experiment folder (no CZI handling)
    image_files = [f for f in os.listdir(experiment_folder) 
                  if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]
    
    if not image_files:
        print(f"No TIF/TIFF files found in {experiment_folder}")
        print("Note: If you have CZI files, please extract them to TIFF first using external tools")
        return results_dir
        
    print(f"Found {len(image_files)} TIFF files to process in {experiment_folder}")
    
    # Create list of image paths
    all_image_paths = [os.path.join(experiment_folder, f) for f in image_files]
        
    print(f"Total images to process: {len(all_image_paths)}")

    cp_model = config.get('cellpose_model', 'cpsam')  # Default to 'cpsam' if not specified
    print(f"Initializing CellPose model {cp_model}...")
    cellpose_model = models.CellposeModel(gpu=config.get('use_gpu', True), pretrained_model= cp_model)

    # Extract configuration
    is_3d = config.get('use_3d', False)
    # Check if quick filtering is enabled instead of alignment
    b_quick_filtering = config.get('quick_size_filtering', False)
    # Check the nuclei_segmentation method
    nuclei_method = config.get('nuclei_segmentation_method', 'background_mask')  # 'background_mask','threshold', 'cellpose'
    print(f"Using {nuclei_method} method for nuclei segmentation...(CellPose will be used for cells)")

    # Process in batches
    batch_size = config.get('batch_size', 4)
    all_results = []
    
    for batch_idx in range(0, len(all_image_paths), batch_size):
        batch_paths = all_image_paths[batch_idx:batch_idx + batch_size]
        
        print(f"Processing batch {batch_idx//batch_size + 1}/{(len(all_image_paths) + batch_size - 1)//batch_size} "
             f"({len(batch_paths)} images)")
        
        # Force cleanup before each batch
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # 1. LOAD BATCH IMAGES
        loaded_images = []
        for path in tqdm(batch_paths, desc="Loading images", leave=False):
            try:
                image = tiff.imread(path)
                img_name = os.path.basename(path)
                
                # Move the shortest axis (channels) to the last index
                if is_3d:
                    # For 3D, image shape should be (Z, Y, X, C)
                    dims = list(image.shape)
                    if min(dims) < 5:  # Likely channel dimension if small
                        channel_axis = dims.index(min(dims))
                        image = np.moveaxis(image, channel_axis, -1)
                    else:
                        # Assume standard ordering Z,Y,X and add channel dimension if needed
                        if len(image.shape) == 3:
                            image = image[..., np.newaxis]  # Add channel dimension
                else:
                    # Standard 2D case
                    shortest_axis = np.argmin(image.shape)
                    image = np.moveaxis(image, shortest_axis, -1)
                
                loaded_images.append({
                    'path': path,
                    'name': img_name,
                    'image': image
                })
            except Exception as e:
                print(f"Error loading {os.path.basename(path)}: {str(e)}")
        
        if not loaded_images:
            continue
        
        # For very large 3D volumes, use more aggressive garbage collection
        if is_3d and config.get('aggressive_gc', False):
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # 2. PREPARE FOR BATCH PROCESSING
        cellpose_inputs = []
        bg_models_by_image = {}
        
        # Process background sequentially for each image
        for img_data in loaded_images:
            path = img_data['path']
            image = img_data['image']
            img_name = img_data['name']
            img_base = os.path.splitext(img_name)[0]
            
            # Calculate background using GMM fitting
            bg_models = estimate_background_gmm(image, config, n_components=2, sample_ratio=0.05, 
                                max_iter=100, max_components=6)

            if config.get('visualize_bg', True):
                for ch_idx in range(image.shape[-1]):
                    visualize_background_mask(
                        image[..., ch_idx], 
                        bg_models[ch_idx],
                        os.path.join(results_dir, f"{img_base}_bg_mask_ch{ch_idx+1}.png")
                    )
                       
            # Store background models
            bg_models_by_image[path] = bg_models
            
            # Prepare image for CellPose and track channel mapping
            prepared_img = prepare_image_for_cellpose(image, config, bg_models)
            
            # Determine the nuclei channel index in the prepared image
            nuclei_channel_idx_original = config.get('nucleus_channel', 1) - 1
            
            # Get the channel mapping from prepare_image_for_cellpose
            if config.get('cellpose_model') == 'nuclei_only':
                nuclei_channel_idx_prepared = 0  # Single channel, nuclei is at index 0
            elif config.get('cellpose_model') == 'cyto3':
                nuclei_channel_idx_prepared = 0  # First channel is nuclei in cyto3 format
            else:
                # For custom segmentation_channels, find where the original nuclei channel ended up
                segmentation_channels = config.get('segmentation_channels', list(range(image.shape[-1])))
                try:
                    nuclei_channel_idx_prepared = segmentation_channels.index(nuclei_channel_idx_original)
                except ValueError:
                    # Nuclei channel not in segmentation channels, use first channel as fallback
                    print(f"Warning: Nuclei channel {nuclei_channel_idx_original+1} not in segmentation channels. Using first channel.")
                    nuclei_channel_idx_prepared = 0
            
            # Apply downsampling if needed
            if config.get('downsample_factor', 1.0) < 1.0:
                original_shape = prepared_img.shape[:3] if is_3d else prepared_img.shape[:2]
                prepared_img = resample_image(
                    prepared_img,
                    config.get('downsample_factor')
                )
                
                cellpose_inputs.append({
                    'image': prepared_img,
                    'orig_path': path,
                    'orig_shape': original_shape,
                    'nuclei_channel_idx': nuclei_channel_idx_prepared  # Store the correct index
                })
            else:
                cellpose_inputs.append({
                    'image': prepared_img,
                    'orig_path': path,
                    'orig_shape': image.shape[:3] if is_3d else image.shape[:2],
                    'nuclei_channel_idx': nuclei_channel_idx_prepared  # Store the correct index
                })
        
        # 3. BATCH SEGMENTATION WITH CELLPOSE
        print("Running CellPose segmentation in batch mode...")
        cell_masks = {}
        nuclei_masks = {}
        
        # Process batch group with CellPose
        ## Stack images for batch processing
        ## Note: Cellpose expects images in (Y,X,C) format for 2D or (Z,Y,X,C) for 3D
        batch_images = [item['image'] for item in cellpose_inputs]
        batch_paths = [item['orig_path'] for item in cellpose_inputs]
        
        # Extract nuclei images using the correct channel indices
        batch_images_nuclei = []
        if not b_quick_filtering:
            for item in cellpose_inputs:
                nuclei_idx = item['nuclei_channel_idx']
                if nuclei_idx < item['image'].shape[-1]:
                    batch_images_nuclei.append(item['image'][..., nuclei_idx])
                else:
                    # Fallback to first channel if index is out of range
                    print(f"Warning: Nuclei channel index {nuclei_idx} out of range, using channel 0")
                    batch_images_nuclei.append(item['image'][..., 0])
            
        try:
            # Run batch segmentation for cells
            print(f"Segmenting cells... with channels {config.get('segmentation_channels', 'all')}")
            masks, _, _ = cellpose_model.eval(
                batch_images,
                normalize=True,
                do_3D=is_3d,
                flow_threshold=config.get('flow_threshold', 0.4),
                cellprob_threshold=config.get('cellprob_threshold', 0.0),
                anisotropy=config.get('anisotropy', 3.0),
            )
            
            if b_quick_filtering:
                # Create empty nuclei masks for consistency
                masks_nuclei = []
                for i,item in enumerate(batch_paths):
                    masks_nuclei.append(np.zeros_like(masks[i]))
            else:
                # Segment nuclei using selected method
                if nuclei_method == 'background_mask':
                    print("Segmenting nuclei using background masks inversion...")
                    # Extract background models for nuclei channel
                    bg_models_nuclei = []
                    nuclei_channel_idx = config.get('nucleus_channel', 1) - 1
                    
                    for img_data in loaded_images:
                        path = img_data['path']
                        bg_models = bg_models_by_image[path]
                        bg_models_nuclei.append(bg_models[nuclei_channel_idx])
                    
                    masks_nuclei = segment_nuclei_from_background_mask(bg_models_nuclei, config )
                    
                elif nuclei_method == 'threshold':
                    print("Segmenting nuclei using thresholding method...")
                    masks_nuclei = segment_nuclei_threshold(batch_images_nuclei, config)
                else:
                    print(f"Segmenting nuclei... with index {cellpose_inputs[0]['nuclei_channel_idx']}")
                    masks_nuclei, _ , _ = cellpose_model.eval(
                        batch_images_nuclei,
                        normalize=True,
                        do_3D=is_3d,
                        flow_threshold=config.get('flow_threshold', 0.4),
                        cellprob_threshold=config.get('cellprob_threshold', 0.0),
                        anisotropy=config.get('anisotropy', 3.0),
                    )
            
            # Handle upscaling if needed
            for i, path in enumerate(batch_paths):
                # Get original shape for this image
                for input_data in cellpose_inputs:
                    if input_data['orig_path'] == path:
                        orig_shape = input_data['orig_shape']
                        
                        # Upscale mask if downsampled
                        if config.get('downsample_factor', 1.0) < 1.0:
                            cell_masks[path] = resize(masks[i], orig_shape, 
                                                    order=0, preserve_range=True).astype(np.int32)
                            nuclei_masks[path] = resize(masks_nuclei[i], orig_shape,
                                                        order=0, preserve_range=True).astype(np.int32)
                        else:
                            cell_masks[path] = masks[i]
                            nuclei_masks[path] = masks_nuclei[i]
                        break
            
            # Clean up memory
            del masks, masks_nuclei
            gc.collect()
            
            if b_quick_filtering:
                print("Using quick size filtering instead of nuclei alignment...")
                # Apply size filtering to cell masks
                filtered_masks = {}
                for path in tqdm(batch_paths, desc="Filtering masks by size"):
                    if path in cell_masks:
                        filtered_masks[path] = filter_masks_by_size(
                            cell_masks[path], 
                            config=config
                        )
                    else:
                        # Get original shape for creating empty mask
                        for input_data in cellpose_inputs:
                            if input_data['orig_path'] == path:
                                orig_shape = input_data['orig_shape']
                                filtered_masks[path] = np.zeros(orig_shape, dtype=np.int32)
                                break
                
                # Store original cell masks and use filtered masks for measurements
                cell_masks_orig = cell_masks.copy()
                cell_masks = filtered_masks

            else:
                # Align cell masks with nuclei masks to create combined segmentation
                print("Aligning cell masks with nuclei masks (nucleus-based approach)...")
                aligned_masks = {}

                for path in tqdm(batch_paths, desc = "Aligning masks"):
                    if path in cell_masks and path in nuclei_masks:
                        aligned_masks[path] = align_cell_masks_to_nuclei(
                            nuclei_masks[path],
                            cell_masks[path], 
                            is_3d=is_3d
                        )
                    else:
                        # If one of the masks is missing, use nuclei masks as foundation
                        aligned_masks[path] = nuclei_masks.get(path, np.zeros_like(cell_masks.get(path, None)))

                # Store the original cell masks for later use                    
                cell_masks_orig = cell_masks.copy() 
                # Use the aligned masks for measurements
                cell_masks = aligned_masks
        
        except Exception as e:
            print(f"Batch segmentation error: {str(e)}")
            traceback.print_exc()
            
            # Create empty masks for failures
            for input_data in cellpose_inputs:
                path = input_data['orig_path']
                orig_shape = input_data['orig_shape']
                cell_masks[path] = np.zeros(orig_shape, dtype=np.int32)

        print("Batch segmentation complete!")
    
        # 4. PARALLEL CTCF MEASUREMENT AND RESULT GENERATION
        batch_results = []
        
        # Process each image (CTCF calculation can be parallel)
        for img_data in loaded_images:
            path = img_data['path']
            img_name = img_data['name']
            img_base = os.path.splitext(img_name)[0]
            
            if path in cell_masks:
                # Get mask and background models
                masks = cell_masks[path]
                masks_nuclei = nuclei_masks[path]
                masks_cells_orig = cell_masks_orig[path]  # Original cell masks before alignment
                bg_models = bg_models_by_image.get(path, {})
                
                # Count original objects before alignment
                original_nuclei_count = len(np.unique(masks_nuclei)) - 1  # Subtract background
                original_cells_count = len(np.unique(masks_cells_orig)) - 1  # Subtract background
                
                try:

                    # Measure cells with GPU or CPU
                    cell_measurements = measure_cells(
                        img_data['image'], masks, bg_models, config
                    )

                    # Define channels of interest from config
                    channels_of_interest = config.get('channels_of_interest', list(range(img_data['image'].shape[-1])))

                    # Save cell measurements to CSV with additional metadata
                    cell_df = pd.DataFrame([
                        {
                            'image': img_name,
                            'cell_id': i,
                            'area': cell['area'] if 'area' in cell else cell.get('volume', 0),
                            **{f'channel_{ch+1}_ctcf': cell['ctcf'][ch] for ch in channels_of_interest},
                            **{f'channel_{ch+1}_mean': cell['mean'][ch] for ch in channels_of_interest},
                            'centroid_x': cell['centroid'][1] if 'centroid' in cell else cell.get('centroid_3d', [0, 0, 0])[2],
                            'centroid_y': cell['centroid'][0] if 'centroid' in cell else cell.get('centroid_3d', [0, 0, 0])[1],
                            **{f'positive': cell['c_positive'][ch] for ch in channels_of_interest} 
                        }
                        for i, cell in enumerate(cell_measurements)
                    ])
                    cell_df.to_csv(os.path.join(results_dir, f"{img_base}_cells.csv"), index=False)
                            
                    # Create visualization if configured
                    if config.get('visualize_segmentation', True):
                        create_visualization(
                            img_data['image'],
                            masks,
                            cell_measurements,
                            os.path.join(results_dir, f"{img_name}_analysis.png"),
                            debug=config.get('debug', False)
                        )
                    
                    # Create QC region images if configured
                    if config.get('save_qc_regions', True):
                        save_segmentation_qc_images(
                            img_data['image'],
                            masks,
                            results_dir,
                            img_name,
                            config
                        )

                    # Create nuclei visualization if configured
                    if config.get('visualize_nuclei', True):
                        create_visualization(
                            img_data['image'],
                            masks_nuclei,
                            cell_measurements,
                            os.path.join(results_dir, f"{img_name}_nuclei.png"),
                            debug=config.get('debug', False)
                        )

                    # Cerate nuclei QC region images if configured
                    if config.get('save_nuclei_qc_regions', True):
                        nuclei_qc_dir = os.path.join(results_dir, "nuclei_qc")
                        os.makedirs(nuclei_qc_dir, exist_ok=True)
                        save_segmentation_qc_images(
                            img_data['image'],
                            masks_nuclei,
                            nuclei_qc_dir,
                            img_name,
                            config,
                        )

                    if config.get('save_cellpose_masks', True):
                        # Save cellpose masks
                        mask_cells_path = os.path.join(results_dir, f"{img_base}_cell_masks.tif")
                        mask_nuclei_path = os.path.join(results_dir, f"{img_base}_nuclei_masks.tif")
                        mask_cells_orig_path = os.path.join(results_dir, f"{img_base}_cell_masks_orig.tif")

                        save_mask_as_tiff(masks, mask_cells_path)
                        save_mask_as_tiff(masks_nuclei, mask_nuclei_path)
                        save_mask_as_tiff(cell_masks_orig[path], mask_cells_orig_path)
                    
                    # Summarize results
                    channels_of_interest = config.get('channels_of_interest', 
                                                   list(range(img_data['image'].shape[-1])))
                    summary = {
                        'image_name': img_name,
                        'total_cells_aligned': len(cell_measurements),
                        'original_nuclei_detected': original_nuclei_count,
                        'original_cells_detected': original_cells_count,
                        'nuclei_cell_ratio': original_cells_count / original_nuclei_count if original_nuclei_count > 0 else 0,
                    }
                    
                    # Add channel statistics with outlier robustness (1st-99th percentile)
                    for ch in channels_of_interest:
                        if ch < img_data['image'].shape[-1]:
                            ch_ctcf = [cell['ctcf'][ch] for cell in cell_measurements]
                            positive_counts_ch = np.sum([cell['c_positive'][ch] for cell in cell_measurements])
                            if ch_ctcf:
                                # Calculate percentiles for trimming outliers
                                p1 = np.percentile(ch_ctcf, 1)
                                p99 = np.percentile(ch_ctcf, 99)
                                
                                # Clip values to 1st and 99th percentiles instead of filtering
                                trimmed_ctcf = np.clip(ch_ctcf, p1, p99)
                                
                                summary[f'channel_{ch+1}_n_pos'] = positive_counts_ch
                                summary[f'channel_{ch+1}_mean_ctcf'] = np.mean(trimmed_ctcf)
                                summary[f'channel_{ch+1}_median_ctcf'] = np.median(trimmed_ctcf)
                                summary[f'channel_{ch+1}_std_ctcf'] = np.std(trimmed_ctcf)

                            else:
                                summary[f'channel_{ch+1}_n_pos'] = 0
                                summary[f'channel_{ch+1}_mean_ctcf'] = 0
                                summary[f'channel_{ch+1}_median_ctcf'] = 0
                                summary[f'channel_{ch+1}_std_ctcf'] = 0

                    batch_results.append(summary)
                    
                    # Clean up memory
                    del cell_measurements, cell_df
                    gc.collect()
                
                except Exception as e:
                    print(f"Error processing {img_name}: {str(e)}")
                    batch_results.append({
                        'image_name': img_name,
                        'error': str(e),
                        'total_cells': 0
                    })
            else:
                batch_results.append({
                    'image_name': img_name,
                    'error': "No mask generated",
                    'total_cells': 0
                })
        
        # Append batch results to all results
        all_results.extend(batch_results)
        
        # Clean up batch memory
        del loaded_images, cellpose_inputs, cell_masks, bg_models_by_image, batch_results
        gc.collect()
        torch.cuda.empty_cache() if config.get('use_gpu', True) else None
    
    # Save experiment results
    if all_results:
        exp_df = pd.DataFrame(all_results)
        exp_df.to_csv(os.path.join(results_dir, "experiment_results.csv"), index=False)
    
    print(f"Experiment processing complete. Results saved to: {results_dir}")
    return results_dir

def process_experiments_optimized(main_directory, config):
    """
    Process experiments in a main directory or a single experiment folder
    
    Parameters:
    - main_directory: Path to main directory with multiple experiment folders
                     OR path to a single experiment folder with TIF/TIFF files
    - config: Configuration dictionary
    
    Returns:
    - Path to results directory
    """
    # Check if the main_directory contains TIF/TIFF files directly
    tif_files = [f for f in os.listdir(main_directory) 
                if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]
    
    if tif_files:
        # This is a single experiment folder - process it directly
        print(f"Processing single experiment folder: {main_directory}")
        return process_experiment_folder(main_directory, config)
    
    # Create results directory for multiple experiments
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join(main_directory, f"CTCF_Analysis_{timestamp}")
    os.makedirs(results_dir, exist_ok=True)
    
    # Find all experiment folders
    experiment_folders = [f.path for f in os.scandir(main_directory) if f.is_dir() 
                         and not f.name.startswith('.') and not "CTCF_Analysis" in f.name]
    
    # Process each experiment folder
    for experiment_folder in experiment_folders:
        exp_name = os.path.basename(experiment_folder)
        print(f"\nProcessing experiment: {exp_name}")
        
        # Create experiment results folder
        exp_results_dir = os.path.join(results_dir, exp_name)
        try:
            # Process this experiment folder
            process_experiment_folder(experiment_folder, config, results_dir=exp_results_dir)
        except Exception as e:
            print(f"Error processing experiment {exp_name}: {str(e)}")
            traceback.print_exc()
            # Continue with next experiment
            continue
        
    print(f"All experiments processed. Results saved to: {results_dir}")
    return results_dir

### Crop Section testing 
In this section the functions for testing on a small cropped picture of the image are dealigned

In [None]:
def crop_central_region(image, crop_ratio=0.1):
    """
    Crop the central region of an image for quick segmentation testing
    
    Parameters:
    - image: The input image (H, W, C)
    - crop_ratio: Size of the crop relative to original image (0.1 = 10%)
    
    Returns:
    - cropped_image: The central cropped region
    - crop_coords: (y_start, y_end, x_start, x_end) for reference
    """
    # Get image dimensions
    h, w, c = image.shape
    
    # Calculate crop size
    crop_h = int(h * crop_ratio)
    crop_w = int(w * crop_ratio)
    
    # Calculate central coordinates
    center_y, center_x = h // 2, w // 2
    
    # Calculate crop boundaries
    y_start = center_y - (crop_h // 2)
    y_end = center_y + (crop_h // 2)
    x_start = center_x - (crop_w // 2)
    x_end = center_x + (crop_w // 2)
    
    # Ensure coordinates are within image bounds
    y_start = max(0, y_start)
    y_end = min(h, y_end)
    x_start = max(0, x_start)
    x_end = min(w, x_end)
    
    # Extract crop
    cropped_image = image[y_start:y_end, x_start:x_end, :]
    
    return cropped_image, (y_start, y_end, x_start, x_end)    

def test_segmentation_on_crop(image_path, output_dir, config, crop_ratio=0.1):
        """
        Test segmentation on a central crop of an image
        
        Parameters:
        - image_path: Path to the input image
        - output_dir: Directory to save results
        - config: Configuration dictionary
        - crop_ratio: Size of the crop relative to original image (0.1 = 10%)
        
        Returns:
        - Dictionary with segmentation results and parameters
        """
        # Load image and normalize channels
        image = tiff.imread(image_path)
        img_name = os.path.basename(image_path)
        img_base = os.path.splitext(img_name)[0]
        
        print(f"\nTesting segmentation on cropped region of: {img_name}")
        
        # Move the shortest axis (channels) to the last index if needed
        shortest_axis = np.argmin(image.shape)
        image = np.moveaxis(image, shortest_axis, -1)
        
        # Crop central region
        cropped_image, crop_coords = crop_central_region(image, crop_ratio)
        y_start, y_end, x_start, x_end = crop_coords
        
        print(f"Original image shape: {image.shape}")
        print(f"Cropped region shape: {cropped_image.shape}")
        print(f"Crop coordinates: (y={y_start}:{y_end}, x={x_start}:{x_end})")
        
        # Create output directory for this test if it doesn't exist
        crop_output_dir = os.path.join(output_dir, f"{img_base}_crop_test")
        os.makedirs(crop_output_dir, exist_ok=True)
        
        # 1. Estimate background for the cropped region
        print("Estimating background using GMM...")
        bg_models = {}
        for ch in range(cropped_image.shape[-1]):
            channel_data = cropped_image[:,:,ch].copy()
            bg_models[ch] = estimate_background_gmm(channel_data)
            
            # Save background mask visualization
            visualize_background_mask(channel_data, bg_models[ch], 
                                     os.path.join(crop_output_dir, f"crop_bg_mask_ch{ch+1}.png"))
        
        # 2. Segment cells on the cropped region
        cell_masks = segment_cells_with_downsampling(cropped_image, config, bg_models)
        
        # 3. Measure CTCF for each cell in the cropped region
        cell_measurements = measure_cells(cropped_image, cell_masks, bg_models)
        
        # Save visualization of the segmentation results
        create_visualization(cropped_image, cell_masks, cell_measurements, 
                            os.path.join(crop_output_dir, f"{img_base}_crop_segmentation.png"), 
                            debug=True)
        
        # Create a comparison visualization showing where the crop is from
        plt.figure(figsize=(12, 6))
        
        # Show original with crop region highlighted
        plt.subplot(1, 2, 1)
        # Use first channel for display or create a composite
        if image.shape[-1] >= 3:
            display_img = np.zeros((image.shape[0], image.shape[1], 3))
            for i in range(min(3, image.shape[-1])):
                ch_data = exposure.equalize_adapthist(image[:,:,i])
                display_img[:,:,i] = ch_data
        else:
            display_img = exposure.equalize_adapthist(image[:,:,0])
        
        plt.imshow(display_img)
        plt.gca().add_patch(plt.Rectangle((x_start, y_start), 
                                         x_end - x_start, 
                                         y_end - y_start, 
                                         fill=False, 
                                         edgecolor='red', 
                                         linewidth=2))
        plt.title('Original Image with Crop Region')
        
        # Show the cropped region with segmentation overlay
        plt.subplot(1, 2, 2)
        # Create overlay of segmentation on image
        if cropped_image.shape[-1] >= 3:
            crop_display = np.zeros((cropped_image.shape[0], cropped_image.shape[1], 3))
            for i in range(min(3, cropped_image.shape[-1])):
                ch_data = exposure.equalize_adapthist(cropped_image[:,:,i])
                crop_display[:,:,i] = ch_data
        else:
            crop_display = exposure.equalize_adapthist(cropped_image[:,:,0])
        
        plt.imshow(crop_display)
        # Add cell mask overlay
        plt.imshow(cell_masks > 0, alpha=0.7, cmap='cool')
        plt.title(f'Segmentation on Cropped Region ({len(cell_measurements)} cells)')
        
        plt.tight_layout()
        plt.savefig(os.path.join(crop_output_dir, f"{img_base}_crop_location.png"), dpi=150)
        plt.close()
        
        # Save cell measurements to CSV
        cell_df = pd.DataFrame([
            {
                'cell_id': cell['label'],
                'area': cell['area'],
                **{f'channel_{ch+1}_ctcf': cell['ctcf'][ch] for ch in range(cropped_image.shape[-1])},
                **{f'channel_{ch+1}_mean': cell['mean'][ch] for ch in range(cropped_image.shape[-1])},
                'centroid_x': cell['centroid'][1],
                'centroid_y': cell['centroid'][0]
            }
            for cell in cell_measurements
        ])
        cell_df.to_csv(os.path.join(crop_output_dir, f"{img_base}_crop_cells.csv"), index=False)
        
        # Return info about the test
        return {
            'image_name': img_name,
            'crop_region': crop_coords,
            'cell_count': len(cell_measurements),
            'output_dir': crop_output_dir
        }

# Test segmentation on a cropped region before full processing
def test_segmentation_parameters(image_path, config, crop_ratio=0.1):
    """Test segmentation parameters on a cropped region of an image"""
    # Create a temporary output directory
    output_dir = os.path.join(os.path.dirname(image_path), "segmentation_tests")
    os.makedirs(output_dir, exist_ok=True)
    
    # Run the crop test
    test_result = test_segmentation_on_crop(
        image_path=image_path,
        output_dir=output_dir,
        config=config,
        crop_ratio=crop_ratio
    )
    
    print(f"\n✓ Segmentation test complete!")
    print(f"  - Found {test_result['cell_count']} cells in the cropped region")
    print(f"  - Results saved to: {test_result['output_dir']}")
    print("\nTIP: Review the results and adjust segmentation parameters in config as needed")
    
    return test_result

## Process Images
In this section firstly the configuratin will be set for the segmentation, and a test can be done for an individual file. On the second part, a batch processing of multiple files can be done.

In [None]:
# Define the configuration for the pipeline
config = {
    # Hardware settings
    'use_gpu': True,  # Set to False to force CPU processing
    'auto_detect_gpu': True,  # Auto-detect and use GPU if available

    # Background Estimation settings
    'use_bg_composite': True,  # Use a composite image for background estimation (faster)

    # CellPose settings
    'cellpose_model': 'cpsam',  # CellPose model to use ('cpsam', 'cyto2', 'cyto3', 'nuclei_only')
    'segmentation_channels': [0, 1, 3],  # Channels for CellPose-SAM (cpsam) segmentation (Max 3)
    'cytoplasm_channel': 4,  # Far Red channel for cytoplasm/membrane (Cyto3 - Model)
    'nucleus_channel': 1,    # Blue channel (DAPI) for nuclei         (Cyto3 - Model)
    'cell_diameter': 50.0,     # Approximate diameter in pixels       (Cyto3 - Model)
    'flow_threshold': 0.4,   # Flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
    'cellprob_threshold': 0.0,  # All pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
    'downsample_factor': 2.0,  # Downsample factor for speed (1.0 = no downsampling)

    # Alignment (Nuclei - Cell Setting)
    'quick_size_filtering': True,       # Set to True to use size filtering instead of nuclei alignment
    'filter_by_radius': True,           # True: filter by max extension radius, False: filter by pixel area
    'cell_min_size': 4,                 # Minimum size (radius or area depending on filter_by_radius)
    'cell_max_size': 100,                # Maximum size (radius or area depending on filter_by_radius)

    # Nuclei segmentation settings
    'nuclei_min_size': 5,  # Minimum size for nuclei segmentation
    'nuclei_max_size': 1000,  # Maximum size for nuclei segmentation
    'nuclei_gaussian_sigma': 1.0,  # Gaussian sigma for nuclei segmentation
    'nuclei_segmentation_method' : 'cellpose' , # Method for nuclei segmentation ('threshold' or 'cellpose')
    'nuclei_thresh_factor' : 1.0, # Adjust sensitivity
    
    # Visualization settings
    'visualize_bg': False,  # Set to False if you don't want intermediate visualizations
    'visualize_segmentation': False,  # Show segmentation results
    'save_qc_regions': True,  # Save QC regions for review
    'qc_region_size': 300,  # Size of the QC region in pixels
    'visualize_nuclei': False,  # Show nuclei segmentation results
    'save_nuclei_qc_regions': False,  # Save QC regions for nuclei segmentation
    'save_cellpose_masks': False,  # Save CellPose masks as TIFF files
    
    # Analysis settings
    'channels_of_interest': [0, 1, 2, 3],  # All channels to measure
    'positive_threshold_method': 'bg_plus_std', # Method for positive thresholding ('bg_plus_std', 'percentile', 'otsu')
    'positive_threshold_std_multiplier': 2.0, # Multiplier for background std in 'bg_plus_std' method
    'positive_threshold_percentile': 95,  # Percentile for positive thresholding in 'percentile' method
    
    # Parallelization settings
    'max_workers': min(os.cpu_count(), 16),  # Maximum number of parallel workers
    'batch_size': 8,                         # Number of images to process in a batch

}

In [None]:
config_3d_params = {
    # 3D specific parameters
    'anisotropy': 3.0,  # Z-to-XY resolution ratio (common in microscopy)
    'cell_diameter_3d': 30.0,  # Diameter in 3D (usually larger than 2D)
    'use_3d': True,         # Flag to enable 3D processing (lowercase 'd')
    'gpu_batch_size_3d': 2,  # Process fewer 3D images at once
    
    'z_downsampling': False,      # Whether to downsample in Z dimension 
    'aggressive_gc': True,        # More aggressive garbage collection for 3D
}
# Update your existing config
config.update(config_3d_params)

### Cropped segmentation and background substraction testing

In [None]:
folder_path_test = select_folder()
print(f'>>> Selected folder: {folder_path_test}')

In [None]:
# Example usage in your script
image_paths = glob.glob(os.path.join(folder_path_test, "*.tiff"))
if len(image_paths) > 0:
    # Test segmentation on first image before batch processing
    test_result = test_segmentation_parameters(image_paths[0], config, crop_ratio=0.1)
    
else:
    print(f"No TIFF images found in {folder_path_test}")

### Batch Processing

In [None]:
exp_folder = select_folder()
print(f'>>> Selected folder: {exp_folder}')

In [None]:
# process single experiment folder  
results_dir = process_experiment_folder(exp_folder, config)
print(f'>>> Experiment processing complete. Results saved to: {results_dir}')