In [None]:
import numpy as np
import pandas as pd
from skimage import io
from scipy.io import loadmat
from skimage.transform import warp
from tifffile import imwrite
from pathlib import Path
import math
import xml.etree.ElementTree as ET
from numpy.fft import fftn, ifftn
from scipy.ndimage import fourier_shift
from skimage.transform import AffineTransform, warp
import os, re

## ImageJ + TrackMate Integration Setup (via pyimagej & scyjava)

This section initializes the **ImageJ** environment with full access to **TrackMate** and other legacy plugins through `pyimagej` and `scyjava`.

### Key Features:
- Allocates up to **16 GB of Java heap memory** for large image processing tasks.
- Launches **Fiji (ImageJ distribution)** in **headless mode**, allowing non-GUI execution.
- Loads all necessary **Java classes** to control ImageJ and TrackMate via Python.
- Provides aliases for essential TrackMate modules: detection, tracking, filtering, and export.

### Java Classes Imported:
| Category              | Classes / Purpose                                                       |
|-----------------------|-------------------------------------------------------------------------|
| **Core ImageJ**       | `ChannelSplitter`, `BackgroundSubtracter`, `WindowManager`, `ImagePlus` |
| **Java Utilities**    | `Integer`, `HashMap`, `File`                                            |
| **TrackMate Core**    | `Model`, `Settings`, `TrackMate`, `Logger`                              |
| **TrackMate Tools**   | `LogDetectorFactory`, `SparseLAPTrackerFactory`, `FeatureFilter`, `ExportTracksToXML` |

### Output
Once executed, the script will print the current ImageJ version to confirm successful initialization.

In [None]:
import imagej, scyjava

# Set the maximum heap memory for the JVM to 16 GB
scyjava.config.add_options("-Xmx16g")

# Initialize ImageJ with legacy mode enabled.
# It is important to use add_legacy=True to enable TrackMate and other legacy plugins.
# Note: 'headless=True' is replaced by 'mode="headless"' in modern ImageJ.
ij = imagej.init("sc.fiji:fiji", mode="headless", add_legacy=True)

# Import key ImageJ Java classes using scyjava.jimport
ChannelSplitter  = scyjava.jimport('ij.plugin.ChannelSplitter')                  # Used to split multichannel images
BackgroundSub    = scyjava.jimport('ij.plugin.filter.BackgroundSubtracter')      # Used for background subtraction
WindowManager    = scyjava.jimport('ij.WindowManager')                           # Manages open image windows in ImageJ
ImagePlus        = scyjava.jimport('ij.ImagePlus')                               # Core image container class
JavaArray        = scyjava.jarray(ImagePlus, 0)                                  # Placeholder for ImagePlus Java array

# Print the current version of ImageJ for verification
print("ImageJ:", ij.getVersion())

# Aliases for frequently used Java classes
Integer = scyjava.jimport("java.lang.Integer")       # Java Integer class
HashMap = scyjava.jimport("java.util.HashMap")       # Java HashMap for parameter passing
File    = scyjava.jimport("java.io.File")            # Java File class for file I/O

# Import core TrackMate classes for tracking
Model     = scyjava.jimport("fiji.plugin.trackmate.Model")                       # Represents the tracking model
Settings  = scyjava.jimport("fiji.plugin.trackmate.Settings")                    # Stores all settings for tracking
TrackMate = scyjava.jimport("fiji.plugin.trackmate.TrackMate")                   # Main execution class for TrackMate

# Import TrackMate utilities for detection, tracking, filtering, and export
FeatureFilter            = scyjava.jimport("fiji.plugin.trackmate.features.FeatureFilter")                # Used to filter features like quality, duration, displacement
LogDetectorFactory       = scyjava.jimport("fiji.plugin.trackmate.detection.LogDetectorFactory")          # LoG-based spot detector
SparseLAPTrackerFactory  = scyjava.jimport("fiji.plugin.trackmate.tracking.jaqaman.SparseLAPTrackerFactory")  # Core tracker factory using LAP framework
ExportTracksToXML        = scyjava.jimport("fiji.plugin.trackmate.action.ExportTracksToXML")              # Export tracking results to XML
Logger                   = scyjava.jimport("fiji.plugin.trackmate.Logger")     

## Configuration & Input Paths for TrackMate-based Image Analysis

This section defines all necessary input paths and parameter sets for running the full TrackMate-based pipeline, including drift correction and chromatic registration.

---

### Input Paths

| Variable       | Description                                |
|----------------|--------------------------------------------|
| `nd2_path`     | Path to the input `.nd2` file (raw microscopy stack) |
| `mat_path`     | Path to the `.mat` file containing 3×3 affine matrix (`T`) for chromatic registration |
| `imagej_path`  | Path to the ImageJ executable (if needed for launching manually or troubleshooting) |

---

### General Drift Correction Configuration (`config`)

This configuration is used for the initial tracking and drift correction phase using a specific channel.

| Key                     | Description |
|-------------------------|-------------|
| `n_channels`            | Total number of channels in the image stack (e.g., 3) |
| `target_chan`           | Channel used for drift tracking (e.g., 3 = red) |
| `spot_radius`           | Approximate radius of features/spots |
| `threshold`             | Minimum intensity for spot detection |
| `subpixel`              | Enable subpixel refinement for detection |
| `median_filter`         | Apply median filtering before detection |
| `linking_max`           | Max distance for linking spots |
| `gapclosing_max`        | Max distance for closing temporal gaps |
| `max_frame_gap`         | Max number of allowed gap frames |
| `allow_split/merge`     | Whether to allow track splitting or merging |
| `filter_quality`        | Apply spot filtering by quality score |
| `quality_cutoff`        | Minimum quality value |
| `filter_displacement`   | Apply track filtering by displacement |
| `displacement_cutoff`   | Minimum allowed displacement |
| `filter_duration`       | Apply track filtering by duration |
| `duration_cutoff_time`  | Minimum track duration in seconds (NaN = auto-compute) |
| `target_registration_chan` | Channel to apply chromatic registration |

