In [None]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tkinter import Tk, filedialog
from scipy.spatial import cKDTree
from scipy.spatial.distance import cdist
from lmfit import Model
import tifffile
from skimage import morphology
import cv2
from tqdm import tqdm
import concurrent.futures
import matplotlib.patches as patches

# Colocalization Analysis and Time-Delay Pipeline

## Function Call Hierarchy

```text
run_full_analysis
├── main_data_loading
│   ├── load_tracks_xml
│   └── tracks_to_dataframe
├── run_main_filtering
│   ├── compute_first_and_last_appearance
│   ├── apply_roi_foi_filter
│   ├── apply_already_activated_filter
│   ├── apply_revived_filter_matlab
│   └── apply_overlapping_filter_matlab
├── find_colocalized_pairs_matlab
│   └── internal_loop  (pairwise distance calculation)
├── main_time_delay_fitting
│   ├── fit_time_delay_exp1decay
│   │   ├── exp1decay_func
│   │   └── compute_r_squared
│   ├── plot_time_delay_exp1decay
│   └── plot_time_delay_exp1decay_inverse_x
├── main_analysis_after_dark_pairs
│   ├── load_multistack
│   ├── analyze_dark_pairs
│   │   └── visualize_dark_pairs
│   └── main_time_delay_fitting  ← 재호출
└── making_mark_start

marker_df_2
├── patch generation per frame
└── merging patch tables

plot_patch_df_grid_grouped
├── render_single_row
└── center_channel_table

combine_figures_to_one_subplot
└── render_to_numpy
```

---

## Key Variable Summary

| Variable          | Description                                                              |
|-------------------|--------------------------------------------------------------------------|
| `df_g`, `df_r`    | Raw green/red tracking data from XML                                     |
| `df_g_ol`, `df_r_ol` | Final filtered green/red tracks (revived/overlap removed)           |
| `primary_pair_df` | Colocalized green-red spot pairs (DataFrame)                            |
| `df_mark`         | Spot pair table with first-frame info per track                         |
| `multistack`      | 4D numpy array of the TIFF image (x, y, t, channel)                      |
| `patch_df`        | Marker visualization patches with RICM overlay                          |
| `settings`        | Dictionary of all analysis parameters                                   |
| `metadata`        | Metadata from TrackMate XML (units, interval, etc.)                     |

---

## Execution Flow (`run_full_analysis()`)

```text
1. Load XML + TIFF via GUI: main_data_loading()
2. Filter green/red tracks: run_main_filtering()
   - ROI, FOI, activation, revived, overlapping
3. Colocalize tracks: find_colocalized_pairs_matlab()
4. Time delay fitting: main_time_delay_fitting()
5. Segment RICM dark regions: main_analysis_after_dark_pairs()
   - Includes fitting for "on dark" and "not on dark" separately
6. Extract start time of tracks: making_mark_start()
7. (Optional) Extract visualization patch: marker_df_2()
8. (Optional) Visualize: plot_patch_df_grid_grouped() → combine_figures_to_one_subplot()
```

---

## Example Usage

```python
df_mark = run_full_analysis(criInter=0.6, criIntra=0.5, criOverlap=2.5, criBlink=2)

# Optional visualization
patch_df = marker_df_2(multistack, df_mark, df_mark)
figs = plot_patch_df_grid_grouped(patch_df)
combine_figures_to_one_subplot(figs)
```

---

## Output Preview

- `df_mark`: Initial frame (`first_g_T`, `first_r_T`) of each paired track
- `patch_df`: List of visual patches for each track with color-coded marker (red/yellow/blue)
- Time-delay histogram plot (τ, R²)
- Overlay visualization of colocalized tracks on RICM


In [None]:
def get_analysis_settings(criInter=0.6, criIntra=0.5, criOverlap=2.5, criBlink=2):
    """
    Returns a dictionary of default analysis parameters used for
    dual-channel spot pairing, time-delay fitting, and RICM overlay evaluation.

    Parameters
    ----------
    criInter : float
        Inter-channel distance threshold (e.g., for colocalization in microns).

    criIntra : float
        Intra-channel displacement tolerance (e.g., drift between frames).

    criOverlap : float
        Minimum number of overlapping frames to consider a valid colocalized pair.

    criBlink : int
        Maximum number of allowed missing frames (blinks) within a track.

    Returns
    -------
    settings : dict
        Dictionary containing all analysis parameters for downstream processing.
    """
    return {
        "path0": r"D:/",                   # Root directory for input/output files
        "fTest": False,                    # Flag for test mode or debug run
        "RICMchannel": 1,                  # Index of RICM (reflection interference) image channel
        "nChan": 3,                        # Total number of image channels (e.g., 3 for RGB)
        "frameToOverlay": 40,             # Frame index used for RICM/image overlay visualization
        "TimeDelayFitting": 15,           # Time window (frames) for delay fitting around colocalization
        "imgsize": 800,                   # Width/height of square image (pixels)
        "edge": 100,                      # Margin pixels to exclude from edge effects
        "frameCnt": 200,                  # Total number of frames in the movie stack
        "roi": [100, 700, 100, 700],      # Region of interest: [xmin, xmax, ymin, ymax]
        "criInter": criInter,            # Max distance for inter-channel spot pairing
        "criIntra": criIntra,            # Max movement allowed within channel
        "criOverlap": criOverlap,        # Min overlap duration for a valid pair (in frames)
        "criBlink": criBlink,            # Max number of frames that can be missed (blinking tolerance)
        "interval": 15,                  # Time interval between frames (seconds or arbitrary unit)
        "lateralOffset": [0.0, 0.0],     # Optional lateral shift (X, Y) between channels
        "criLengthMin1": 2,              # Min track length in channel 1 (green)
        "criLengthMin2": 3,              # Min track length in channel 2 (red)
        "timeHistEdges": np.arange(-35.5, 36.5, 1),  # Histogram bin edges for time delay distribution
    }

def load_tracks_xml(filepath):
    """
    Load tracking data from a TrackMate-generated XML file.

    This function parses an XML file produced by the TrackMate plugin in Fiji/ImageJ.
    It interprets each <particle> as a single track and extracts the sequence of
    spot detections (detection elements with t, x, y, z attributes) while preserving order.

    Additionally, metadata such as spatial units, time units, and frame interval
    are extracted from the XML root attributes.

    Parameters
    ----------
    filepath : str or Path
        Path to the TrackMate XML file.

    Returns
    -------
    tracks : list of pandas.DataFrame
        A list of track DataFrames. Each DataFrame contains:
        - T : time/frame index (float)
        - X : x-coordinate (float)
        - Y : y-coordinate (float)
        - Z : z-coordinate (float)

    metadata : dict
        Dictionary containing global metadata from the XML file:
        - spaceUnits      : unit of spatial measurement (e.g., "micron", "pixels")
        - timeUnits       : unit of time measurement (e.g., "seconds", "frames")
        - frameInterval   : interval between frames
        - date            : timestamp when the XML was generated
        - source          : source or software that generated the file
    """
    tree = ET.parse(filepath)
    root = tree.getroot()

    # Extract basic metadata from the root <TrackMate> tag
    metadata = {
        "spaceUnits": root.attrib.get("spaceUnits", "pixels"),
        "timeUnits": root.attrib.get("timeUnits", "frames"),
        "frameInterval": float(root.attrib.get("frameInterval", 1.0)),
        "date": root.attrib.get("generationDateTime", ""),
        "source": root.attrib.get("from", "")
    }

    tracks = []

    # Each <particle> element corresponds to a single track
    particles = root.findall(".//particle")

    for particle in particles:
        detections = particle.findall(".//detection")
        t_list = []
        for d in detections:
            t = float(d.attrib.get("t", 0))
            x = float(d.attrib.get("x", 0))
            y = float(d.attrib.get("y", 0))
            z = float(d.attrib.get("z", 0))
            t_list.append([t, x, y, z])

        if len(t_list) > 0:
            tracks.append(pd.DataFrame(t_list, columns=["T", "X", "Y", "Z"]))

    return tracks, metadata

def tracks_to_dataframe(tracks):
    """
    Convert a list of per-track DataFrames into a single unified DataFrame.

    Each input DataFrame represents a track and contains columns:
        - T : time point (frame index)
        - X : x-coordinate
        - Y : y-coordinate
        - Z : z-coordinate (optional; ignored in this output)

    The function assigns a unique track ID to each track and
    compiles all timepoints into a flat DataFrame suitable for filtering,
    visualization, or machine learning.

    Parameters
    ----------
    tracks : list of pandas.DataFrame
        List of DataFrames where each DataFrame corresponds to a single track.
        Each DataFrame must have columns: ["T", "X", "Y", "Z"]

    Returns
    -------
    pandas.DataFrame
        Combined DataFrame with columns:
        - track_id : integer index assigned to each track
        - T        : rounded time/frame index (int)
        - X        : x-coordinate (float)
        - Y        : y-coordinate (float)
    """
    rows = []
    for track_id, df in enumerate(tracks):
        for _, row in df.iterrows():
            rows.append((track_id, round(row["T"]), row["X"], row["Y"]))

    return pd.DataFrame(rows, columns=["track_id", "T", "X", "Y"])