---

### Green Channel Tracking Configuration (`config_green`)

Used for running TrackMate on channel 2 (green) during dual-channel tracking.

- High intensity threshold due to brighter signal
- No displacement filter
- Includes filtering based on duration and early track start

---

### Red Channel Tracking Configuration (`config_red`)

Used for running TrackMate on channel 3 (red) during dual-channel tracking.

- Lower threshold due to dimmer red signal
- Similar structure to `config_green` with independent settings

---

### Purpose

These configurations feed into:
- `run_precorrection()` for full pipeline execution (drift + chromatic correction + merge)
- `run_trackmate_dual_channel_from_config()` for channel-specific dual tracking
- XML export for downstream time delay or colocalization analysis

In [None]:
# Paths
nd2_path = "D:/Jo_Lab/230501_THP1_LDVPdp2036_2.5s_paBleb/cham1_paBleb10uM/new/1.ND2_file/paBleb10uM_cham1_002.nd2"
mat_path = "D:/Jo_Lab/chromaticErr/tform_matrix_only_2.mat"
imagej_path = "C:/Users/KU/Desktop/fiji-win64/Fiji.app/ImageJ-win64.exe"

# Configurations
config = {
    "n_channels": 3,  # Total number of image channels (e.g., 3 for RGB, or 3 grayscale slices)

    "target_chan": 3,  # Channel index used for drift tracking (1-based, as expected by ImageJ/TrackMate)

    # Spot detection parameters
    "spot_radius": 2.5,       # Approximate spot radius (in pixels) for detection
    "threshold": 20.0,        # Intensity threshold for spot detection
    "subpixel": True,         # Enable subpixel localization refinement
    "median_filter": False,   # Apply median filtering before detection (reduces noise)

    # Linking & tracking parameters
    "linking_max": 1.0,       # Maximum linking distance between consecutive frames (in pixels)
    "gapclosing_max": 1.0,    # Maximum distance for closing gaps between tracks
    "max_frame_gap": 0,       # Maximum number of frames to allow gap closing

    # Track structure behaviors
    "allow_split": False,     # Allow splitting of tracks (e.g., cell division)
    "allow_merge": False,     # Allow merging of tracks (e.g., convergence)

    # Track filtering
    "filter_quality": True,        # Filter spots by quality score
    "quality_cutoff": 0.0,         # Minimum quality threshold
    "filter_displacement": True,   # Filter tracks by total displacement
    "displacement_cutoff": 0.0,    # Minimum displacement for track to be retained
    "filter_duration": True,       # Filter tracks by duration
    "duration_cutoff_time": math.nan,  # Minimum duration (in seconds); if NaN, auto-calculated from frames

    # Chromatic registration
    "target_registration_chan": 3  # Channel index to which chromatic registration is applied
}

config_green = {
    "target_chan": 2,              # Channel index to analyze (Green channel)

    # Spot detection settings
    "spot_radius": 2.5,            # Estimated radius of the spots to detect
    "threshold": 100.0,            # Intensity threshold for spot detection
    "subpixel": True,              # Enable subpixel localization
    "median_filter": False,        # Apply median filter before detection (noise reduction)

    # Tracking and linking settings
    "linking_max": 1.0,            # Max distance allowed for linking spots (in pixels)
    "gapclosing_max": 0.6,         # Max distance for gap closing
    "max_frame_gap": 5,            # Max number of frames allowed for gap closing

    # Track behavior
    "allow_split": False,          # Disallow track splitting
    "allow_merge": False,          # Disallow track merging

    # Track filtering options
    "displacement_cutoff": 1.5,    # Minimum displacement (ignored if filter_displacement=False)
    "filter_displacement": False,  # Disable displacement-based track filtering

    "duration_cutoff_time": 2.5,   # Minimum track duration in time units (e.g., seconds)
    "filter_duration": True,       # Enable duration-based track filtering

    "filter_quality": True,        # Enable filtering based on spot quality
    "quality_cutoff": 0.0,         # Minimum spot quality threshold

    "filter_track_start": True,    # Filter out tracks that start too early
    "track_start_time": 1          # Minimum start time (frame index) for a track to be retained
}

config_red = {
    "target_chan": 3,              # Channel index to analyze (Red channel)

    # Spot detection settings
    "spot_radius": 2.5,
    "threshold": 40.0,             # Lower threshold due to dimmer red signal
    "subpixel": True,
    "median_filter": False,

    # Tracking and linking settings
    "linking_max": 1.0,
    "gapclosing_max": 0.6,
    "max_frame_gap": 5,

    # Track behavior
    "allow_split": False,
    "allow_merge": False,

    # Track filtering options
    "displacement_cutoff": 1.5,
    "filter_displacement": False,

    "duration_cutoff_time": 2.5,
    "filter_duration": True,

    "filter_quality": True,
    "quality_cutoff": 0.0,

    "filter_track_start": True,
    "track_start_time": 1
}

# MasterMacro Pipeline

---

## Function Call Hierarchy

```text
run_precorrection
├── run_trackmate_nd2
│   └── TrackMate tracking & export (XML, TIFF, metadata)
├── load_tracks_xml
├── compute_median_drift_from_tracks
├── save_drift_trace_txt
├── apply_drift_correction_to_tiff
│   └── Fourier shift applied per-channel
├── apply_chromatic_registration_from_mat
│   └── AffineTransform (scikit-image warp)
└── combine_tif_channels
    └── ImageJ merge via pyimagej

run_trackmate_dual_channel_from_config
├── run_single_channel (green)
└── run_single_channel (red)
    └── TrackMate on multichannel TIFF

subtract_background_and_merge
├── ImageJ channel split
├── Background subtraction (C2, C3)
└── Channel merge (C1 + corrected C2/C3)
```

---

## Key Variables & Paths

| Variable | Description |
|---------|-------------|
| `nd2_path` | Input raw image file in ND2 format |
| `track_dir` | Directory for TrackMate XML and TIF output |
| `reg_dir` | Directory for drift-corrected results |
| `mat_path` | MATLAB `.mat` file with 3x3 affine transform |
| `drift_trace` | NumPy array of XY drift per frame |
| `description_template` | Metadata string for TIFF description field |

---

## Full Execution Flow

### `run_precorrection()`
1. Create output folders (2, 3)
2. Run TrackMate on ND2 → XML + TIF export
3. Load tracks and compute median XY drift
4. Apply drift correction to all channels
5. Apply chromatic correction to registration channel
6. Merge all corrected channels into multichannel TIF

### `run_trackmate_dual_channel_from_config()`
1. Open multichannel TIFF
2. Run TrackMate separately for green and red channels
3. Apply spot filtering and duration filters
4. Export tracking results as XML (g and r)

### `subtract_background_and_merge()`
1. Load multichannel TIFF and split into C1, C2, C3
2. Apply background subtraction on C2 and C3
3. Merge C1 + corrected C2 + corrected C3 (in slot 4)
4. Save merged image as new TIFF

---

## Sample Output File Names
```text
ND2 input:             sample.nd2
TrackMate XML:         sample_Tracks.xml
Raw TIFF:              sample.tif
Drift trace:           sample_drift.txt
Corrected TIFFs:       sample_chan1_drftc.tif, sample_chan2_drftc.tif, ...
Registered TIF:        sample_drftc_reg.tif
Dual XML (g, r):       sample_drftc_reg_g.xml, sample_drftc_reg_r.xml
Backsub TIFF:          sample_back.tif
```
---

## Example Image Shapes

| File Type          | Shape         | Description             |
|--------------------|---------------|-------------------------|
| Raw ND2            | (T, Y, X, C)  | 4D hyperstack           |
| Track TIFF         | (T, Y, X, C)  | Uncorrected export      |
| Corrected TIFF     | (T, Y, X)     | Per-channel TIFF        |
| Merged TIF         | (T, Y, X, C)  | Final drift-corrected   |
| Background-sub TIFF| (T, Y, X, C)  | After subtract_background|

In [None]:
def run_trackmate_nd2(
    nd2_paths,
    out_dir,
    target_chan=3,
    spot_radius=2.5,
    threshold=20.0,
    subpixel=True,
    median_filter=False,
    linking_max=1.0,
    gapclosing_max=1.0,
    max_frame_gap=0,
    allow_split=False,
    allow_merge=False,
    filter_quality=True,
    quality_cutoff=0.0,
    filter_displacement=True,
    displacement_cutoff=0.0,
    filter_duration=True,
    duration_cutoff_time=math.nan,
):
    '''
    Run TrackMate on ND2 files for spot detection and tracking,
    and export results including XML tracks, TIF images, and metadata.

    Parameters
    ----------
    nd2_paths : str or Path or list
        Path(s) to ND2 image file(s) to be processed.
    out_dir : str or Path
        Output directory where results (XML, TIF, metadata) will be saved.
    target_chan : int, default=3
        Channel index (1-based) for spot detection.
    spot_radius : float, default=2.5
        Radius of spots to be detected.
    threshold : float, default=20.0
        Intensity threshold for spot detection.
    subpixel : bool, default=True
        Whether to enable subpixel localization.
    median_filter : bool, default=False
        Whether to apply median filter before detection.
    linking_max : float, default=1.0
        Maximum linking distance for track continuation.
    gapclosing_max : float, default=1.0
        Maximum distance allowed for closing temporal gaps.
    max_frame_gap : int, default=0
        Maximum number of frames allowed for gap closing.
    allow_split : bool, default=False
        Whether to allow track splitting.
    allow_merge : bool, default=False
        Whether to allow track merging.
    filter_quality : bool, default=True
        Apply filtering based on spot quality.
    quality_cutoff : float, default=0.0
        Minimum quality value for spots to be retained.
    filter_displacement : bool, default=True
        Filter tracks based on displacement.
    displacement_cutoff : float, default=0.0
        Minimum displacement required to keep the track.
    filter_duration : bool, default=True
        Filter tracks based on duration.
    duration_cutoff_time : float, default=nan
        Minimum duration (in time units) required to keep the track.
        If NaN, it is automatically calculated.

    Returns
    -------
    out_xmls : list of Path
        List of paths to the exported TrackMate XML result files.
    '''

    if isinstance(nd2_paths, (str, Path)):
        nd2_paths = [nd2_paths]

    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    out_xmls = []

    for nd2 in map(Path, nd2_paths):
        print(f"> Processing {nd2.name}")

        # Load ND2 image using Bio-Formats
        imp = ij.IJ.openImage(str(nd2))
        if imp is None:
            print("  Warning: Failed to load with Bio-Formats:", nd2)
            continue
        imp.show()

        # Reset spatial scale to unitless (match MATLAB convention)
        cal = imp.getCalibration()
        cal.setUnit("")
        cal.pixelWidth = cal.pixelHeight = cal.pixelDepth = 1.0
        imp.setCalibration(cal)

        # Compute duration cutoff (if not provided)
        frame_interval = cal.frameInterval
        nT = imp.getDimensions()[4]
        if math.isnan(duration_cutoff_time):
            duration_cutoff = frame_interval * (nT - 1.5)
        else:
            duration_cutoff = duration_cutoff_time

        # Initialize TrackMate model and logger
        model = Model()
        model.setLogger(Logger.IJ_LOGGER)
        model.setPhysicalUnits(cal.getUnit(), cal.getTimeUnit())

        # Create settings from the image
        settings = Settings(imp)

        # Configure detector settings
        settings.detectorFactory = LogDetectorFactory()
        det = HashMap()
        det.put("DO_SUBPIXEL_LOCALIZATION", subpixel)
        det.put("RADIUS", spot_radius)
        det.put("TARGET_CHANNEL", Integer.valueOf(target_chan))
        det.put("THRESHOLD", threshold)
        det.put("DO_MEDIAN_FILTERING", median_filter)
        settings.detectorSettings = det

        # Apply spot quality filter
        if filter_quality:
            settings.addSpotFilter(FeatureFilter("QUALITY", quality_cutoff, True))

        # Configure tracker settings
        settings.trackerFactory = SparseLAPTrackerFactory()
        trk = settings.trackerFactory.getDefaultSettings()
        trk.put("MAX_FRAME_GAP", Integer.valueOf(max_frame_gap))
        trk.put("LINKING_MAX_DISTANCE", linking_max)
        trk.put("GAP_CLOSING_MAX_DISTANCE", gapclosing_max)
        trk.put("ALLOW_TRACK_SPLITTING", allow_split)
        trk.put("ALLOW_TRACK_MERGING", allow_merge)
        settings.trackerSettings = trk

        # Add all analyzers and optional track filters
        settings.addAllAnalyzers()
        if filter_displacement:
            settings.addTrackFilter(FeatureFilter("TRACK_DISPLACEMENT", displacement_cutoff, True))
        if filter_duration:
            settings.addTrackFilter(FeatureFilter("TRACK_DURATION", duration_cutoff, True))

        # Execute TrackMate
        tm = TrackMate(model, settings)
        if not tm.checkInput() or not tm.process():
            print("  Error:", tm.getErrorMessage())
            imp.close()
            continue

        # Export tracking results to XML
        xml_path = out_dir / f"{nd2.stem}_Tracks.xml"
        ExportTracksToXML.export(model, settings, File(str(xml_path)))
        print("  XML saved:", xml_path.name)

        # Save the image stack as TIF
        tiff_path = out_dir / f"{nd2.stem}.tif"
        ij.IJ.saveAs(imp, "Tiff", str(tiff_path))

        # Extract and save metadata from image if available
        info_str = imp.getProperty("Info")
        if info_str:
            meta_path = out_dir / f"{nd2.stem}_metadata.txt"
            with open(meta_path, "w", encoding="utf-8") as f:
                f.write(str(info_str))
            print("  Metadata saved:", meta_path.name)
        else:
            print("  Warning: No metadata found.")

        imp.close()
        out_xmls.append(xml_path)

    return out_xmls