def compute_first_and_last_appearance(df, min_length, method="median"):
    """
    Compute the first and last appearance of each track and its spatial center.

    This function filters out short tracks based on a minimum length requirement,
    and for each remaining track, calculates:
    - First frame (T_first)
    - Last frame (T_last)
    - Spatial center (X_center, Y_center) using mean or median

    Parameters
    ----------
    df : pandas.DataFrame
        A DataFrame with columns: ["track_id", "T", "X", "Y"]
        Each row represents a spot detection at frame T for a given track.

    min_length : int
        Minimum number of timepoints (frames) required for a track to be retained.

    method : str, optional, default="median"
        Aggregation method to compute spatial center of the track:
        - "median" : compute median of X and Y
        - "mean"   : compute mean of X and Y

    Returns
    -------
    df_fl : pandas.DataFrame
        DataFrame with one row per track, including:
        - track_id   : ID of the track
        - T_first    : first frame the track appears
        - T_last     : last frame the track appears
        - X_center   : average or median X position
        - Y_center   : average or median Y position
    """
    # Count how many timepoints each track has
    counts = df.groupby("track_id")["T"].count()

    # Keep only tracks that meet the minimum length requirement
    valid_tracks = counts[counts >= min_length].index
    df_valid = df[df["track_id"].isin(valid_tracks)].copy()

    # Choose aggregation function for X and Y
    agg_func = {"T": ["min", "max"]}
    if method == "mean":
        agg_func["X"] = "mean"
        agg_func["Y"] = "mean"
    else:
        agg_func["X"] = "median"
        agg_func["Y"] = "median"

    # Aggregate data per track
    df_fl = df_valid.groupby("track_id").agg(agg_func)

    # Flatten multi-level columns
    df_fl.columns = ["T_first", "T_last", "X_center", "Y_center"]
    df_fl = df_fl.reset_index()

    return df_fl

def apply_roi_foi_filter(df_fl: pd.DataFrame,
                         t_min: int,
                         t_max: int,
                         roi: list[int]) -> pd.DataFrame:
    """
    Apply spatial (ROI) and temporal (FOI) filters to a DataFrame of tracks
    based on first appearance time and center position.

    This function mimics the logic of MATLAB-style for-loop filtering, and keeps
    only those tracks whose first timepoint falls within a time range (FOI),
    and whose center position falls within a rectangular region (ROI).

    Parameters
    ----------
    df_fl : pandas.DataFrame
        A DataFrame containing one row per track, with at least the following columns:
        - "T_first" : first frame where the track appears
        - "X_center": X coordinate of the track's spatial center
        - "Y_center": Y coordinate of the track's spatial center

    t_min : int
        Minimum allowed frame index (inclusive) for the first appearance.

    t_max : int
        Maximum allowed frame index (inclusive) for the first appearance.

    roi : list of int [x_min, x_max, y_min, y_max]
        Spatial region of interest (ROI) defined as:
        - roi[0]: minimum X
        - roi[1]: maximum X
        - roi[2]: minimum Y
        - roi[3]: maximum Y

    Returns
    -------
    pandas.DataFrame
        A filtered DataFrame containing only the tracks within both
        the specified temporal and spatial ranges.
    """
    cond = (
        (df_fl["T_first"] >= t_min) & (df_fl["T_first"] <= t_max) &
        (df_fl["X_center"].between(roi[0], roi[1])) &
        (df_fl["Y_center"].between(roi[2], roi[3]))
    )
    return df_fl[cond].reset_index(drop=True)

def apply_already_activated_filter(df_g: pd.DataFrame,
                                   df_r: pd.DataFrame,
                                   interval: int):
    """
    Filter out tracks that appear too early (before activation window).

    This function mimics the MATLAB for-loop logic used to exclude tracks
    that are already active before the start of the observation or delay window.
    This helps ensure only new or "freshly activated" tracks are included.

    Parameters
    ----------
    df_g : pandas.DataFrame
        DataFrame for green-channel tracks.
        Must contain a "T_first" column indicating first appearance frame.

    df_r : pandas.DataFrame
        DataFrame for red-channel tracks.
        Must also contain a "T_first" column.

    interval : int
        Minimum required start frame for green-channel tracks.
        Tracks with `T_first < interval` are removed from df_g.

    Returns
    -------
    df_g_filt : pandas.DataFrame
        Filtered green-channel DataFrame where T_first >= interval.

    df_r_filt : pandas.DataFrame
        Filtered red-channel DataFrame where T_first >= 1.

    Notes
    -----
    - Green tracks are filtered with a stricter threshold (`interval`)
      to avoid early-appearing events in time-delay pairing.
    - Red tracks are only filtered to remove those starting at T=0.
    """
    df_g_filt = df_g[df_g["T_first"] >= interval].reset_index(drop=True)
    df_r_filt = df_r[df_r["T_first"] >= 1].reset_index(drop=True)
    return df_g_filt, df_r_filt

def apply_revived_filter_matlab(df_first: pd.DataFrame, cri_intra: float) -> pd.DataFrame:
    """
    Remove revived (redundant) spots using MATLAB-style rules.

    This function eliminates spatially redundant tracks based on pairwise distances,
    using a logic that mimics the original MATLAB implementation:

        • If two points are within 'cri_intra' distance,
        • The one with later (or same) T_first is discarded,
        • If T_first is equal, the one with the larger row index is removed.

    This preserves only the earliest-appearing track within a neighborhood,
    using triangular index logic (lower triangle) to ensure stability and prevent double removal.

    Parameters
    ----------
    df_first : pandas.DataFrame
        Input DataFrame with one row per track.
        Must contain:
        - "T_first": first appearance frame
        - "X_center", "Y_center": spatial coordinates of the track

    cri_intra : float
        Maximum distance within which two tracks are considered redundant.

    Returns
    -------
    pandas.DataFrame
        Filtered DataFrame with revived tracks removed. Index is reset.

    Notes
    -----
    - Uses pairwise distance matrix (cdist) for full comparison (O(N^2)).
    - Applies "lower triangular" logic to ensure unique pair comparisons.
    - Follows MATLAB rule: if two points are too close,
        drop the one that is later or has higher index if simultaneous.
    """

    # Step ①: Extract coordinates and appearance time as NumPy arrays
    coords = df_first[["X_center", "Y_center"]].to_numpy(float)
    times = df_first["T_first"].to_numpy()

    # Step ②: Compute full pairwise distance matrix (N × N)
    dist_mat = cdist(coords, coords)

    # Step ③: Create a mask of all (i, j) pairs with distance < cri_intra (but i ≠ j)
    mask = (dist_mat < cri_intra) & (dist_mat > 0)

    # Step ④: Apply lower triangular logic (i > j) to avoid duplicate comparisons
    tri_mask = np.tril(mask, k=-1)

    # Step ⑤: Get all (i, j) index pairs that satisfy the lower triangle condition
    i_idx, j_idx = np.where(tri_mask)

    # Step ⑥: For each (i, j) pair, drop i if its T_first >= T_first of j
    later_or_equal = times[i_idx] - times[j_idx] >= 0
    rows_to_drop = i_idx[later_or_equal]

    # Step ⑦: Create a boolean mask for rows to keep
    keep_mask = np.ones(len(df_first), dtype=bool)
    keep_mask[rows_to_drop] = False

    return df_first.loc[keep_mask].reset_index(drop=True)

def apply_overlapping_filter_matlab(df: pd.DataFrame, cri_overlap: float) -> pd.DataFrame:
    """
    Remove all overlapping spots based on spatial proximity using MATLAB-style logic.

    This function identifies all pairs of spots (rows) that are closer than a 
    given radius (`cri_overlap`), and removes both from the dataset. This is useful 
    in scenarios where any spatially ambiguous or overlapping detections should be excluded.

    The logic replicates the MATLAB approach:
        • If two points are within cri_overlap distance (and not the same point),
        • Both points are discarded,
        • Only one direction of the pair (i > j) is checked to avoid redundancy.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame of spot or track centers.
        Must contain:
        - "X_center": x-coordinate of the center
        - "Y_center": y-coordinate of the center

    cri_overlap : float
        Maximum distance threshold under which two spots are considered overlapping.
        Any pair with distance < cri_overlap will be removed (both rows).

    Returns
    -------
    pandas.DataFrame
        Filtered DataFrame excluding all overlapping points.
        Index is reset after filtering.

    Notes
    -----
    - All (i, j) pairs within cri_overlap distance are detected.
    - Both i and j are removed if i > j (lower triangle) to avoid duplicate checks.
    - This is stricter than 'revived' filtering, where only one point is removed.
    """

    # Step ①: Extract coordinate array
    coords = df[["X_center", "Y_center"]].to_numpy(float)

    # Step ②: Compute full pairwise distance matrix (N x N)
    dist_mat = cdist(coords, coords)

    # Step ③: Create mask for close pairs (not self)
    mask = (dist_mat < cri_overlap) & (dist_mat > 0)

    # Step ④: Apply lower-triangular mask to avoid double-counting
    tri_mask = np.tril(mask, k=-1)

    # Step ⑤: Get all (i, j) index pairs that should be removed
    i_idx, j_idx = np.where(tri_mask)

    # Step ⑥: Mark both i and j for removal
    rows_to_drop = np.unique(np.concatenate([i_idx, j_idx]))

    # Step ⑦: Create boolean mask to retain non-overlapping rows
    keep_mask = np.ones(len(df), dtype=bool)
    keep_mask[rows_to_drop] = False

    return df.loc[keep_mask].reset_index(drop=True)