def load_tracks_xml(filepath):
    """
    Parse a TrackMate-generated XML file and extract tracking data.

    Each <particle> element in the XML represents one individual track.
    All detections (spots) within a track are read in their original order.

    Parameters
    ----------
    filepath : str or Path
        Path to the TrackMate XML file.

    Returns
    -------
    tracks : list of pandas.DataFrame
        A list of DataFrames, each corresponding to a single track.
        Each DataFrame has columns: ["T", "X", "Y", "Z"], representing
        time point and spatial coordinates of each detection.
    """

    # Parse the XML tree
    tree = ET.parse(filepath)
    root = tree.getroot()

    tracks = []

    # Find all <particle> elements (each representing a track)
    particles = root.findall(".//particle")

    for particle in particles:
        # Extract all <detection> elements within each particle
        detections = particle.findall(".//detection")
        t_list = []

        for d in detections:
            # Extract time and spatial coordinates
            t = float(d.attrib.get("t", 0))
            x = float(d.attrib.get("x", 0))
            y = float(d.attrib.get("y", 0))
            z = float(d.attrib.get("z", 0))
            t_list.append([t, x, y, z])

        # Only add non-empty tracks
        if len(t_list) > 0:
            tracks.append(pd.DataFrame(t_list, columns=["T", "X", "Y", "Z"]))

    return tracks

def compute_median_drift_from_tracks(tracks) -> np.ndarray:
    """
    Compute median XY drift over time from multiple tracks.

    Each track is aligned to its initial position, and then the median 
    displacement across all tracks is computed frame-by-frame.

    Parameters
    ----------
    tracks : list of pandas.DataFrame
        A list of tracks returned from `load_tracks_xml()`.
        Each DataFrame must contain columns ["T", "X", "Y", "Z"],
        where T is the frame index, and X, Y, Z are spatial coordinates.

    Returns
    -------
    drift_trace : np.ndarray of shape (max_frame+1, 2)
        An array containing median XY drift per frame.
        Each row is [median_dx, median_dy] at that frame.
    """

    if not tracks:
        raise ValueError("Input track list is empty.")

    # Determine the maximum frame index across all tracks
    max_frame = int(max(df["T"].max() for df in tracks))

    # Create padded arrays to store aligned X and Y positions
    traj_x, traj_y = [], []

    for df in tracks:
        t = df["T"].to_numpy(dtype=int)

        # Normalize track positions so that the first point is treated as origin
        x = df["X"].to_numpy() - df["X"].iloc[0]
        y = df["Y"].to_numpy() - df["Y"].iloc[0]

        # Initialize arrays with NaNs for full frame range
        x_full = np.full(max_frame + 1, np.nan)
        y_full = np.full(max_frame + 1, np.nan)

        # Insert values at corresponding frame indices
        x_full[t] = x
        y_full[t] = y

        traj_x.append(x_full)
        traj_y.append(y_full)

    # Stack all trajectories across tracks (frames × tracks)
    traj_x = np.stack(traj_x, axis=1)
    traj_y = np.stack(traj_y, axis=1)

    # Compute median drift per frame, ignoring NaNs
    median_x = np.nanmedian(traj_x, axis=1)
    median_y = np.nanmedian(traj_y, axis=1)

    # Return combined XY median drift as a (T, 2) array
    return np.column_stack([median_x, median_y])

def save_drift_trace_txt(trace_array, save_path):
    """
    Save a drift trace array to a plain text file.

    Parameters
    ----------
    trace_array : np.ndarray
        A 2D NumPy array of shape (T, 2), where each row represents
        the median drift [dx, dy] at a specific frame.

    save_path : str or Path
        Path to the output text file. The file will be saved in plain text format.

    Returns
    -------
    None
    """

    # Save the array to a text file with 8 decimal precision
    np.savetxt(save_path, trace_array, fmt='%.8f')

def sanitize_description(desc: str) -> str:
    """
    Remove non-ASCII characters from a metadata string.

    Parameters
    ----------
    desc : str
        Original metadata string that may contain non-ASCII characters.

    Returns
    -------
    str
        Sanitized ASCII-only string.
    """
    return desc.encode("ascii", "ignore").decode("ascii")

def apply_drift_correction_to_tiff(input_tiff_path, drift_trace, output_prefix, n_channels=3, description_template=None):
    """
    Apply frame-wise XY drift correction to a multi-channel 4D TIFF stack.

    Each channel is independently corrected using FFT-based subpixel shifting,
    and saved as a grayscale multipage TIFF image with photometric mode 'minisblack'.

    Parameters
    ----------
    input_tiff_path : str or Path
        Path to the input 4D TIFF image (shape: T, Y, X, C).

    drift_trace : np.ndarray
        Drift array of shape (T, 2), where each row is [dx, dy] for a frame.

    output_prefix : str or Path
        Prefix for the output file names. Each channel will be saved as:
        <output_prefix>_chanX_drftc.tif

    n_channels : int, default=3
        Number of channels in the input image.

    description_template : str, optional
        Optional metadata string to be included in each output TIFF.

    Returns
    -------
    None
    """

    input_path = Path(input_tiff_path)
    stack = io.imread(input_path)  # Expected shape: (T, Y, X, C)

    # Verify input shape is 4D
    assert stack.ndim == 4, "Input TIFF must be a 4D array (T, Y, X, C)"
    n_frames, height, width, n_chan_inferred = stack.shape
    assert n_chan_inferred == n_channels, (
        f"Expected {n_channels} channels, but found {n_chan_inferred} in the image."
    )

    # Process each channel independently
    for c in range(n_channels):
        corrected_channel = np.empty((n_frames, height, width), dtype=np.uint16)

        for t in range(n_frames):
            # Get drift for the current frame (default to [0, 0] if trace is short)
            dx, dy = drift_trace[t] if t < len(drift_trace) else (0, 0)

            # Apply subpixel shift in frequency domain (note: [y, x] order for shift)
            shifted = ifftn(fourier_shift(fftn(stack[t, :, :, c]), shift=[-dy, -dx])).real

            # Clip values and convert to uint16
            corrected_channel[t] = np.clip(np.round(shifted), 0, 65535).astype(np.uint16)

        # Sanitize metadata string if provided
        image_description = description_template
        if image_description:
            image_description = sanitize_description(image_description)

        # Save corrected grayscale stack
        output_path = Path(f"{output_prefix}_chan{c+1}_drftc.tif")
        imwrite(
            output_path,
            corrected_channel,  # Shape: (T, Y, X)
            photometric='minisblack',
            description=image_description
        )
        print(f"Drift-corrected grayscale saved → {output_path}")

def apply_chromatic_registration_from_mat(input_tiff_path, output_tiff_path, mat_path, description_template=None):
    """
    Apply chromatic registration to a grayscale image stack using a 3x3 affine matrix from a MATLAB .mat file.

    This function reads an affine transformation matrix stored under the key 'T' in a .mat file,
    applies it to each frame of a TIFF image stack, and saves the registered result.

    Parameters
    ----------
    input_tiff_path : str or Path
        Path to the input grayscale multipage TIFF (shape: T, Y, X).

    output_tiff_path : str or Path
        Path where the registered image stack will be saved.

    mat_path : str or Path
        Path to the .mat file that contains the 3x3 affine matrix 'T'.

    description_template : str, optional
        Optional metadata string to be embedded into the TIFF.

    Returns
    -------
    None
    """

    input_path = Path(input_tiff_path)
    output_path = Path(output_tiff_path)
    mat_path = Path(mat_path)

    # Step 1: Load affine transformation matrix 'T' from .mat file
    mat_data = loadmat(mat_path, simplify_cells=True)
    if "T" not in mat_data:
        raise KeyError(f"Affine matrix 'T' not found in {mat_path.name}.")
    T = np.asarray(mat_data["T"], dtype=np.float64).T  # Transpose to match ImageJ/MATLAB forward transform

    # Step 2: Create an AffineTransform object
    tform = AffineTransform(matrix=T)  # Forward transform (as in MATLAB)

    # Step 3: Load image stack
    stack = io.imread(str(input_path))  # Expected shape: (T, Y, X)
    assert stack.ndim == 3, "Input TIFF must have shape (frames, height, width)."
    n_frames, height, width = stack.shape
    registered = np.empty_like(stack)

    # Step 4: Apply the affine transform to each frame
    for i in range(n_frames):
        registered[i] = warp(
            stack[i],
            inverse_map=tform.inverse,  # Use inverse map to achieve forward transform effect
            order=1,                    # Bilinear interpolation (default in MATLAB)
            preserve_range=True,
            output_shape=(height, width)
        ).astype(stack.dtype)

    # Step 5: Sanitize and apply metadata if provided
    image_description = description_template
    if image_description:
        image_description = sanitize_description(image_description)

    # Step 6: Save the registered stack as grayscale TIFF
    imwrite(
        output_path,
        registered,  # shape: (T, Y, X)
        photometric='minisblack',
        description=image_description
    )

    print(f"Chromatic-registered stack saved → {output_path}")