def find_colocalized_pairs_matlab(df1: pd.DataFrame,
                                   df2: pd.DataFrame,
                                   criInter: float,
                                   valid_rows: np.ndarray | None = None
                                   ) -> pd.DataFrame:
    """
    Find colocalized spot pairs using MATLAB-style for-loop logic.

    This function mimics the MATLAB colocalization loop exactly:
        • For each green spot (df1), compute Euclidean distance to all red spots (df2)
        • If the minimum distance is less than criInter, save the pair
        • One red spot can be matched to multiple green spots (duplicates allowed)

    Parameters
    ----------
    df1 : pandas.DataFrame
        DataFrame of reference spots (e.g., green channel).
        Must include columns:
        - "track_id"
        - "T_first"
        - "X_center"
        - "Y_center"

    df2 : pandas.DataFrame
        DataFrame of comparison spots (e.g., red channel).
        Same required columns as df1.

    criInter : float
        Maximum distance (in pixels or microns) allowed to consider a spot pair colocalized.

    valid_rows : np.ndarray of bool, optional
        Boolean array (same length as df1) indicating which rows are eligible for matching.
        If None, all rows in df1 are considered valid.

    Returns
    -------
    pandas.DataFrame
        DataFrame of matched colocalized pairs with the following columns:
        - time_delay : T_green - T_red
        - distance   : Euclidean distance between spot centers
        - ref_id     : track_id from df1 (green)
        - ref_T      : T_first from df1
        - ref_X      : X_center from df1
        - ref_Y      : Y_center from df1
        - cmp_id     : track_id from df2 (red)
        - cmp_T      : T_first from df2
        - cmp_X      : X_center from df2
        - cmp_Y      : Y_center from df2
    """

    # Extract coordinate and time arrays
    coords1 = df1[["X_center", "Y_center"]].to_numpy(float)
    coords2 = df2[["X_center", "Y_center"]].to_numpy(float)
    times1  = df1["T_first"].to_numpy()
    times2  = df2["T_first"].to_numpy()

    # Initialize mask for valid rows if not provided
    if valid_rows is None:
        valid_rows = np.ones(len(df1), dtype=bool)
    else:
        valid_rows = np.asarray(valid_rows, dtype=bool)

    pair_rec = []

    # MATLAB-style loop: for i = find(valid_rows)'
    for i in np.where(valid_rows)[0]:
        # Compute distance to all spots in df2
        diff = coords2 - coords1[i]           # Shape: (N2, 2)
        distlist = np.sqrt(np.sum(diff**2, axis=1))

        # Find the nearest red spot
        id = int(distlist.argmin())
        minDist = float(distlist[id])

        # If below colocalization threshold, record the pair
        if minDist < criInter:
            pair_rec.append([
                float(times1[i] - times2[id]),              # Time delay
                minDist,                                     # Distance
                int(df1.iloc[i]["track_id"]),               # Reference (green) track ID
                float(times1[i]), float(coords1[i, 0]), float(coords1[i, 1]),
                int(df2.iloc[id]["track_id"]),              # Comparison (red) track ID
                float(times2[id]), float(coords2[id, 0]), float(coords2[id, 1]),
            ])

    cols = [
        "time_delay", "distance",
        "ref_id", "ref_T", "ref_X", "ref_Y",
        "cmp_id", "cmp_T", "cmp_X", "cmp_Y"
    ]

    return pd.DataFrame(pair_rec, columns=cols)

def exp1decay_func(x, a, b, c):
    """
    Single exponential decay function.

    This model is commonly used to describe time-dependent decay processes,
    such as fluorescence bleaching, signal attenuation, or molecular unbinding.

    Mathematical form:
        f(x) = (a - c) * exp(-x / b) + c

    Parameters
    ----------
    x : array-like
        Input (usually time) values at which to evaluate the function.

    a : float
        Initial value at time zero, i.e., f(0) = a.

    b : float
        Time constant (decay rate); controls how quickly the function decays.

    c : float
        Asymptotic value; the baseline that the function approaches as x → ∞.

    Returns
    -------
    y : array-like
        Output values of the exponential decay function evaluated at x.
    """
    return (a - c) * np.exp(-x / b) + c

def compute_r_squared(y_true, y_pred):
    """
    Compute the coefficient of determination (R² score) between true and predicted values.

    R² measures the proportion of variance in the dependent variable (y_true)
    that is predictable from the independent variable(s) (via y_pred).
    
    It is defined as:
        R² = 1 - (SS_res / SS_tot)
    where:
        SS_res = Σ (y_true - y_pred)²   → residual sum of squares
        SS_tot = Σ (y_true - mean(y_true))²   → total sum of squares

    Parameters
    ----------
    y_true : array-like
        Ground truth (observed) values.

    y_pred : array-like
        Predicted values from the model.

    Returns
    -------
    r_squared : float
        R² score. Ranges from -∞ to 1.0
        - 1.0 means perfect prediction
        - 0.0 means model predicts as poorly as just using the mean
        - Negative values imply the model performs worse than the mean

    Notes
    -----
    - If the variance in y_true is zero (i.e., all values are the same), R² is undefined.
      In that case, the function returns np.nan.
    """
    ss_res = np.sum((y_true - y_pred) ** 2)  # Residual sum of squares
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)  # Total sum of squares

    if np.allclose(ss_tot, 0):
        return np.nan  # R² undefined when y_true has no variance

    return 1 - (ss_res / ss_tot)

def fit_time_delay_exp1decay(xdata, ydata, frameInterval=1.0, min_points=6):
    """
    Robust fitting of a time-delay histogram using a single exponential decay model.

    This function fits the provided (x, y) data using the model:
        y = (a - c) * exp(-x / b) + c
    which describes a single-phase exponential decay, often used in time delay
    or unbinding analyses in biological imaging.

    Parameters
    ----------
    xdata : array-like
        The x-axis values (e.g., time delay in frame units).

    ydata : array-like
        The observed data points (e.g., counts or normalized histogram values).

    frameInterval : float, default=1.0
        Time duration of one frame (in seconds). Used to convert the decay constant
        from frames to seconds.

    min_points : int, default=6
        Minimum number of data points required to attempt fitting.
        If fewer points are given, fitting is skipped.

    Returns
    -------
    result : lmfit.model.ModelResult or None
        The fitted model result object from `lmfit`, or None if fitting fails.

    tau : float or None
        Estimated decay constant τ (in seconds), converted from frames.
        τ = b * frameInterval.

    r_squared : float or None
        Coefficient of determination (R²) indicating fit quality.
        Returns None if fitting fails.

    Notes
    -----
    - The initial guess for the decay constant (b) is estimated as the index
      where the signal drops to ~50% of its initial value.
    - If fitting fails, the function prints an error message and returns (None, None, None).
    """

    xdata = np.asarray(xdata)
    ydata = np.asarray(ydata)

    if len(xdata) < min_points or len(ydata) < min_points:
        print(f"Insufficient data: xData={xdata}, yData={ydata}")
        return None, None, None

    try:
        y0 = ydata[0]
        # Estimate where the signal drops to half for initial guess of b
        tmp_idx = np.argmin(np.abs(ydata[1:] - 0.5 * y0)) + 1

        model = Model(exp1decay_func)
        params = model.make_params(a=y0, b=tmp_idx + 1, c=0)

        result = model.fit(ydata, params, x=xdata)

        # Convert decay constant from frames to seconds
        tau = result.params["b"].value * frameInterval

        # Evaluate fitted model and compute R²
        y_pred = result.eval(x=xdata)
        r_squared = compute_r_squared(ydata, y_pred)

        return result, tau, r_squared

    except Exception as e:
        print("Fitting failed:", str(e))
        return None, None, None

def plot_time_delay_exp1decay(bin_centers, hist_vals, xData, yData, fit_result, time_delay, r_squared, time_unit="sec"):
    """
    Visualizes the fitting results of a time-delay histogram using exponential decay.

    This function generates a plot showing the time-delay histogram along with the fitted exponential decay curve.
    The plot includes information about the decay constant (τ) and the coefficient of determination (R²) if available.

    Parameters
    ----------
    bin_centers : array-like
        The center points of the histogram bins, representing the time delay values (in frames).

    hist_vals : array-like
        The histogram values corresponding to the bin_centers, representing the frequency/count of data in each time delay bin.

    xData : array-like
        The x-values (time delay values) for fitting, typically the same as bin_centers but used for the fitting procedure.

    yData : array-like
        The y-values (histogram counts) for fitting, typically the same as hist_vals but used for fitting.

    fit_result : lmfit.model.ModelResult or None
        The fitted result object from `lmfit`, containing the parameters and values of the exponential decay fit.
        If no fitting was performed, this should be `None`.

    time_delay : float
        The time constant (τ) estimated from the exponential decay fit, in the units specified by `time_unit`.

    r_squared : float or None
        The coefficient of determination (R²) for the fit, indicating the goodness of fit. If fitting fails, this may be `None`.

    time_unit : str, default="sec"
        The unit of time for the time constant (τ), such as "sec" or "ms".

    Notes
    -----
    - The histogram is plotted as bars, with the fitted exponential decay curve overlaid.
    - The plot includes a vertical line at x=0 to indicate the origin for time delays.
    - The title of the plot displays the R² value if fitting is successful, otherwise it shows "R²=N/A".
    """
    plt.figure(figsize=(6, 4))
    plt.bar(bin_centers, hist_vals, width=1.0, alpha=0.5, label="histogram")
    plt.axvline(0, linestyle="--", color="blue", label="x=0")

    if fit_result is not None:
        x_fit = np.linspace(min(xData), max(xData), 100)
        y_fit = fit_result.eval(x=x_fit)
        plt.plot(x_fit, y_fit, "r-", label=f"Exp Fit (τ={time_delay:.2f} {time_unit}, R²={r_squared:.3f})")

    plt.xlabel("Time delay (frame)")
    plt.ylabel("Count")
    if r_squared is not None:
        plt.title(f"Exp Fit (R²={r_squared:.3f})")
    else:
        plt.title("Exp Fit (R²=N/A)")
    plt.legend()
    plt.show()

def plot_time_delay_exp1decay_inverse_x(bin_centers, hist_vals, xData, yData, fit_result, time_delay, r_squared, time_unit="sec"):
    """
    Plot the time-delay histogram and exponential fit using inverse x-axis transformation (1/x).

    This visualization helps highlight short time delays and behaviors near x = 0,
    often used when events occurring very quickly are of particular interest.

    Parameters
    ----------
    bin_centers : array-like
        Center values of the time-delay histogram bins (x-axis, in frame units).

    hist_vals : array-like
        Histogram counts corresponding to bin_centers.

    xData : array-like
        x values used for fitting (same units as bin_centers).

    yData : array-like
        y values used for fitting (same units as hist_vals).

    fit_result : lmfit.model.ModelResult or None
        Result of fitting using an exponential decay model. If None, no fit curve is shown.

    time_delay : float
        Fitted decay time constant τ, to be shown in the legend (in units of time_unit).

    r_squared : float or None
        R² value indicating goodness of fit. If None, this is omitted from the title.

    time_unit : str, default="sec"
        Unit for time constant τ (e.g., 'sec', 'ms').

    Notes
    -----
    - All x=0 values are excluded to avoid division-by-zero errors.
    - The fit curve is plotted against 1/x values, and the original histogram is also transformed accordingly.
    - A vertical dashed line at x=0 is retained for visual reference.
    """

    plt.figure(figsize=(6, 4))

    # Exclude x=0 to avoid divide-by-zero
    nonzero_mask = bin_centers != 0
    inv_bin_centers = 1 / bin_centers[nonzero_mask]
    inv_hist_vals = hist_vals[nonzero_mask]

    # Plot histogram on inverted x-axis
    plt.bar(inv_bin_centers, inv_hist_vals, width=0.05, alpha=0.5, label="histogram")
    plt.axvline(0, linestyle="--", color="blue", label="x=0")

    # Plot the fitted exponential curve, if available
    if fit_result is not None:
        x_fit = np.linspace(min(xData), max(xData), 100)
        x_fit = x_fit[x_fit != 0]  # Exclude zero
        inv_x_fit = 1 / x_fit
        y_fit = fit_result.eval(x=x_fit)
        plt.plot(inv_x_fit, y_fit, "r-", label=f"Exp Fit (τ={time_delay:.2f} {time_unit}, R²={r_squared:.3f})")

    # Axis labels and title
    plt.xlabel("1 / Time delay (1/frame)")
    plt.ylabel("Count")

    if r_squared is not None:
        plt.title(f"Exp Fit (R²={r_squared:.3f})")
    else:
        plt.title("Exp Fit (R²=N/A)")

    plt.legend()
    plt.tight_layout()
    plt.show()

def load_multistack(imagefile, nChan=None):
    """
    Load a 4D TIFF image stack and reorder axes to (width, height, time, channel).

    This function assumes the TIFF is stored in the order:
        (frames, channels, height, width)
    as is common in BioFormats exports or ImageJ multichannel exports.

    The returned array will be reordered to match MATLAB's convention:
        (nx, ny, frameCnt, nChan)

    Parameters
    ----------
    imagefile : str
        Path to the TIFF file.

    nChan : int or None, optional
        If specified, manually enforce the number of channels (e.g., 3 or 4).
        If None, the number of channels is inferred from the file.
        If the specified nChan does not match the file's channels, a warning is printed and the file's value is used.

    Returns
    -------
    multistack : np.ndarray
        A 4D NumPy array with shape (nx, ny, frameCnt, nChan),
        corresponding to (width, height, time, channels).
    """

    # Step 1: Load the TIFF file as a 4D array
    stack0 = tifffile.imread(imagefile)

    if stack0.ndim != 4:
        raise ValueError(
            f"Expected a 4D TIFF (frames, channels, height, width). "
            f"Got shape {stack0.shape}."
        )

    # Extract shape dimensions
    frames, channels, ny, nx = stack0.shape
    print(f"Loaded TIFF shape: (frames={frames}, channels={channels}, "
          f"height={ny}, width={nx})")
    print(f"stack0 dtype: {stack0.dtype}")

    # Step 2: Determine number of channels
    if nChan is None:
        nChan = channels
    if nChan != channels:
        print(f"Warning: The file has {channels} channels, but nChan={nChan} was given. "
              f"Using {channels} from the file.")
        nChan = channels

    # Step 3: Reorder dimensions to match MATLAB convention
    # From: (frames, channels, height, width)
    # To:   (width, height, frames, channels)
    multistack = stack0.transpose(3, 2, 0, 1)

    # Final check
    print(f"Final multistack shape: {multistack.shape} "
          "(nx, ny, frames, channels)")

    return multistack