def combine_tif_channels(
    root: str | Path,
    *,
    n_chan: int,
    ij,
    postfix: str = "_drftc",
    delete_intermediate: bool = True,
    description_template: str | None = None,
):
    """
    Combine individual single-channel TIFF stacks into a single multichannel TIFF.

    This function loads grayscale image stacks for each channel and merges them
    using ImageJ's "Merge Channels..." function. The merged result is saved as
    a multichannel hyperstack TIFF with optional metadata embedding.

    Parameters
    ----------
    root : str or Path
        Directory containing the single-channel TIFF files.

    n_chan : int
        Number of channels to merge (currently supports 3 or 4).

    ij : imagej.ImageJ
        An initialized ImageJ instance (from pyimagej).

    postfix : str, default="_drftc"
        Common suffix used in the filenames for channel images.

    delete_intermediate : bool, default=True
        Whether to delete the original single-channel TIFF files after merging.

    description_template : str or None, optional
        Optional metadata string to be embedded into the final TIFF file.

    Returns
    -------
    out_path : Path
        Path to the final merged TIFF file.
    """
    root = Path(root)
    assert n_chan in (3, 4), "Only 3-channel or 4-channel merging is supported."

    # Find base filenames to merge (those with channel 1 files)
    pattern = re.compile(rf"(.+)_chan1{postfix}\.tif$")
    groups = [m.group(1) for m in map(pattern.match, os.listdir(root)) if m]
    if not groups:
        print("No matching TIFFs found for merging.")
        return

    print(f"> Found {len(groups)} merge targets with {n_chan} channels each.")

    # Define lookup table (LUT) order for color assignment
    lut_order_dict = {
        3: ["Grays", "Green", "Red"],
        4: ["Grays", "Blue", "Green", "Red"]
    }

    for base in groups:
        filenames = [f"{base}_chan{c}{postfix}.tif" for c in range(1, n_chan + 1)]
        filenames2delete = filenames.copy()
        save_name = f"{base}{postfix}_reg.tif"

        # Close all ImageJ windows before starting merge
        ij.IJ.run("Close All")
        imps = []

        # Open each channel TIFF and register window
        for f in filenames:
            imp = ij.IJ.openImage(str(root / f))
            if imp is None:
                raise FileNotFoundError(f"File not found: {f}")
            imp.show()
            imps.append(imp)

        titles = ij.WindowManager.getImageTitles()

        # Merge channels using ImageJ GUI command
        merge_args = " ".join([f"c{i+1}={titles[i]}" for i in range(n_chan)]) + " create"
        ij.IJ.run("Merge Channels...", merge_args)
        comp = ij.IJ.getImage()

        # Assign LUTs (color maps) to each channel
        for idx, lut in enumerate(lut_order_dict[n_chan], start=1):
            comp.setPosition(idx)
            ij.IJ.run(comp, lut, "")

        # Correct dimensional metadata if necessary (swap Z and T)
        dims = comp.getDimensions()  # Returns [X, Y, C, Z, T]
        if dims[3] > 1 and dims[4] == 1:
            comp.setDimensions(dims[2], dims[4], dims[3])  # C, T, Z order

        # Extract original metadata from the first channel image
        first_channel_path = root / filenames[0]
        ref_img = ij.IJ.openImage(str(first_channel_path))
        info = ref_img.getProperty("Info")
        ref_img.close()

        if info is None:
            raise ValueError(f"Metadata 'Info' not found in: {first_channel_path}")
        info_str = str(info)

        # Extract number of frames and time interval from metadata
        match_frames = re.search(r"SizeT\s*=\s*(\d+)", info_str)
        match_interval = re.search(r"finterval\s*=\s*([\d\.]+)", info_str)

        if not match_frames:
            raise ValueError("'SizeT' (number of frames) not found in metadata.")

        frames = match_frames.group(1)
        finter = match_interval.group(1) if match_interval else "1.0"  # Default to 1.0 sec
        ij.IJ.run("Properties...", f"channels={n_chan} slices=1 frames={frames} interval={finter}")

        # Apply metadata (ImageDescription)
        image_description = description_template
        if image_description:
            image_description = sanitize_description(image_description)
            comp.setProperty("Info", image_description)

        # Save the merged image
        out_path = root / save_name
        ij.IJ.saveAs(comp, "Tiff", str(out_path))
        comp.close()
        print(f"Saved merged TIFF → {out_path.name}")

        # Optionally delete intermediate single-channel files
        if delete_intermediate:
            for f in filenames2delete:
                (root / f).unlink(missing_ok=True)

    return out_path