def analyze_dark_pairs(multistack, pair, settings, frameToShow=100, visualize=True):
    """
    Identify colocalized spot pairs that fall into 'dark' regions of the RICM channel.

    Steps performed:
        1. Extract the RICM channel and create binary masks of dark regions via morphological filtering.
        2. Determine whether both green and red spots in each pair fall into dark areas.
        3. Optionally visualize the detected regions and filtered results.
        4. Return filtered DataFrames or arrays of 'on-dark' and 'not-on-dark' spot pairs.

    Parameters
    ----------
    multistack : np.ndarray
        4D image stack with shape (nx, ny, frameCnt, nChan).

    pair : pd.DataFrame or np.ndarray
        Colocalized spot pairs. If a DataFrame is passed, the function returns DataFrames as well.
        Expected columns/order: ['time_delay', 'distance', ..., 'g_T', 'g_X', 'g_Y', ..., 'r_T', 'r_X', 'r_Y'].

    settings : dict
        Analysis settings, must include:
        - "RICMchannel" : 1-based index of the RICM channel in the stack
        - "frameInterval" : (optional) time per frame in seconds

    frameToShow : int, default=100
        Frame index to visualize when `visualize=True`.

    visualize : bool, default=True
        Whether to display a visualization of dark regions and spot locations.

    Returns
    -------
    pairOnDark : pd.DataFrame or np.ndarray
        Pairs where both the green and red spot fall within dark regions (±1 frame).

    pairNotOnDark : pd.DataFrame or np.ndarray
        Pairs that do not meet the dark region criteria.

    maskStack : np.ndarray (bool)
        Binary mask stack of dark regions. Shape: (nx, ny, frameCnt)

    Notes
    -----
    - Dark regions are defined using Top-hat and Otsu thresholding followed by morphological filtering.
    - Each spot (green or red) is checked for presence within dark areas in frame ±1 of its appearance.
    - This is designed for RICM images, where dark regions typically indicate tight cell-surface contact.

    """

    # 0) Convert DataFrame to NumPy array if needed
    columns = None
    if isinstance(pair, pd.DataFrame):
        columns = pair.columns
        pair = pair.to_numpy()

    nx, ny, frameCnt, nChan = multistack.shape
    RICMchannel = settings["RICMchannel"] - 1  # Convert to 0-based index
    frameInterval = settings.get("frameInterval", 1.0)

    # 1) Extract RICM channel → shape: (nx, ny, frameCnt)
    ricmStack = multistack[:, :, :, RICMchannel]

    # 2) Generate binary masks of dark regions
    maskStack = np.zeros((nx, ny, frameCnt), dtype=bool)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21))

    for fr in range(frameCnt):
        # Normalize frame and invert contrast
        img = cv2.normalize(ricmStack[:, :, fr], None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        inv = cv2.bitwise_not(img)
        top_hat = cv2.morphologyEx(inv, cv2.MORPH_TOPHAT, kernel)
        th_comp = cv2.bitwise_not(top_hat)

        # Binary threshold + postprocessing
        _, bw = cv2.threshold(th_comp, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        bw = cv2.bitwise_not(bw)
        bw = cv2.medianBlur(bw, 3)
        bw = cv2.dilate(bw, kernel)
        fill = morphology.remove_small_holes(bw > 0, area_threshold=1000)
        maskStack[:, :, fr] = fill

    # 3) Check if both green and red spots fall on dark regions (±1 frame)
    pairCnt = pair.shape[0]
    onDark = np.zeros(pairCnt, dtype=bool)

    for i in range(pairCnt):
        fr_g = int(round(pair[i, 3]))
        xg   = int(round(pair[i, 4]))
        yg   = int(round(pair[i, 5]))
        fr_r = int(round(pair[i, 7]))
        xr   = int(round(pair[i, 8]))
        yr   = int(round(pair[i, 9]))

        def in_mask(x, y, f):
            if x < 0 or x >= nx or y < 0 or y >= ny or f < 0 or f >= frameCnt:
                return False
            return maskStack[x, y, f]

        g_dark = any(in_mask(xg, yg, fr_g + dt) for dt in [-1, 0, 1])
        r_dark = any(in_mask(xr, yr, fr_r + dt) for dt in [-1, 0, 1])
        if g_dark and r_dark:
            onDark[i] = True

    pairOnDark_np = pair[onDark]
    pairNotOnDark_np = pair[~onDark]

    # 4) Convert back to DataFrame if original input was DataFrame
    if columns is not None:
        pairOnDark = pd.DataFrame(pairOnDark_np, columns=columns)
        pairNotOnDark = pd.DataFrame(pairNotOnDark_np, columns=columns)
    else:
        pairOnDark = pairOnDark_np
        pairNotOnDark = pairNotOnDark_np

    # Optional check for debugging: consistency between mask and onDark result
    for i in range(pairCnt):
        y, x, frame = int(pair[i, 5]), int(pair[i, 4]), int(pair[i, 3])  # Note: y, x order
        if maskStack[x, y, frame] != onDark[i]:
            print(f"Mismatch at index {i}: maskStack={maskStack[x, y, frame]}, onDark={onDark[i]}")

    # 5) Visualization
    if visualize:
        visualize_dark_pairs(ricmStack, maskStack, pair, onDark, frameToShow)

    return pairOnDark, pairNotOnDark, maskStack

def visualize_dark_pairs(ricmStack, maskStack, pair, onDark, frameToShow):
    """
    Visualize a specific frame of the RICM image along with its dark region mask
    and overlaid spot pair locations.

    This visualization helps verify whether detected spot pairs fall within
    morphologically defined dark regions, typically corresponding to close
    cell-substrate contact zones in RICM.

    Parameters
    ----------
    ricmStack : np.ndarray
        The extracted RICM image stack with shape (nx, ny, frameCnt).

    maskStack : np.ndarray (bool)
        The binary mask stack indicating dark regions (same shape as ricmStack).

    pair : np.ndarray
        Array of colocalized spot pairs. Columns 4 and 5 are assumed to contain
        the X and Y coordinates of the green (reference) spot.

    onDark : np.ndarray (bool)
        Boolean array of same length as `pair` indicating whether each pair
        is located on a dark region.

    frameToShow : int
        Index of the frame to visualize.

    Returns
    -------
    None
    """

    # Create a 1x3 subplot: [RICM Image, Dark Mask, Overlay]
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # 1) Show the original RICM image
    axs[0].imshow(ricmStack[:, :, frameToShow].T, cmap='gray', origin='lower')
    axs[0].set_title(f"RICM Image (Frame {frameToShow})")
    axs[0].axis("off")

    # 2) Show the binary dark region mask
    axs[1].imshow(maskStack[:, :, frameToShow].T, cmap='binary', origin='lower')
    axs[1].set_title(f"Dark Region Mask (Frame {frameToShow})")
    axs[1].axis("off")

    # 3) Overlay the green spot positions: red = onDark, blue = not onDark
    axs[2].imshow(ricmStack[:, :, frameToShow].T, cmap='gray', origin='lower')

    axs[2].scatter(pair[onDark, 4], pair[onDark, 5], c='red', s=10, label="On Dark")
    axs[2].scatter(pair[~onDark, 4], pair[~onDark, 5], c='blue', s=10, label="Not On Dark")

    axs[2].set_title("Pairs (Red: On Dark, Blue: Not On Dark)")
    axs[2].axis("off")
    axs[2].legend()

    plt.tight_layout()
    plt.show()

def main_data_loading(criInter=0.6, criIntra=0.6, criOverlap=2.5, criBlink=2.0):
    """
    Load the necessary data files for time-delay and co-localization analysis.

    This function performs the following steps:
    1. Load analysis settings with given filter thresholds.
    2. Change working directory to `path0`.
    3. Let the user select a TIFF image file for overlay.
    4. Automatically detect or prompt the user to load TrackMate XML files for green and red tracks.
    5. Parse the XML files and convert to DataFrames.
    6. Extract metadata including time interval.

    Parameters
    ----------
    criInter : float
        Maximum allowed distance between green/red tracks for colocalization.

    criIntra : float
        Radius threshold for intra-channel duplicate removal.

    criOverlap : float
        Distance threshold for excluding spatially overlapping tracks.

    criBlink : float
        Maximum tolerated blinking gap during spot tracking.

    Returns
    -------
    df_g : pd.DataFrame
        DataFrame of green-channel (reference) track detections.

    df_r : pd.DataFrame
        DataFrame of red-channel (comparison) track detections.

    imagefile : str
        Full path to the selected TIFF image.

    pathImgFile : str
        Directory containing the TIFF image.

    filenamehead : str
        Image filename without extension (used as a base name for finding XMLs).

    metadata_g : dict
        Metadata parsed from the green-channel XML (includes units, frame interval, etc).

    settings : dict
        Dictionary of all analysis parameters.
    """

    # Load default analysis settings
    settings = get_analysis_settings(criInter, criIntra, criOverlap, criBlink)
    path0 = settings["path0"]

    # Check if base path exists
    if not os.path.exists(path0):
        print(f"Warning: path0 does not exist: {path0}")
    os.chdir(path0)

    # Early exit if testing mode is enabled
    if settings["fTest"]:
        print("Skipping data loading due to fTest=True.")
        return None, None, None, None, None, settings

    # Let user choose a TIFF image file
    Tk().withdraw()
    imagefile = filedialog.askopenfilename(
        title="Open an image file to overlay",
        filetypes=[("TIFF files", "*.tif")]
    )
    if not imagefile:
        print("No image selected. Abort.")
        return None, None, None, None, None, settings

    pathImgFile, imageName = os.path.split(imagefile)
    filenamehead, _ = os.path.splitext(imageName)

    # Expected XML filenames based on TIFF basename
    trjfile1 = os.path.join(pathImgFile, f"{filenamehead}_g.xml")
    trjfile2 = os.path.join(pathImgFile, f"{filenamehead}_r.xml")

    # If not found, manually prompt user for XML files
    if not (os.path.isfile(trjfile1) and os.path.isfile(trjfile2)):
        print("Select reference track file (g) & compare track file (r).")
        trjfile1 = filedialog.askopenfilename(
            title="Choose a reference track file (high force)",
            filetypes=[("XML files", "*.xml")])
        trjfile2 = filedialog.askopenfilename(
            title="Choose a track file to be compared (low force)",
            filetypes=[("XML files", "*.xml")])
        if not trjfile1 or not trjfile2:
            print("Canceled track file selection.")
            return None, None, None, None, None, settings
        pathTracks = os.path.dirname(trjfile1)
        os.chdir(pathTracks)
    else:
        os.chdir(pathImgFile)

    # Load XML files
    print("Loading TrackMate xml files...")
    t0 = time.time()
    tracks1, metadata_g = load_tracks_xml(trjfile1)
    tracks2, metadata_r = load_tracks_xml(trjfile2)

    # Convert track lists to flat DataFrames
    df_g = tracks_to_dataframe(tracks1)
    df_r = tracks_to_dataframe(tracks2)

    elapsed = time.time() - t0
    print(f"XML load done in {elapsed:.2f}s")
    print("Reference (g) tracks:", df_g["track_id"].nunique())
    print("Compare (r) tracks:", df_r["track_id"].nunique())

    # Print key metadata
    print("spaceUnits:", metadata_g["spaceUnits"])
    print("timeUnits:", metadata_g["timeUnits"])
    frameInterval = round(metadata_g["frameInterval"], 2)
    print("frameInterval:", frameInterval)

    return df_g, df_r, imagefile, pathImgFile, filenamehead, metadata_g, settings

def run_main_filtering(df_g, df_r, settings):
    """
    Apply a series of spatial and temporal filters to green/red track data.

    The filtering steps include:
    1. First appearance and duration-based pre-filtering.
    2. ROI (Region of Interest) and FOI (Frame of Interest) filtering.
    3. Early-activation filtering to remove prematurely appearing tracks.
    4. Intra-channel redundancy filtering via "revived" spot suppression.
    5. Spatial de-duplication via overlapping exclusion.
    6. Optional chromatic lateral offset correction (e.g., optical misalignment).

    Parameters
    ----------
    df_g : pd.DataFrame
        Green channel track detections (columns: track_id, T, X, Y).

    df_r : pd.DataFrame
        Red channel track detections (columns: track_id, T, X, Y).

    settings : dict
        Dictionary of filtering thresholds and metadata from `get_analysis_settings()`.

    Returns
    -------
    df_g_ol : pd.DataFrame
        Final filtered green-channel track summary (one row per track).

    df_r_ol : pd.DataFrame
        Final filtered red-channel track summary (one row per track).

    df_g_roi : pd.DataFrame
        Intermediate green tracks within ROI & FOI, before revived/overlap filtering.

    df_r_roi : pd.DataFrame
        Intermediate red tracks within ROI & FOI, before revived/overlap filtering.
    """

    # 1. Compute first and last appearance frame of each track
    df_first_g = compute_first_and_last_appearance(df_g, settings["criLengthMin1"], method="mean")
    df_first_r = compute_first_and_last_appearance(df_r, settings["criLengthMin2"], method="median")

    print(f"After filtering, df_first_g: {len(df_first_g)} tracks, df_first_r: {len(df_first_r)} tracks")

    # 2. Apply ROI (spatial) and FOI (temporal) filters
    t_max = settings["frameCnt"] - settings["interval"]

    df_g_roi = apply_roi_foi_filter(df_first_g,
                                    settings["interval"], t_max,
                                    settings["roi"])
    df_r_roi = apply_roi_foi_filter(df_first_r,
                                    2, t_max,
                                    settings["roi"])

    # 3. Remove tracks that appeared too early (based on activation delay)
    df_g_1, df_r_1 = apply_already_activated_filter(df_g_roi, df_r_roi, settings["interval"])
    
    print(f"After filtering, df_g_1: {len(df_g_1)} tracks, df_r_1: {len(df_r_1)} tracks")

    # 4. Apply revived spot filter to avoid duplicate detections within intra-channel radius
    df_g_rev = apply_revived_filter_matlab(df_g_1, settings["criIntra"])
    df_r_rev = apply_revived_filter_matlab(df_r_1, settings["criIntra"])

    print(f"After filtering, df_g_rev: {len(df_g_rev)} tracks, df_r_rev: {len(df_r_rev)} tracks")

    # 5. Remove overlapping tracks within each channel
    df_g_ol = apply_overlapping_filter_matlab(df_g_rev, settings["criOverlap"])
    df_r_ol = apply_overlapping_filter_matlab(df_r_rev, settings["criOverlap"])

    # 6. Apply chromatic offset correction (e.g., if red channel is misaligned)
    df_r_ol["X_center"] += settings["lateralOffset"][0]
    df_r_ol["Y_center"] += settings["lateralOffset"][1]

    print(f"After filtering, df_g_ol: {len(df_g_ol)} tracks, df_r_ol: {len(df_r_ol)} tracks")

    return df_g_ol, df_r_ol, df_g_roi, df_r_roi

def main_time_delay_fitting(pair_df: pd.DataFrame,
                            metadata: dict,
                            settings: dict):
    """
    Perform exponential fitting on time-delay histogram from paired colocalized spots.

    This function replicates the MATLAB "Time Delay Fitting" block and performs:
    1. Histogram construction of time delays (bin width = 1 frame)
    2. Exponential decay fitting on bins from frame +1 to +15 (default)
    3. Calculation of decay constant (τ) and R² goodness-of-fit
    4. Visualization of raw histogram with fitted exponential curve
    5. Optional inverse-X axis visualization for alternative insight

    Parameters
    ----------
    pair_df : pd.DataFrame
        DataFrame containing colocalized spot pairs with at least 'ref_T' and 'cmp_T' columns.

    metadata : dict
        Metadata extracted from XML, including 'frameInterval'.

    settings : dict
        Dictionary with analysis settings, particularly:
        - 'TimeDelayFitting': number of bins to use for exponential fitting.

    Returns
    -------
    None
    """

    # Early exit if DataFrame is empty or malformed
    if pair_df.empty or "time_delay" not in pair_df.columns:
        print("pair_df is empty or missing 'time_delay' column.")
        return

    # Recalculate time delay explicitly from ref_T and cmp_T (rounded to frame integers)
    pair_df["time_delay"] = (
        pair_df["ref_T"].round().astype(int) -
        pair_df["cmp_T"].round().astype(int)
    )

    # Histogram edges and bin centers from -35.5 to +35.5, step 1
    bin_edges   = np.arange(-35.5, 36.5, 1)        # Edges: include both ends
    bin_centers = bin_edges[:-1] + 0.5             # Centers: ..., -0.5, 0.5, ...
    hist_vals, _ = np.histogram(pair_df["time_delay"], bins=bin_edges)

    # Retrieve frame interval from metadata (default = 1.0 sec)
    frameInterval = metadata.get("frameInterval", 1.0)

    # Locate index of bin centered at 0
    idx0 = np.where(np.isclose(bin_centers, 0))[0][0]

    # Number of bins to fit (default = 15 bins to the right of zero)
    fit_len = settings.get("TimeDelayFitting", 15)
    start_idx = idx0 + 1
    end_idx = start_idx + fit_len
    if end_idx > len(bin_centers):
        end_idx = len(bin_centers)

    # Extract fitting range
    xData = bin_centers[start_idx:end_idx]
    yData = hist_vals[start_idx:end_idx]

    # Fit exponential decay model: (a - c) * exp(-x / b) + c
    fit_result, tau, r2 = fit_time_delay_exp1decay(
        xData, yData, frameInterval, min_points=6
    )

    # Plot histogram and fitted curve (normal x-axis)
    plot_time_delay_exp1decay(
        bin_centers, hist_vals, xData, yData,
        fit_result, tau, r2, time_unit="sec"
    )

    # Plot with inverse x-axis (1/x) for visualizing early delays
    plot_time_delay_exp1decay_inverse_x(
        bin_centers, hist_vals, xData, yData,
        fit_result, tau, r2, time_unit="sec"
    )

def main_analysis_after_dark_pairs(imagefile, pair, metadata, settings):
    """
    Perform post-colocalization analysis by separating pairs based on RICM-derived dark regions.

    This function does the following:
    1. Load a 4D multi-channel TIFF image (RICM + fluorescence).
    2. Segment dark regions in the RICM channel.
    3. Classify colocalized pairs as being in 'dark' or 'non-dark' zones.
    4. Perform time-delay histogram fitting separately for both categories.
    5. Return the image and filtered data for further use.

    Parameters
    ----------
    imagefile : str or Path
        Path to the input 4D TIFF image file.

    pair : pd.DataFrame
        DataFrame of colocalized spot pairs with coordinates and frame information.

    metadata : dict
        Metadata associated with time units, frame interval, etc.

    settings : dict
        Analysis configuration including channel indices and thresholds.

    Returns
    -------
    multistack : np.ndarray
        Loaded image stack in shape (nx, ny, frames, channels).

    pairOnDark : pd.DataFrame
        Subset of input `pair` that falls within dark regions in RICM.
    """

    # Load multi-channel image stack (shape: nx, ny, T, C)
    multistack = load_multistack(imagefile)

    # Segment RICM dark regions and classify pairs based on overlap
    pairOnDark, pairNotOnDark, maskStack = analyze_dark_pairs(multistack, pair, settings)

    # Time delay fitting for pairs on dark region
    print("pairs on dark")
    main_time_delay_fitting(pairOnDark, metadata, settings)

    # Time delay fitting for pairs not on dark region
    print("pairs not on dark")
    main_time_delay_fitting(pairNotOnDark, metadata, settings)

    return multistack, pairOnDark

In [None]:
def making_mark_start(df_g, df_r, pairOnDark):
    """
    Extract initial timepoints of tracks involved in colocalized dark-region pairs.

    For each colocalized pair (from 'pairOnDark'), this function:
    - Retrieves the green (ref) and red (cmp) track IDs
    - Finds the first frame ('T') where each track appears
    - Stores those values for downstream time alignment or visualization

    Parameters
    ----------
    df_g : pd.DataFrame
        Flattened DataFrame of green-channel tracks. Must contain 'track_id' and 'T'.

    df_r : pd.DataFrame
        Flattened DataFrame of red-channel tracks. Must contain 'track_id' and 'T'.

    pairOnDark : pd.DataFrame
        Filtered colocalized pair DataFrame (typically from dark regions only),
        must include 'ref_id', 'ref_X', 'ref_Y', and 'cmp_id'.

    Returns
    -------
    df_mark : pd.DataFrame
        A new DataFrame with:
        - ref_id, ref_X, ref_Y from pairOnDark
        - first_g_T: first time point of green track
        - first_r_T: first time point of red track
    """

    # Initialize output DataFrame with position info
    df_mark = pairOnDark[["ref_id", "ref_X", "ref_Y"]].copy()

    # Add columns to store first frame of green and red tracks
    df_mark["first_T"] = np.nan  # (optional legacy field, not currently used)
    df_mark["last_T"] = np.nan   # (optional legacy field, not currently used)

    df_mark["first_g_T"] = np.nan
    df_mark["first_r_T"] = np.nan

    # Loop over all paired entries
    for i in range(len(pairOnDark)):
        # Retrieve T values for ref (green) and cmp (red) track_id
        g_id = pairOnDark["ref_id"].iloc[i]
        r_id = pairOnDark["cmp_id"].iloc[i]

        set_g_T = df_g.loc[df_g["track_id"] == g_id, "T"]
        set_r_T = df_r.loc[df_r["track_id"] == r_id, "T"]

        # Save the first frame for each track
        if not set_g_T.empty:
            df_mark.at[i, "first_g_T"] = sorted(set_g_T)[0]
        if not set_r_T.empty:
            df_mark.at[i, "first_r_T"] = sorted(set_r_T)[0]

    return df_mark

def marker_df_2(
    multistack,
    df_mark_1, df_mark_2,
    alpha=0.6,
    green_gain=1.0,
    red_gain=1.0,
    green_contrast=1.0,
    red_contrast=1.0,
    green_gamma=1.0,
    red_gamma=1.0,
    use_gamma=False,
    marker_radius=5,
    radius_px=13,
    max_workers=16,
    use_combined=True,
):
    '''
    Generate marker patches from multichannel image stack by overlaying green/red signals
    onto RICM background. Visual indicators are generated frame-wise for each track ID.

    Parameters:
    ------------
    multistack : np.ndarray
        4D image stack (nx, ny, n_frames, n_channels).
    df_mark_1, df_mark_2 : pd.DataFrame
        DataFrames containing markers with columns: ref_id, ref_X, ref_Y, first_g_T, first_r_T.
    alpha : float
        Blending ratio between RICM and overlay image.
    green_gain, red_gain : float
        Gain multipliers for green and red channels.
    green_contrast, red_contrast : float
        Contrast enhancement factors.
    green_gamma, red_gamma : float
        Gamma correction values.
    use_gamma : bool
        Whether to apply gamma correction.
    marker_radius : int
        Radius of circle marker drawn on the patch.
    radius_px : int
        Width and height of extracted patch in pixels.
    max_workers : int
        Number of threads to parallelize frame-wise processing.
    use_combined : bool
        If True, blend red/green with RICM background; otherwise, channel-specific only.

    Returns:
    --------
    patch_df : pd.DataFrame
        DataFrame containing original coordinates and patch sequences for each marker.
    '''

    nx, ny, frameCnt, nChan = multistack.shape
    greenchannel = 1
    redchannel = 2
    RICMchannel = 0

    greenStack = multistack[:, :, :, greenchannel].astype(np.float32)
    redStack = multistack[:, :, :, redchannel].astype(np.float32)
    ricmStack = multistack[:, :, :, RICMchannel]

    # Merge both marker lists and remove duplicates
    combined_df = pd.concat([df_mark_1, df_mark_2]).drop_duplicates()

    # Determine group membership based on marker presence
    ids1 = set(df_mark_1["ref_id"])
    ids2 = set(df_mark_2["ref_id"])
    common_ids = ids1 & ids2
    only1_ids = ids1 - ids2
    only2_ids = ids2 - ids1

    def get_group_type(ref_id):
        if ref_id in common_ids:
            return "common"
        elif ref_id in only1_ids:
            return "only1"
        elif ref_id in only2_ids:
            return "only2"
        else:
            return "unknown"

    # Initialize patch_df with group information and an empty patch column
    patch_df = combined_df[["ref_id", "first_g_T", "first_r_T", "ref_X", "ref_Y"]].copy()
    patch_df["id_group"] = patch_df["ref_id"].apply(get_group_type)
    patch_df["patch"] = [pd.DataFrame(columns=["g_frame", "g_image", "r_frame", "r_image"]) for _ in range(len(patch_df))]

    # Define visual properties for different groups
    marker_groups = [
        (only2_ids, (0, 255, 0), "g"),
        (only2_ids, (0, 255, 0), "r"),
        (only1_ids, (0, 0, 255), "g"),
        (only1_ids, (0, 0, 255), "r"),
        (common_ids, (255, 255, 0), "g"),
        (common_ids, (255, 255, 0), "r"),
    ]

    def process_frame(frameToShow):
        # Generate marker patches for a single frame
        local_patch_list = [None for _ in range(len(patch_df))]

        # Extract individual frame data
        ricm = ricmStack[:, :, frameToShow]
        green = greenStack[:, :, frameToShow]
        red = redStack[:, :, frameToShow]

        # Normalize and apply contrast/gamma
        green = cv2.normalize(green, None, 0, 255, cv2.NORM_MINMAX)
        green *= green_contrast
        if use_gamma:
            green = 255 * ((green / 255) ** green_gamma)

        red = cv2.normalize(red, None, 0, 255, cv2.NORM_MINMAX)
        red *= red_contrast
        if use_gamma:
            red = 255 * ((red / 255) ** red_gamma)

        # Convert grayscale RICM to RGB
        ricm_norm = cv2.normalize(ricm, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        ricm_rgb = cv2.cvtColor(ricm_norm, cv2.COLOR_GRAY2RGB)

        # Prepare RGB overlays
        green_rgb = np.zeros_like(ricm_rgb, dtype=np.float32)
        green_rgb[:, :, 1] = np.clip(green_gain * green, 0, 255)

        red_rgb = np.zeros_like(ricm_rgb, dtype=np.float32)
        red_rgb[:, :, 0] = np.clip(red_gain * red, 0, 255)

        # Combine RICM and overlay images
        overlay_rgb = green_rgb + red_rgb
        blended = np.clip((1 - alpha) * ricm_rgb + alpha * overlay_rgb, 0, 255)
        blended_uint8 = blended.astype(np.uint8)

        # Extract patches around each marker
        for i, row in patch_df.iterrows():
            x = int(round(row["ref_X"]))
            y = int(round(row["ref_Y"]))
            half = radius_px // 2
            r1, r2 = max(x - half, 0), min(x + half + 1, nx)
            c1, c2 = max(y - half, 0), min(y + half + 1, ny)

            new_rows = []

            for ids_set, color, channel in marker_groups:
                if row["ref_id"] not in ids_set:
                    continue

                time_col = f"first_{channel}_T"
                frame_col = f"{channel}_frame"
                image_col = f"{channel}_image"

                if not (row[time_col] - 10 <= frameToShow <= row[time_col] + 10):
                    continue

                if use_combined:
                    base_img = blended_uint8
                elif channel == 'g':
                    base_img = green_rgb.astype(np.uint8)
                else:
                    base_img = red_rgb.astype(np.uint8)

                patch_img = base_img[r1:r2, c1:c2]
                pad_h = radius_px - patch_img.shape[0]
                pad_w = radius_px - patch_img.shape[1]
                top, bottom = pad_h // 2, pad_h - pad_h // 2
                left, right = pad_w // 2, pad_w - pad_w // 2
                patch_img = cv2.copyMakeBorder(patch_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
                center = (radius_px // 2, radius_px // 2)

                marked_img = patch_img.copy()
                cv2.circle(marked_img, center, marker_radius, color, thickness=1)

                new_row = {"g_frame": None, "g_image": None, "r_frame": None, "r_image": None}
                new_row[frame_col] = frameToShow
                new_row[image_col] = marked_img
                new_rows.append(new_row)

            if new_rows:
                patch_table = patch_df.at[i, "patch"]
                patch_table = pd.concat([patch_table, pd.DataFrame(new_rows)], ignore_index=True)
                local_patch_list[i] = patch_table

        return local_patch_list

    # Process all frames in parallel
    all_patch_tables = [None for _ in range(len(patch_df))]
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        for frame_result in tqdm(executor.map(process_frame, range(frameCnt)), total=frameCnt, desc="Processing frames"):
            for i, patch_table in enumerate(frame_result):
                if patch_table is not None:
                    if all_patch_tables[i] is None:
                        all_patch_tables[i] = patch_table
                    else:
                        all_patch_tables[i] = pd.concat([all_patch_tables[i], patch_table], ignore_index=True)

    # Merge g/r channels and match frames for final output
    def merge_patch_table(patch_table, row):
        g_df = patch_table[["g_frame", "g_image"]].dropna(subset=["g_frame"]).copy()
        r_df = patch_table[["r_frame", "r_image"]].dropna(subset=["r_frame"]).copy()
        g_df["frame"] = g_df["g_frame"].astype(int)
        r_df["frame"] = r_df["r_frame"].astype(int)
        merged = pd.merge(g_df, r_df, on="frame", how="outer").sort_values("frame").reset_index(drop=True)
        for col in ["g_frame", "r_frame", "g_image", "r_image"]:
            if col in merged.columns:
                merged[col] = merged[col].where(pd.notna(merged[col]), None)
        merged["is_first_g_T"] = merged["frame"] == row["first_g_T"]
        merged["is_first_r_T"] = merged["frame"] == row["first_r_T"]
        return merged[["frame", "g_frame", "g_image", "r_frame", "r_image", "is_first_g_T", "is_first_r_T"]]

    patch_df["patch"] = [
        merge_patch_table(tbl, patch_df.iloc[i]) if tbl is not None else pd.DataFrame(
            columns=["frame", "g_frame", "g_image", "r_frame", "r_image", "is_first_g_T", "is_first_r_T"]
        )
        for i, tbl in enumerate(all_patch_tables)
    ]

    return patch_df

def run_full_analysis(criInter, criIntra, criOverlap, criBlink):
    """
    Run the complete colocalization and dark-region-based analysis pipeline.

    This function orchestrates the full image-based track analysis:
    1. Loads tracking data and metadata from user-selected TIFF/XML files.
    2. Applies filtering to green/red tracks (ROI/FOI, revived, overlapping).
    3. Identifies colocalized green-red track pairs.
    4. Fits exponential decay to time delay histogram of pairs.
    5. Analyzes which pairs occur in dark regions of the RICM channel.
    6. Returns starting coordinates and frame index for selected tracks.

    Parameters
    ----------
    criInter : float
        Maximum distance to consider two tracks colocalized (inter-track threshold).
    criIntra : float
        Minimum distance between tracks to be considered distinct (intra-track threshold).
    criOverlap : float
        Distance threshold below which two tracks are considered overlapping and filtered.
    criBlink : float
        Duration-based threshold to filter out short-lived tracks (used for blinking filter).

    Returns
    -------
    df_mark : pd.DataFrame
        Table containing ref_id, coordinates, and start frame of marked green/red colocalized tracks
        that lie in dark regions of the RICM channel.
    """

    # Step 1: Load track data (from XML), image path, metadata, and analysis settings
    df_g, df_r, imagefile, pathImgFile, filenamehead, metadata, settings = main_data_loading(
        criInter, criIntra, criOverlap, criBlink
    )
    
    # Step 2: Apply filtering on green/red track data (ROI/FOI + revived + overlap removal)
    df_g_ol, df_r_ol, df_first_g, df_first_r = run_main_filtering(df_g, df_r, settings)
    
    # Step 3: Detect primary green-red colocalized pairs based on proximity
    primary_pair_df = find_colocalized_pairs_matlab(df_g_ol, df_r_ol, settings["criInter"])

    print("before loading RICM image")

    # Step 4: Perform time delay histogram analysis and exponential decay fitting (τ, R²)
    main_time_delay_fitting(primary_pair_df, metadata, settings)

    print("after loading RICM image")

    # Step 5: Load full image stack and filter colocalized pairs that overlap with dark regions in RICM
    multistack, pairOnDark = main_analysis_after_dark_pairs(imagefile, primary_pair_df, metadata, settings)

    # Step 6: For each selected colocalized pair, extract the green/red track's first appearance frame
    df_mark = making_mark_start(df_g, df_r, pairOnDark)

    return multistack, df_mark

def plot_patch_df_grid_grouped(patch_df, pad=10, max_workers=16):
    """
    Visualize side-by-side temporal patches of green and red channel images for each tracked pair.

    This function:
    - Aligns patches temporally around the first appearance frame of each marker
    - Displays a row of green patches and a row of red patches per track
    - Uses color-coded bounding boxes to highlight the center (first_T) frame
    - Supports parallel rendering for speed

    Parameters
    ----------
    patch_df : pd.DataFrame
        DataFrame containing per-track patch data (output from `marker_df_2()`).

    pad : int, default=10
        Number of frames to include before and after the center frame.

    max_workers : int, default=16
        Number of threads for parallel rendering.

    Returns
    -------
    figures : list of matplotlib.figure.Figure
        A list of rendered matplotlib figures, each visualizing one track’s green/red patch series.
    """
    def center_channel_table(patch_table, channel='g', pad=10):
        """
        Extract and center the image patch table around the marker's first appearance frame.

        Parameters
        ----------
        patch_table : pd.DataFrame
            Table of frames and images for a given marker.
        channel : str
            Either 'g' or 'r' for green/red channel.
        pad : int
            How many frames before and after the center to show.

        Returns
        -------
        df_merged : pd.DataFrame
            Table with fixed offsets around center frame, padded with NaNs if needed.
        """
        frame_col = f"{channel}_frame"
        image_col = f"{channel}_image"
        highlight_col = f"is_first_{channel}_T"

        df = patch_table[[frame_col, image_col, highlight_col]].copy()
        df.columns = ["frame", "image", "highlight"]
        df = df[pd.notna(df["frame"])].sort_values("frame").reset_index(drop=True)

        # Identify the frame where the marker first appears
        center_rows = df[df["highlight"]]
        if center_rows.empty:
            return None
        center_frame = center_rows["frame"].iloc[0]

        # Add relative offset from center frame
        df["offset"] = df["frame"] - center_frame
        df = df[df["offset"].between(-pad, pad)]

        # Create full offset range and merge to maintain fixed frame layout
        full_offsets = np.arange(-pad, pad + 1)
        df_full = pd.DataFrame({"offset": full_offsets})
        df_merged = df_full.merge(df, on="offset", how="left")
        return df_merged

    def render_single_row(row, pad, n_cols):
        """
        Render a single row (i.e., one marker) as a 2-row image grid with green/red patch timelines.

        Parameters
        ----------
        row : pd.Series
            A row from the patch_df.
        pad : int
            Padding around center frame.
        n_cols : int
            Total number of columns to display (2*pad + 1).

        Returns
        -------
        fig : matplotlib.figure.Figure
            A rendered figure for one marker.
        """
        patch_table = row["patch"]
        ref_id = row["ref_id"]

        g_table = center_channel_table(patch_table, 'g', pad=pad)
        r_table = center_channel_table(patch_table, 'r', pad=pad)

        fig, axes = plt.subplots(2, n_cols, figsize=(20, 3))
        axes = axes.reshape(2, n_cols)

        # Plot each frame's image for both channels
        for i in range(n_cols):
            for j, (table, color, ch) in enumerate([(g_table, 'green', 'g'), (r_table, 'red', 'r')]):
                ax = axes[j, i]
                if table is not None and isinstance(table.at[i, "image"], np.ndarray):
                    img = table.at[i, "image"]
                    ax.imshow(img)

                    frame_val = table.at[i, "frame"]
                    if pd.notna(frame_val):
                        ax.set_title(f"{ch}{int(frame_val)}", fontsize=6)

                    # Highlight the first appearance frame
                    if table.at[i, "highlight"]:
                        ax.add_patch(patches.Rectangle((0, 0), img.shape[1], img.shape[0],
                                                       linewidth=2, edgecolor=color, facecolor='none'))
                ax.axis("off")

        # Add reference ID label at the top center
        mid_ax = axes[0, n_cols // 2]
        mid_ax.text(0.5, 1.1, f"ref_id: {ref_id}", ha='center', va='bottom',
                    transform=mid_ax.transAxes, fontsize=8, color='black')

        fig.subplots_adjust(left=0.01, right=0.99, top=0.85, bottom=0.05, wspace=0.05, hspace=0.05)
        return fig

    # Sort DataFrame by ID group for consistent visual grouping
    patch_df = patch_df.copy()
    patch_df['id_group'] = pd.Categorical(patch_df['id_group'], categories=["common", "only1", "only2"], ordered=True)
    patch_df = patch_df.sort_values("id_group").reset_index(drop=True)

    n_cols = pad * 2 + 1  # Number of columns in grid (before/after center frame)

    figures = []

    # Parallel processing for each marker row
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(render_single_row, row, pad, n_cols) for _, row in patch_df.iterrows()]
        for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Rendering"):
            try:
                fig = f.result()
                if fig:
                    figures.append(fig)
            except Exception as e:
                print(f"[Error] Failed to render figure: {e}")

    return figures

def combine_figures_to_one_subplot(figures, dpi=100):
    """
    Combine a list of matplotlib figures into a single vertically stacked image.

    This function:
    - Converts each individual matplotlib figure into an RGB image
    - Stacks all images vertically into a single canvas
    - Displays the result as one subplot

    Parameters
    ----------
    figures : list of matplotlib.figure.Figure
        List of matplotlib figures to be combined.

    dpi : int, default=100
        Resolution for displaying the final combined figure.

    Returns
    -------
    None
        The function directly displays the final combined image using matplotlib.
    """
    rendered_imgs = []

    for fig in figures:
        # Convert the matplotlib figure to a NumPy array using backend canvas
        canvas = FigureCanvas(fig)
        canvas.draw()

        # Retrieve RGBA buffer and convert to RGB (drop alpha)
        width, height = fig.canvas.get_width_height()
        img = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape((height, width, 4))[:, :, :3]
        rendered_imgs.append(img)

        plt.close(fig)  # Close the figure to free memory

    # Calculate total height and maximum width for the final canvas
    total_height = sum(img.shape[0] for img in rendered_imgs)
    max_width = max(img.shape[1] for img in rendered_imgs)

    # Create a blank white canvas
    final_canvas = np.ones((total_height, max_width, 3), dtype=np.uint8) * 255

    # Copy each image into the final canvas (top-down)
    y = 0
    for img in rendered_imgs:
        h, w = img.shape[:2]
        final_canvas[y:y + h, :w] = img
        y += h

    # Display the final combined image
    plt.figure(figsize=(max_width / dpi, total_height / dpi))
    plt.imshow(final_canvas)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

#### Here is an example of the execution pipeline for the code:

In [None]:
# Run for criOverlap = 1.0
multistack_10, df_mark_10 = run_full_analysis(0.6, 0.6, 1.0, 2.0)

# Run for criOverlap = 2.5
multistack_25, df_mark_25 = run_full_analysis(0.6, 0.6, 2.5, 2.0)

In [None]:
patch_df = marker_df_2(
    multistack_10,
    df_mark_10, df_mark_25,
    alpha=1.0,
    green_gain=1.0,
    red_gain=1.0,
    green_contrast=1.0,
    red_contrast=1.0,
    green_gamma=1.0,
    red_gamma=1.0,
    use_gamma=False,
    marker_radius=5,
    radius_px=13,
    max_workers=16,
    use_combined=False,
)

In [None]:
patch_df = patch_df.reset_index(drop=True)

In [None]:
patch_table = patch_df.loc[patch_df["first_g_T"] - patch_df["first_r_T"] > 0].copy().reset_index(drop=True)

In [None]:
figures = plot_patch_df_grid_grouped(patch_table, pad=10)

In [None]:
combine_figures_to_one_subplot(figures)

#### Here are functions reserved for future improvements and not currently in use.

In [None]:
def apply_revived_filter(df_first: pd.DataFrame, cri_intra: float) -> pd.DataFrame:
    """
    반경 cri_intra 이내에 있는 점들 중에서 가장 먼저 등장한 것만 남긴다.
    KDTree 기반 greedy 방식으로 1개씩 선택하며 중복 제거.
    """
    coords = df_first[["X_center", "Y_center"]].to_numpy()
    times = df_first["T_first"].to_numpy()

    # 빠른 등장 순으로 정렬 (동시간이면 index 기준으로 안정정렬)
    order = np.lexsort((np.arange(len(times)), times))
    used = np.full(len(df_first), False)
    keep_indices = []

    tree = cKDTree(coords)

    for i in order:
        if used[i]:
            continue
        keep_indices.append(i)
        neighbors = tree.query_ball_point(coords[i], r=cri_intra)
        used[neighbors] = True  # 자신과 반경 이웃 모두 제거

    return df_first.iloc[keep_indices].reset_index(drop=True)

def apply_revived_filter_v2(df: pd.DataFrame, cri_intra: float, blinking_duration: float) -> pd.DataFrame:
    coords = df[["X_center", "Y_center"]].to_numpy()
    times_first = df["T_first"].to_numpy()
    times_last = df["T_last"].to_numpy()
    keep = np.full(len(df), True)

    order = np.lexsort((np.arange(len(times_first)), times_first))  # 시간 오름차순
    tree = cKDTree(coords)

    for i in order:
        if not keep[i]:
            continue
        neighbors = tree.query_ball_point(coords[i], r=cri_intra)
        for j in neighbors:
            if j == i or not keep[j]:
                continue

            # 양방향 blinking 판별
            blink_ij = times_first[i] - times_last[j]
            blink_ji = times_first[j] - times_last[i]

            if 0 < blink_ij <= blinking_duration:
                keep[j] = False
            elif 0 < blink_ji <= blinking_duration:
                keep[i] = False

    return df[keep].reset_index(drop=True)

def apply_overlapping_filter(df: pd.DataFrame, cri_overlap: float) -> pd.DataFrame:
    """
    KDTree 기반 greedy 방식으로 반경 내 overlapping spot 중 하나만 남긴다.
    중복 없이 선택하며, T_first가 있으면 우선순위로 사용한다.
    """
    coords = df[["X_center", "Y_center"]].to_numpy()
    times = df["T_first"].to_numpy() if "T_first" in df.columns else np.zeros(len(df))
    
    # 우선순위 기준 정렬: T_first가 빠른 순
    order = np.lexsort((np.arange(len(times)), times))  # stable sort
    used = np.full(len(df), False)
    keep_indices = []

    tree = cKDTree(coords)

    for i in order:
        if used[i]:
            continue
        keep_indices.append(i)
        neighbors = tree.query_ball_point(coords[i], r=cri_overlap)
        used[neighbors] = True  # 이웃 모두 제거 (자기 자신 포함)

    return df.iloc[keep_indices].reset_index(drop=True)

def apply_overlapping_filter_v2(df: pd.DataFrame, cri_overlap: float) -> pd.DataFrame:
    """
    KDTree 기반 greedy 방식으로 반경 내 overlapping spot 중 하나만 남긴다.
    중복 없이 선택하며, T_first가 있으면 우선순위로 사용한다.
    """
    coords = df[["X_center", "Y_center"]].to_numpy()
    times_first = df["T_first"].to_numpy() if "T_first" in df.columns else np.zeros(len(df))
    times_last = df["T_last"].to_numpy() if "T_last" in df.columns else np.zeros(len(df))
    
    # 우선순위 기준 정렬: T_first가 빠른 순
    order = np.lexsort((np.arange(len(times_first)), times_first))  # stable sort
    used = np.full(len(df), False)
    keep_indices = []

    tree = cKDTree(coords)

    for i in order:
        if used[i]:
            continue
        keep_indices.append(i)

        neighbors = tree.query_ball_point(coords[i], r=cri_overlap)
        for j in neighbors:
            if i == j:
                continue
            if not (times_last[j] < times_first[i] or times_first[j] > times_last[i]):
                used[j] = True
        used[i] = True

    return df.iloc[keep_indices].reset_index(drop=True)

def find_colocalized_pairs(df1, df2, criInter):
    coords1 = df1[["X_center", "Y_center"]].to_numpy()
    coords2 = df2[["X_center", "Y_center"]].to_numpy()
    times1 = df1["T_first"].to_numpy()
    times2 = df2["T_first"].to_numpy()

    tree2 = cKDTree(coords2)

    pair_candidates = []

    for i, (coord1, t1) in enumerate(zip(coords1, times1)):
        dists, idxs = tree2.query(coord1, k=len(coords2), distance_upper_bound=criInter)

        # dists, idxs 는 거리순 정렬되어 있음
        for dist, j in zip(dists, idxs):
            if j == len(coords2) or not np.isfinite(dist):
                break  # bound 초과 or invalid
            pair_candidates.append((i, j, dist))

            break  # 가장 가까운 하나만 남기고 종료

    # 이제 df2 중복 제거: greedy 방식
    used_df2 = set()
    pairs = []

    for i, j, dist in sorted(pair_candidates, key=lambda x: x[2]):
        if j in used_df2:
            continue
        used_df2.add(j)

        t1 = times1[i]
        t2 = times2[j]
        coord1 = coords1[i]
        coord2 = coords2[j]
        timedelay = int(round(t1)) - int(round(t2))

        pairs.append([
            timedelay,
            dist,
            df1.iloc[i]["track_id"], t1, coord1[0], coord1[1],
            df2.iloc[j]["track_id"], t2, coord2[0], coord2[1]
        ])

    return pd.DataFrame(pairs, columns=[
        "time_delay", "distance",
        "ref_id", "ref_T", "ref_X", "ref_Y",
        "cmp_id", "cmp_T", "cmp_X", "cmp_Y"
    ])