In [None]:
def run_precorrection(nd2_path, config, mat_path, ij):
    """
    Run the full drift and chromatic correction pipeline on a multichannel ND2 image.

    This function performs:
    1. Spot detection and tracking using TrackMate
    2. Drift estimation from tracking results
    3. Frame-wise XY drift correction across all channels
    4. Chromatic registration using a MATLAB-generated affine transform
    5. Final multichannel TIFF reconstruction and export

    Parameters
    ----------
    nd2_path : str
        Path to the input ND2 image file.

    config : dict
        Dictionary containing all TrackMate and processing parameters.

    mat_path : str or Path
        Path to the .mat file containing the 3x3 chromatic affine matrix ('T').

    ij : imagej.ImageJ
        Initialized ImageJ instance.

    Returns
    -------
    tif_path : Path
        Path to the final drift- and chromatic-corrected multichannel TIFF.
    """

    # [0] Prepare output directories and names
    base_dir = os.path.abspath(os.path.join(os.path.dirname(nd2_path), os.pardir))
    track_dir = os.path.join(base_dir, "2.Tracks_for_driftc")
    reg_dir = os.path.join(base_dir, "3.driftc_reg")
    os.makedirs(track_dir, exist_ok=True)
    os.makedirs(reg_dir, exist_ok=True)

    filename_head = os.path.splitext(os.path.basename(nd2_path))[0]
    n_chan = config["n_channels"]
    target_reg_chan = config["target_registration_chan"]

    # [1] Run tracking using TrackMate with given config
    print("[1] Tracking and filtering...")
    xml_path = run_trackmate_nd2(
        nd2_path,
        track_dir,
        target_chan=config["target_chan"],
        spot_radius=config["spot_radius"],
        threshold=config["threshold"],
        subpixel=config["subpixel"],
        median_filter=config["median_filter"],
        linking_max=config["linking_max"],
        gapclosing_max=config["gapclosing_max"],
        max_frame_gap=config["max_frame_gap"],
        allow_split=config["allow_split"],
        allow_merge=config["allow_merge"],
        filter_quality=config["filter_quality"],
        quality_cutoff=config["quality_cutoff"],
        filter_displacement=config["filter_displacement"],
        displacement_cutoff=config["displacement_cutoff"],
        filter_duration=config["filter_duration"],
        duration_cutoff_time=config["duration_cutoff_time"]
    )

    # Load tracking data from XML
    df_list = load_tracks_xml(str(xml_path[0]))

    # [2] Compute and save drift trace from tracking results
    print("[2] Drift trace generation...")
    drift_trace = compute_median_drift_from_tracks(df_list)
    drift_txt_path = os.path.join(track_dir, f"{filename_head}_drift.txt")
    save_drift_trace_txt(drift_trace, drift_txt_path)

    # [3] Apply drift correction to all channels
    print("[3] Drift correction...")
    description_path = os.path.join(track_dir, f"{filename_head}_metadata.txt")
    description_template = None
    if os.path.exists(description_path):
        with open(description_path, "r", encoding="utf-8") as f:
            description_template = f.read()

    raw_stack_path = os.path.join(track_dir, f"{filename_head}.tif")
    corrected_prefix = os.path.join(reg_dir, filename_head)
    apply_drift_correction_to_tiff(
        input_tiff_path=raw_stack_path,
        drift_trace=drift_trace,
        output_prefix=corrected_prefix,
        n_channels=n_chan,
        description_template=description_template
    )

    # [4] Apply chromatic registration using the affine matrix from MATLAB
    print("[4] Chromatic error correction...")
    apply_chromatic_registration_from_mat(
        os.path.join(reg_dir, f"{filename_head}_chan{target_reg_chan}_drftc.tif"),
        os.path.join(reg_dir, f"{filename_head}_chan{target_reg_chan}_drftc.tif"),
        mat_path,
        description_template=description_template
    )

    # [5] Merge corrected channels into final multichannel TIFF using ImageJ
    print("[5] Merging channels to multichannel TIF...")
    tif_path = combine_tif_channels(
        reg_dir,
        n_chan=n_chan,
        ij=ij,
        postfix="_drftc",
        delete_intermediate=True,
        description_template=description_template
    )

    print("Full pipeline completed.")

    return tif_path

In [None]:
def get_output_paths_for_g_r_tracks(tif_path: str | Path) -> tuple[Path, Path, Path]:
    """
    Derive output paths for green and red channel tracking results based on the TIFF file path.

    Parameters
    ----------
    tif_path : str or Path
        Path to the registered TIFF file (typically ending with '_drftc_reg.tif').

    Returns
    -------
    delay_dir : Path
        Path to the output directory (4.Tracks_for_TimeDelay).

    xml_g_path : Path
        Output path for TrackMate result XML of the green channel.

    xml_r_path : Path
        Output path for TrackMate result XML of the red channel.
    """
    tif_path = Path(tif_path)
    parent_dir = tif_path.parent.parent  # Navigate two levels up to reach base directory
    delay_dir = parent_dir / "4.Tracks_for_TimeDelay"
    delay_dir.mkdir(exist_ok=True)

    filename_head = tif_path.stem.replace("_drftc_reg", "")
    xml_g_path = delay_dir / f"{filename_head}_drftc_reg_g.xml"
    xml_r_path = delay_dir / f"{filename_head}_drftc_reg_r.xml"

    return delay_dir, xml_g_path, xml_r_path

def run_trackmate_dual_channel_from_config(ij, tif_path, config_green, config_red):
    """
    Run TrackMate analysis independently on green and red channels of a multichannel TIFF,
    and export the tracking results as XML files.

    Parameters
    ----------
    ij : imagej.ImageJ
        An initialized ImageJ instance (from pyimagej).

    tif_path : str or Path
        Path to the input multichannel TIFF file (e.g., with '_drftc_reg' suffix).

    config_green : dict
        Configuration dictionary for green channel tracking parameters.

    config_red : dict
        Configuration dictionary for red channel tracking parameters.

    Returns
    -------
    None
    """
    tif_path = Path(tif_path)
    base_dir = tif_path.parent.parent
    output_dir = base_dir / "4.Tracks_for_TimeDelay"
    output_dir.mkdir(parents=True, exist_ok=True)

    base_name = tif_path.stem.replace("_drftc_reg", "")
    xml_g_path = output_dir / f"{base_name}_drftc_reg_g.xml"
    xml_r_path = output_dir / f"{base_name}_drftc_reg_r.xml"

    # Load image in ImageJ
    imp = ij.IJ.openImage(str(tif_path))
    if imp is None:
        raise FileNotFoundError(f"Cannot open image: {tif_path}")
    imp.show()

    # Standardize calibration metadata
    cal = imp.getCalibration()
    cal.setUnit("pixels")
    cal.pixelWidth = cal.pixelHeight = cal.pixelDepth = 1.0
    cal.setTimeUnit("sec")
    cal.frameInterval = 2.5
    imp.setCalibration(cal)

    # Correct dimension order if needed (e.g., Z/T swap)
    dims = imp.getDimensions()
    if len(dims) >= 5 and dims[3] > 1 and dims[4] == 1:
        imp.setDimensions(dims[2], dims[4], dims[3])  # C, T, Z

    def run_single_channel(config, xml_out_path):
        """
        Run TrackMate on a single channel using the provided config,
        and export the result to XML.
        """
        model = Model()
        model.setLogger(Logger.IJ_LOGGER)
        model.setPhysicalUnits(cal.getUnit(), cal.getTimeUnit())

        settings = Settings(imp)

        # Set up detector
        settings.detectorFactory = LogDetectorFactory()
        det = HashMap()
        det.put("DO_SUBPIXEL_LOCALIZATION", config["subpixel"])
        det.put("RADIUS", config["spot_radius"])
        det.put("TARGET_CHANNEL", Integer(config["target_chan"]))
        det.put("THRESHOLD", config["threshold"])
        det.put("DO_MEDIAN_FILTERING", config["median_filter"])
        settings.detectorSettings = det

        # Spot quality filter (optional)
        if config.get("filter_quality", True):
            settings.addSpotFilter(
                FeatureFilter("QUALITY", config.get("quality_cutoff", 0.0), True)
            )

        # Set up tracker
        settings.trackerFactory = SparseLAPTrackerFactory()
        trk = settings.trackerFactory.getDefaultSettings()
        trk.put("MAX_FRAME_GAP", Integer(config["max_frame_gap"]))
        trk.put("LINKING_MAX_DISTANCE", config["linking_max"])
        trk.put("GAP_CLOSING_MAX_DISTANCE", config["gapclosing_max"])
        trk.put("ALLOW_TRACK_SPLITTING", config["allow_split"])
        trk.put("ALLOW_TRACK_MERGING", config["allow_merge"])
        settings.trackerSettings = trk

        # Compute duration cutoff
        fi = cal.frameInterval
        nT = imp.getDimensions()[4]
        duration_cutoff = config["duration_cutoff_time"]
        if math.isnan(duration_cutoff):
            duration_cutoff = fi * (nT - 1.5)

        settings.addAllAnalyzers()

        # Optional track filters
        if config.get("filter_displacement", True):
            settings.addTrackFilter(
                FeatureFilter("TRACK_DISPLACEMENT", config["displacement_cutoff"], True)
            )

        if config.get("filter_duration", True):
            settings.addTrackFilter(
                FeatureFilter("TRACK_DURATION", duration_cutoff, True)
            )

        if config.get("filter_track_start", True):
            settings.addTrackFilter(
                FeatureFilter("TRACK_START", config.get("track_start_time", 1), True)
            )

        # Execute TrackMate
        tm = TrackMate(model, settings)
        if not tm.checkInput() or not tm.process():
            raise RuntimeError(tm.getErrorMessage())

        ExportTracksToXML.export(model, settings, File(str(xml_out_path)))
        print(f"TrackMate result saved → {xml_out_path.name}")

    # Run for green channel
    print("TrackMate: Green channel...")
    run_single_channel(config_green, xml_g_path)

    # Run for red channel
    print("TrackMate: Red channel...")
    run_single_channel(config_red, xml_r_path)

    imp.close()

In [None]:
def subtract_background_and_merge(tif_path, rolling=5, ij=ij):
    """
    Subtract background from channel 2 and 3 of a multichannel TIFF using ImageJ,
    then merge the corrected channels into a new 3-channel TIFF and save the result.

    Parameters
    ----------
    tif_path : str or Path
        Path to the original multichannel TIFF image.

    rolling : int, default=5
        Radius of the rolling ball for background subtraction (in pixels).

    ij : imagej.ImageJ
        An initialized ImageJ instance for executing macro operations.

    Returns
    -------
    out_path : Path
        Path to the saved background-corrected and merged TIFF file.
    """
    tif_path = Path(tif_path)
    tif_name = tif_path.stem

    # Create output directory
    out_dir = tif_path.parent.parent / "5.Background_Corrected"
    out_dir.mkdir(exist_ok=True)
    out_path = out_dir / f"{tif_name}_back.tif"

    # Load the original multichannel TIFF
    imp = ij.IJ.openImage(str(tif_path))
    imp.show()

    # Split the image into separate channels
    ij.IJ.run("Split Channels")

    # Define expected window titles for each channel
    c1_title = f"C1-{tif_name}.tif"
    c2_title = f"C2-{tif_name}.tif"
    c3_title = f"C3-{tif_name}.tif"

    # Apply background subtraction to channel 2
    ij.IJ.selectWindow(c2_title)
    ij.IJ.run("Subtract Background...", f"rolling={rolling} stack")

    # Apply background subtraction to channel 3
    ij.IJ.selectWindow(c3_title)
    ij.IJ.run("Subtract Background...", f"rolling={rolling} stack")

    # Retrieve the channel image objects
    c1 = WindowManager.getImage(c1_title)

    # Ensure the correct images are selected for merging
    ij.IJ.selectWindow(c2_title)
    c2_corr = WindowManager.getCurrentImage()

    ij.IJ.selectWindow(c3_title)
    c3_corr = WindowManager.getCurrentImage()

    # Rename channels for merging recognition
    c1.setTitle("C1")
    c2_corr.setTitle("C2_corr")
    c3_corr.setTitle("C3_corr")

    # Merge corrected channels:
    #   - C1 to Channel 1
    #   - C2_corr to Channel 2
    #   - C3_corr to Channel 4 (ImageJ convention, Channel 3 left empty)
    ij.IJ.run("Merge Channels...", "c1=C1 c2=C2_corr c4=C3_corr create")
    merged = WindowManager.getCurrentImage()

    # Set intensity range and reset display
    ij.IJ.setMinAndMax(merged, 0, 65535)
    merged.resetDisplayRange()

    # Save the result as a new TIFF file
    ij.IJ.saveAsTiff(merged, str(out_path))

    # Clean up all open ImageJ windows
    ij.IJ.run("Close All")

    print(f"Saved background-corrected TIFF: {out_path}")
    return out_path

## Final Execution: Full Preprocessing and Tracking Pipeline

The following code executes the full preprocessing and analysis workflow for a multichannel ND2 file. This includes:

1. Drift correction
2. Chromatic registration
3. Dual-channel TrackMate tracking
4. Background subtraction and re-merging

In [None]:
tif_path = run_precorrection(nd2_path, config, mat_path, ij)

In [None]:
run_trackmate_dual_channel_from_config(ij,
                                       tif_path,
                                       config_green,
                                       config_red
                                       )

In [None]:
back_path = subtract_background_and_merge(tif_path, ij=ij)