# Detections

In [3]:
import cv2
import numpy as np
import pandas as pd
import os
import time

def get_config():
    """
    Define input/output paths and parameters.
    """
    input_video_path = '/Users/Ricardo/Desktop/Y4 Lab code/Cooldown 10/_video (0).mp4'
    output_folder = '/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K'
    # Vertical crop limits for the ROI
    y_lower_limit = 100
    y_upper_limit = 819
    # Process frames from 0 to this index (inclusive)
    end_frame = 300
    # Detection parameters for small/big particles
    params = {
        'min_small_intensity': 18,
        'min_big_intensity': 12,
        'background_removal_small': True,
        'background_removal_big': True,
        'small_neighborhood_and_integration_size': (5, 5),
        'big_neighborhood_and_integration_size': (9, 9),
        'small_gaussian_blur_kernel_size': (13, 13),
        'big_gaussian_blur_kernel_size': (13, 13),
        'small_position_refinement_size': (9, 9),
        'big_position_refinement_size': (9, 9),
        'small_size': (5, 5),
        'big_size': (9, 9),
        'raw_luminosity_grid_size': (9, 9)
    }
    return input_video_path, output_folder, y_lower_limit, y_upper_limit, end_frame, params


def find_local_minima(Z, neighborhood_size=(3,3)):
    """
    Find local minima in a 2D array (image).
    Uses morphological erosion to find pixels that are lower than or equal
    to their neighbours.
    """
    nh, nw = neighborhood_size
    kernel = np.ones((nh, nw), dtype=np.uint8)
    # Pixels equal to the eroded image are local minima
    local_min = (cv2.erode(Z, kernel) == Z)
    # Remove border artefacts where the neighbourhood would spill outside
    local_min[:nh//2, :] = False
    local_min[-(nh//2):, :] = False
    local_min[:, :nw//2] = False
    local_min[:, -(nw//2):] = False
    return local_min


def fit_quadratic_surface(Z_blurred, neighborhood_size=(3,3)):
    """
    Fit a quadratic surface to local minima points of the image.
    Used for background estimation/removal.
    """
    # Construct coordinate grids
    X, Y = np.meshgrid(np.arange(Z_blurred.shape[1]), np.arange(Z_blurred.shape[0]))
    # Use minima as proxy for background locations
    minima_mask = find_local_minima(Z_blurred, neighborhood_size)
    X_min, Y_min, Z_min = X[minima_mask], Y[minima_mask], Z_blurred[minima_mask]

    # Require enough points for least squares (6 params)
    if X_min.size < 6:
        return None

    # Quadratic surface model: ax^2 + by^2 + cxy + dx + ey + f
    A = np.c_[X_min**2, Y_min**2, X_min*Y_min, X_min, Y_min, np.ones_like(X_min)]
    try:
        coefficients, _, _, _ = np.linalg.lstsq(A, Z_min, rcond=None)
    except np.linalg.LinAlgError:
        return None
    return coefficients


def remove_background(Z, coefficients):
    """
    Remove background using fitted quadratic surface.
    Ensures negative pixel values are set to zero.
    """
    a, b, c, d, e, f = coefficients
    height, width = Z.shape
    X, Y = np.meshgrid(np.arange(width), np.arange(height))

    # Compute fitted background surface for the whole frame
    Z_fit = a * X**2 + b * Y**2 + c * X * Y + d * X + e * Y + f
    # Subtract background and clamp to [0, 255] via uint8 cast later
    Z_corrected = Z.astype(np.float32) - Z_fit
    Z_corrected[Z_corrected < 0] = 0
    return Z_corrected.astype(np.uint8)


def refine_centroid(frame, x, y, window_size=(5,5)):
    """
    Refine particle centroid using image moments inside a local window.
    Returns sub-pixel accurate coordinates.
    """
    # Define local window bounds around the initial (x, y)
    half_window_x = window_size[0] // 2
    half_window_y = window_size[1] // 2
    y_min = max(int(y) - half_window_y, 0)
    y_max = min(int(y) + half_window_y + 1, frame.shape[0])
    x_min = max(int(x) - half_window_x, 0)
    x_max = min(int(x) + half_window_x + 1, frame.shape[1])

    window = frame[y_min:y_max, x_min:x_max]
    if window.size == 0:
        return x, y

    # Calculate centroid using raw image moments
    m = cv2.moments(window.astype(np.uint8))
    if m['m00'] == 0:
        return x, y
    refined_x = x_min + (m['m10'] / m['m00'])
    refined_y = y_min + (m['m01'] / m['m00'])
    return refined_x, refined_y


def local_maxima(img, min_distance):
    """
    Find local maxima in an image, separated by at least min_distance.
    Uses dilation + connected components.
    """
    # Square structuring element sized to enforce min_distance separation
    kernel = np.ones((2 * min_distance + 1, 2 * min_distance + 1), np.uint8)
    dilated = cv2.dilate(img, kernel)
    # Pixels equal to local dilation are local peaks
    local_max_mask = (img == dilated)

    # Suppress peaks too close to borders (cannot form a full neighbourhood)
    local_max_mask[:min_distance, :] = False
    local_max_mask[-min_distance:, :] = False
    local_max_mask[:, :min_distance] = False
    local_max_mask[:, -min_distance:] = False

    # Label connected maxima regions and pick the brightest pixel per region
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(local_max_mask.astype(np.uint8))
    coordinates = []

    for label in range(1, num_labels):  # skip background
        x, y, w, h, _ = stats[label]
        sub_img = img[y:y + h, x:x + w]
        _, _, _, maxLoc = cv2.minMaxLoc(sub_img)
        global_y = y + maxLoc[1]
        global_x = x + maxLoc[0]
        coordinates.append((global_y, global_x))

    if coordinates:
        return np.array(coordinates)
    else:
        return np.empty((0, 2), dtype=int)


def find_particles(frame, original_roi, min_distance, min_intensity,
                   neighborhood_size, gaussian_blur_kernel_size,
                   position_refinement_size, mass_grid_size):
    """
    Detect particles in a frame:
    - Blur image
    - Find local maxima
    - Filter by intensity threshold
    - Refine positions
    - Compute raw luminosity using integral image
    """
    # Approximate sigma so the given kernel spans ~±3σ
    sigma_x = gaussian_blur_kernel_size[0] / 6.0
    sigma_y = gaussian_blur_kernel_size[1] / 6.0

    # Blur to reduce noise and small-scale texture
    blurred = cv2.GaussianBlur(frame, gaussian_blur_kernel_size, sigmaX=sigma_x, sigmaY=sigma_y)
    # Local mean for intensity thresholding
    averaged = cv2.boxFilter(frame, ddepth=-1, ksize=neighborhood_size, normalize=True)

    # Detect local maxima candidates
    coordinates = local_maxima(blurred, min_distance)
    if coordinates.size == 0:
        return pd.DataFrame(columns=['x', 'y', 'raw_luminosity', 'frame', 'tracker_type'])

    # Filter by local mean intensity threshold
    intensities = averaged[coordinates[:, 0], coordinates[:, 1]]
    valid_coords = coordinates[intensities >= min_intensity]
    if valid_coords.size == 0:
        return pd.DataFrame(columns=['x', 'y', 'raw_luminosity', 'frame', 'tracker_type'])

    # Refine positions (sub-pixel) around each valid coordinate
    refined_positions = np.array([refine_centroid(blurred, x, y, window_size=position_refinement_size) for y, x in valid_coords])
    x_positions = refined_positions[:, 0]
    y_positions = refined_positions[:, 1]

    # Compute luminosity using a summed-area table (integral image)
    integral = cv2.integral(original_roi)
    raw_luminosities = []
    half_grid = mass_grid_size[0] // 2
    for x, y in refined_positions:
        x_int = int(round(x))
        y_int = int(round(y))
        # Window bounds inside the ROI for average intensity
        y_min = max(y_int - half_grid, 0)
        y_max = min(y_int + half_grid + 1, original_roi.shape[0])
        x_min = max(x_int - half_grid, 0)
        x_max = min(x_int + half_grid + 1, original_roi.shape[1])
        # Fast sum via integral image
        sum_window = integral[y_max, x_max] - integral[y_min, x_max] - integral[y_max, x_min] + integral[y_min, x_min]
        area = (y_max - y_min) * (x_max - x_min)
        raw_luminosities.append(sum_window / area if area > 0 else np.nan)

    return pd.DataFrame({'x': x_positions, 'y': y_positions, 'raw_luminosity': raw_luminosities, 'frame': np.nan})


def temporal_gaussian_blur(frames, kernel_size):
    """
    Apply Gaussian blur across the temporal (time) dimension of a stack of frames.
    """
    # Build a 1D Gaussian kernel over the time axis
    sigma = kernel_size / 6
    k = np.arange(kernel_size) - kernel_size // 2
    kernel = np.exp(-0.5 * (k / sigma) ** 2)
    kernel /= kernel.sum()

    # Pad in time to avoid edge effects (replicate ends)
    padded = np.pad(frames, ((kernel_size // 2, kernel_size // 2), (0, 0), (0, 0)), mode='edge')
    blurred = np.empty_like(frames, dtype=np.float32)

    # Convolve per-pixel along the temporal dimension
    for i in range(frames.shape[0]):
        blurred[i] = np.tensordot(kernel, padded[i:i + kernel_size], axes=(0, 0))
    return blurred.astype(np.uint8)


def read_video_frames(input_video_path, y_lower_limit, y_upper_limit, end_frame):
    """
    Read frames from video up to end_frame.
    Extract region of interest (ROI) between y_lower_limit and y_upper_limit.
    """
    cap = cv2.VideoCapture(input_video_path)
    frames = []
    frame_number = 0

    while frame_number <= end_frame:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert to grayscale to simplify processing
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        # Crop to the vertical ROI; keep all columns
        roi = gray[y_lower_limit:y_upper_limit, :]
        frames.append(roi)
        frame_number += 1

    cap.release()
    return np.array(frames)


def process_frames(frames, y_lower_limit, end_frame,
                   min_small_intensity, min_big_intensity,
                   background_removal_small, background_removal_big,
                   small_neighborhood_and_integration_size, big_neighborhood_and_integration_size,
                   small_gaussian_blur_kernel_size, big_gaussian_blur_kernel_size,
                   small_position_refinement_size, big_position_refinement_size,
                   small_size, big_size,
                   raw_luminosity_grid_size,
                   temporal_gaussian_blur_kernel_size=None):
    """
    Full pipeline for detecting small and big particles in each frame:
    - Optionally apply temporal blur
    - Optionally remove background
    - Detect particles with different parameters for small/big
    - Collect detections in DataFrame
    """
    # Optional temporal smoothing to stabilise detections across frames
    proc = temporal_gaussian_blur(frames, temporal_gaussian_blur_kernel_size) if temporal_gaussian_blur_kernel_size else frames
    all_small, all_big = [], []

    for i in range(frames.shape[0]):
        roi = proc[i]
        # When temporal blur is used, measure luminosity on original frame to avoid bias
        roi_orig = frames[i] if temporal_gaussian_blur_kernel_size else roi

        # Fit background once per frame (on a blurred version) if any removal is needed
        coeff = fit_quadratic_surface(cv2.GaussianBlur(roi, small_gaussian_blur_kernel_size, 0),
                                      small_neighborhood_and_integration_size) if (background_removal_small or background_removal_big) else None

        # Apply background removal independently for small/big streams
        roi_small = remove_background(roi, coeff) if background_removal_small and coeff is not None else roi.copy()
        roi_big = remove_background(roi, coeff) if background_removal_big and coeff is not None else roi.copy()

        # Minimum distance between peaks based on expected particle size
        min_d_small = max(small_size) // 2
        min_d_big = max(big_size) // 2

        # Detect small particles
        det_small = find_particles(roi_small, roi_orig, min_d_small, min_small_intensity,
                                   small_neighborhood_and_integration_size, small_gaussian_blur_kernel_size,
                                   small_position_refinement_size, raw_luminosity_grid_size)
        if not det_small.empty:
            # Shift back to full-frame coordinates on y-axis
            det_small['y'] += y_lower_limit
            det_small['frame'] = i
            det_small['tracker_type'] = 'small'
            all_small.append(det_small)

        # Detect big particles
        det_big = find_particles(roi_big, roi_orig, min_d_big, min_big_intensity,
                                 big_neighborhood_and_integration_size, big_gaussian_blur_kernel_size,
                                 big_position_refinement_size, raw_luminosity_grid_size)
        if not det_big.empty:
            det_big['y'] += y_lower_limit
            det_big['frame'] = i
            det_big['tracker_type'] = 'big'
            all_big.append(det_big)

    # Combine results; ensure columns exist if no detections
    df_small = pd.concat(all_small, ignore_index=True) if all_small else pd.DataFrame(columns=['x', 'y', 'raw_luminosity', 'frame', 'tracker_type'])
    df_big = pd.concat(all_big, ignore_index=True) if all_big else pd.DataFrame(columns=['x', 'y', 'raw_luminosity', 'frame', 'tracker_type'])
    return pd.concat([df_small, df_big], ignore_index=True)


def main():
    """
    Main function:
    - Read video frames
    - Process frames to detect particles
    - Save results to CSV
    """
    input_video_path, output_folder, y_lower_limit, y_upper_limit, end_frame, params = get_config()

    # Ensure output folder exists
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Read video and extract ROI frames
    frames = read_video_frames(input_video_path, y_lower_limit, y_upper_limit, end_frame)

    # Run detection pipeline
    detections = process_frames(frames, y_lower_limit, end_frame,
                                params['min_small_intensity'], params['min_big_intensity'],
                                params['background_removal_small'], params['background_removal_big'],
                                params['small_neighborhood_and_integration_size'], params['big_neighborhood_and_integration_size'],
                                params['small_gaussian_blur_kernel_size'], params['big_gaussian_blur_kernel_size'],
                                params['small_position_refinement_size'], params['big_position_refinement_size'],
                                params['small_size'], params['big_size'],
                                params['raw_luminosity_grid_size'])

    # Save results for downstream linking
    detections['video'] = os.path.basename(input_video_path)
    detections.to_csv(os.path.join(output_folder, 'particle_detections.csv'), index=False)
    print(len(detections), "rows exported for particle_detections.csv")


if __name__ == "__main__":
    main()


15905 rows exported for particle_detections.csv


# Track Linking

Version 1

In [4]:
import os
import pandas as pd
import numpy as np
from numba import njit
from scipy.spatial import cKDTree
from collections import defaultdict

def get_config():
    """
    Define I/O and all tracking hyperparameters for forward and reverse linking.
    """
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"

    # Forward-pass constraints (when linking in ascending frame order)
    fwd_params = {
        "memory": 0,  # how many consecutive frames a track can be unmatched before deletion
        # Speed thresholds (pixels/frame) that define low/medium/high regimes
        "slow_max_speed": 6, "medium_max_speed": 12, "high_max_speed": 65,
        # Allowed change in direction (radians) for each regime
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4 * np.pi, "upper": 0.4 * np.pi},
        "high_direction_change_limits": {"lower": -0.15 * np.pi, "upper": 0.15 * np.pi},
        # Allowed change in speed (Δ pixels/frame) for each regime
        "low_speed_change_limits": {"lower": -30, "upper": 50},
        "medium_speed_change_limits": {"lower": -40, "upper": 55},
        "high_speed_change_limits": {"lower": -50, "upper": 60},
        # Distance threshold (pixels) to look back in a track’s history for a “previous” point
        "directional_threshold": 12
    }

    # Reverse-pass constraints (when linking in descending frame order)
    # Direction limits are symmetric; speed-change limits are mirrored
    rev_params = {
        "memory": 0,
        "slow_max_speed": 6, "medium_max_speed": 12, "high_max_speed": 65,
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4 * np.pi, "upper": 0.4 * np.pi},
        "high_direction_change_limits": {"lower": -0.15 * np.pi, "upper": 0.15 * np.pi},
        "low_speed_change_limits": {"lower": -50, "upper": 30},
        "medium_speed_change_limits": {"lower": -55, "upper": 40},
        "high_speed_change_limits": {"lower": -60, "upper": 50},
        "directional_threshold": 12
    }

    end_frame = 600   # not used directly here, but returned for potential future use
    min_links = 10    # minimum detections per track to keep after filtering
    return (None, None, None, fwd_params, rev_params, end_frame, min_links)


class MemoryManager:
    """
    Tracks currently-active particle tracks and basic per-track stats.
    """
    def __init__(self, memory):
        self.memory = memory               # max tolerated consecutive missed frames
        self.active = {}                   # active tracks: id -> state dict
        self.length = {}                   # total assigned detections per track id

    def add(self, ids, positions):
        """
        Start new tracks with given ids at given positions.
        """
        for p in ids:
            self.active[p] = {
                "position": positions[p],      # last known (x, y)
                "history": [positions[p]],     # full position history (list of (x, y))
                "frames_lost": 0,              # # of consecutive misses
                "previous_speed": 0,           # last observed speed (pixels/frame)
                "speed_mode": "low"            # last regime: 'low'|'medium'|'high'
            }
            self.length[p] = 1

    def update(self, ids, new_positions):
        """
        Update existing tracks (append to history and bump length counter).
        """
        for p in ids:
            if p in self.active:
                self.active[p]["history"].append(new_positions[p])
                self.active[p]["position"] = new_positions[p]
                self.length[p] += 1


class Statistics:
    """
    Collect diagnostics on rejected candidate links (angles and speed changes)
    and total distance travelled per regime.
    """
    def __init__(self):
        self.dir_stats = {"low": [], "medium": [], "high": []}     # rejected direction-change values
        self.speed_stats = {"low": [], "medium": [], "high": []}   # rejected speed-change values
        self.distances = {"low": 0, "medium": 0, "high": 0}        # accepted path length per regime

    def update_dir(self, mode, angle, params):
        self.dir_stats[mode].append(angle)

    def update_speed(self, mode, speed_change, params):
        self.speed_stats[mode].append(speed_change)

    def add_distance(self, mode, distance):
        self.distances[mode] += distance


def calc_angle_jit(prev, cur, nw):
    """
    Compute signed turn angle between vectors (prev->cur) and (cur->nw).
    Returns angle in radians in [-pi, pi]. If any segment has zero length, returns 0.
    """
    m1 = cur - prev
    m2 = nw - cur
    n1 = np.sqrt(m1[0] * m1[0] + m1[1] * m1[1])
    n2 = np.sqrt(m2[0] * m2[0] + m2[1] * m2[1])
    if n1 == 0 or n2 == 0:
        return 0.0
    dp = (m1[0] * m2[0] + m1[1] * m2[1]) / (n1 * n2)
    dp = min(max(dp, -1.0), 1.0)  # numerical clamp for arccos domain
    ang = np.arccos(dp)
    # Determine sign using 2D cross product (z-component)
    if m1[0] * m2[1] - m1[1] * m2[0] < 0:
        ang = -ang
    return ang


def calc_angle(prev, cur, nw):
    """
    Wrapper that handles None 'prev' point and ensures float inputs.
    """
    if prev is None:
        return 0.0
    return calc_angle_jit(
        np.asarray(prev, dtype=np.float64),
        np.asarray(cur, dtype=np.float64),
        np.asarray(nw, dtype=np.float64)
    )


def find_prev(hist, cur, thresh):
    """
    Walk history backwards to find the latest point far enough from 'cur'
    (distance > thresh). Helps stabilise direction-change estimation.
    """
    for prev in reversed(hist[:-1]):
        if np.linalg.norm(cur - prev) > thresh:
            return prev
    return None


def custom_link(dets, params, stats, asc=True):
    """
    Greedy frame-by-frame linker with physics-inspired constraints.
    - dets: DataFrame with columns ['frame','x','y']
    - params: hyperparameters dict (see get_config)
    - stats: Statistics object to collect diagnostics
    - asc: True for forward pass (increasing frame); False for reverse
    Returns: (linked DataFrame, count_rank1_matches, count_rank2_matches)
    """
    mem = params["memory"]
    max_dist = params["high_max_speed"]     # KD-tree search radius
    thresh = params["directional_threshold"]
    pid = 0                                 # next new track id

    # Sort detections by frame in chosen direction
    dets = dets.sort_values("frame", ascending=asc).reset_index(drop=True)
    dets["particle"] = np.nan

    mm = MemoryManager(mem)
    l1 = l2 = 0  # number of first-choice and second-choice links accepted

    for f, fd in dets.groupby("frame"):
        pos = fd[["x", "y"]].values
        idxs = fd.index.values

        # If no active tracks, start new ones for all detections
        if not mm.active:
            new_ids = list(range(pid, pid + len(fd)))
            new_pos = {p: pos[i] for i, p in enumerate(new_ids)}
            mm.add(new_ids, new_pos)
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # Rank active ids by their current length (prefer longer, more stable tracks)
        act_ids = sorted(mm.active.keys(), key=lambda p: mm.length[p], reverse=True)
        active = mm.active

        # If still empty (paranoia), spawn all as new
        if not act_ids:
            new_ids = list(range(pid, pid + len(fd)))
            new_pos = {p: pos[i] for i, p in enumerate(new_ids)}
            mm.add(new_ids, new_pos)
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # Build KD-tree over last positions of active tracks to find candidate links
        act_pos = np.array([active[p]["position"] for p in act_ids])
        tree = cKDTree(act_pos)
        # Query up to 2 nearest active tracks for each detection (within max_dist)
        dists, a_idxs = tree.query(pos, k=2, distance_upper_bound=max_dist)

        # Prepare all possible (track, detection) pairings with angle & distance
        poss_list = []
        for i in range(len(pos)):
            for k in range(2):  # consider 1st and 2nd nearest
                if a_idxs[i, k] >= len(act_ids) or dists[i, k] == np.inf:
                    continue
                aid = act_ids[a_idxs[i, k]]
                part = active[aid]
                cur_pos = part["position"]
                prev_pt = find_prev(part["history"], cur_pos, thresh)
                angle = calc_angle(prev_pt, cur_pos, pos[i])
                poss_list.append((aid, i, idxs[i], angle, dists[i, k], k+1))  # rank=k+1

        # Sort candidates:
        #   1) longer tracks first
        #   2) smaller |angle| (straighter continuation)
        #   3) smaller spatial distance
        poss_list.sort(key=lambda x: (-mm.length[x[0]], abs(x[3]), x[4]))

        used_a = set()  # already used active track ids
        used_d = set()  # already used detection indices
        fm = []         # accepted matches: (aid, det_idx, rank)

        for aid, di, dii, angle, dist_val, rank in poss_list:
            if aid in used_a or di in used_d:
                continue

            part = active.get(aid)
            if part is None:
                continue

            # Compute angle and speed using most recent two points
            prev_pt = part["history"][-2] if len(part["history"]) >= 2 else None
            angle = calc_angle(prev_pt, part["position"], pos[di])
            disp = pos[di] - part["position"]
            spd = np.linalg.norm(disp)
            spd_ch = spd - part["previous_speed"]

            # Determine current speed regime of candidate
            cm = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            pm = part["speed_mode"]  # previous regime

            # Compose cumulative limits across regime transition path
            modes = ["low", "medium", "high"]
            pi, ci = modes.index(pm), modes.index(cm)
            seq = modes[pi:ci+1] if pi <= ci else modes[ci:pi+1]

            cs_lim = {"lower": -np.inf, "upper": np.inf}  # cumulative speed-change limits
            cd_lim = {"lower": -np.inf, "upper": np.inf}  # cumulative direction-change limits
            for m in seq:
                sp_lim = params[f"{m}_speed_change_limits"]
                d_lim = params[f"{m}_direction_change_limits"]
                cs_lim["lower"] = max(cs_lim["lower"], sp_lim["lower"])
                cs_lim["upper"] = min(cs_lim["upper"], sp_lim["upper"])
                cd_lim["lower"] = max(cd_lim["lower"], d_lim["lower"])
                cd_lim["upper"] = min(cd_lim["upper"], d_lim["upper"])

            # Apply constraints; if violated, keep diagnostics and possibly drop track
            rej = False
            if not (cd_lim["lower"] <= angle <= cd_lim["upper"]):
                stats.update_dir(cm, angle, params)
                rej = True
            if not (cs_lim["lower"] <= spd_ch <= cs_lim["upper"]):
                stats.update_speed(cm, spd_ch, params)
                rej = True
            if rej:
                part["frames_lost"] += 1
                if part["frames_lost"] > params["memory"]:
                    # Forget track if memory exceeded
                    del active[aid]
                    del mm.length[aid]
                continue

            # Accept match: assign detection to track
            dets.at[dii, "particle"] = aid
            fm.append((aid, di, rank))
            used_a.add(aid)
            used_d.add(di)
            l1 += (rank == 1)
            l2 += (rank == 2)

        # Commit accepted matches: update positions & per-track speed regime
        new_pos = {}
        mids = []
        for aid, di, _ in fm:
            new_pos[aid] = pos[di]
            spd = np.linalg.norm(pos[di] - active[aid]["position"])
            sm_val = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            active[aid]["speed_mode"] = sm_val
            active[aid]["previous_speed"] = spd
            stats.add_distance(sm_val, spd)
            mids.append(aid)
        mm.update(mids, new_pos)

        # Any detections not matched start new tracks
        un = set(range(len(fd))) - used_d
        if un:
            new_ids = list(range(pid, pid + len(un)))
            new_pos = {p: pos[i] for p, i in zip(new_ids, un)}
            mm.add(new_ids, new_pos)
            for p, i in zip(new_ids, un):
                dets.at[idxs[i], "particle"] = p
            pid += len(un)

    # Ensure nullable integer dtype for particle ids
    dets["particle"] = dets["particle"].astype("Int64")

    # Return frame order consistent with pass direction
    return (dets.sort_values("frame").reset_index(drop=True) if not asc else dets, l1, l2)


def process_tracker_type(args):
    """
    Run forward and reverse linking for a given tracker_type subset.
    Adds metadata columns and returns both DataFrames.
    """
    t, sub, fwd_params, rev_params = args
    stats_f = Statistics()
    stats_r = Statistics()

    # Forward pass
    lf, _, _ = custom_link(sub.copy(), fwd_params, stats_f, asc=True)
    lf["link_direction"] = "forward"
    lf["tracker_type"] = t + "_forward"
    lf["unique_id"] = lf["particle"].apply(lambda x: f"{t}_forward_{int(x)}")

    # Reverse pass
    lr, _, _ = custom_link(sub.copy(), rev_params, stats_r, asc=False)
    lr["link_direction"] = "reverse"
    lr["tracker_type"] = t + "_reverse"
    lr["unique_id"] = lr["particle"].apply(lambda x: f"{t}_reverse_{int(x)}")

    return lf, lr


def find_overlaps(df, dt, mcf):
    """
    Identify overlapping tracks (close in space within 'dt' pixels) that co-exist
    for at least 'mcf' consecutive frames. Returns:
      - osm: dict[(idA,idB)] -> list of (start_frame, end_frame) segments
      - tl: dict unique_id -> total length (# detections) of each track
    """
    # Group detections per frame for efficient per-frame proximity queries
    frame_groups = {f: group for f, group in df.groupby("frame")}
    # Track lengths
    tl = df.groupby("unique_id").size().to_dict()

    # Candidate overlaps (per pair -> list of frames where distance < dt)
    od = defaultdict(list)
    for f, group in frame_groups.items():
        pos = group[["x", "y"]].values
        parts = group["unique_id"].values
        if len(pos) == 0:
            continue
        tree = cKDTree(pos)
        for i, j in tree.query_pairs(dt):
            od[tuple(sorted((parts[i], parts[j])))] .append(f)

    # Convert scattered frames into contiguous segments of length >= mcf
    osm = {}
    for c, fs in od.items():
        fs = sorted(fs)
        seg = []
        s = fs[0]
        p = fs[0]
        for f in fs[1:]:
            if f == p + 1:
                p = f
            else:
                if p - s + 1 >= mcf:
                    seg.append((s, p))
                s = f
                p = f
        if p - s + 1 >= mcf:
            seg.append((s, p))
        if seg:
            osm[c] = seg
    return osm, tl


def filter_tracks(df, osm, tl):
    """
    Remove the shorter member of overlapping track pairs over their overlapping
    segments, keeping the longer one intact.
    """
    rem = set()
    for cl, segs in osm.items():
        # pick the longest track in the conflicted pair
        cll = {p: tl[p] for p in cl}
        lp = max(cll, key=cll.get)         # keep this one
        sp = [p for p in cl if p != lp]    # candidates to remove within overlap
        # Remove only the overlapping portions of the shorter track(s)
        for s, e in segs:
            for p in sp:
                rem.update(df[(df["unique_id"] == p) & (df["frame"] >= s) & (df["frame"] <= e)].index.tolist())
    return df.drop(rem).reset_index(drop=True)


def main():
    """
    Entry point:
      1) Load detections
      2) Track separately per original tracker_type (e.g., 'small'/'big')
      3) Run forward and reverse linkers; combine
      4) Detect spatial overlaps and filter shorter duplicates
      5) Keep only tracks with >= min_links detections
      6) Save CSVs
    """
    (_, _, _, fwd_params, rev_params, end_frame, min_links) = get_config()

    # List of detection CSVs to process
    det_files = ["/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K/particle_detections.csv"]

    for det_file in det_files:
        # Only required columns here; original 'tracker_type' comes from detection stage
        det = pd.read_csv(det_file, usecols=["frame", "x", "y", "tracker_type"])

        # Track each detection subtype independently (e.g., 'small' and 'big')
        unique_types = det["tracker_type"].unique()
        res = []
        for t in unique_types:
            sub = det[det["tracker_type"] == t].copy()
            lf, lr = process_tracker_type((t, sub, fwd_params, rev_params))
            res.extend([lf, lr])

        # Combine forward+reverse passes and export raw tracks
        comb = pd.concat(res, ignore_index=True)
        out_track = os.path.join(os.path.dirname(det_file), "particle_tracks_v1.csv")
        out_filt  = os.path.join(os.path.dirname(det_file), "filtered_particle_tracks_v1.csv")
        comb.to_csv(out_track, index=False)

        # Overlap filtering parameters:
        dt = 3   # spatial proximity for overlap (pixels)
        mcf = 2  # minimum consecutive frames to consider as real overlap

        # Find overlaps and track lengths, then filter
        osm, tl = find_overlaps(comb, dt, mcf)
        filt = filter_tracks(comb, osm, tl)

        # Drop short tracks (keep only those with at least 'min_links' detections)
        cnt = filt["unique_id"].value_counts()
        filt = filt[filt["unique_id"].isin(cnt[cnt >= min_links].index)].reset_index(drop=True)

        # Save filtered result
        filt.to_csv(out_filt, index=False)


if __name__ == "__main__":
    main()


Version 2a

In [5]:
import os
import pandas as pd
import numpy as np
from numba import njit
from scipy.spatial import cKDTree
from collections import defaultdict

def get_config():
    """
    Define tracking hyperparameters (forward and reverse passes) and global settings.
    """
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"

    # Forward-pass parameters (ascending frame order)
    fwd_params = {
        "memory": 0,  # how many consecutive frames a track can be unmatched before being dropped

        # Speed thresholds (pixels/frame) defining regimes
        "slow_max_speed": 6,
        "medium_max_speed": 12,
        "high_max_speed": 65,

        # Allowed direction change (radians) per regime
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4 * np.pi, "upper": 0.4 * np.pi},
        "high_direction_change_limits": {"lower": -0.15 * np.pi, "upper": 0.15 * np.pi},

        # Allowed speed change (Δ pixels/frame) per regime
        "low_speed_change_limits": {"lower": -30, "upper": 50},
        "medium_speed_change_limits": {"lower": -40, "upper": 55},
        "high_speed_change_limits": {"lower": -50, "upper": 60},

        # Multi-objective matching weights & constraints
        "w_distance": 1.0,   # weight for spatial distance cost
        "w_angle": 1.0,      # weight for turning angle cost
        "w_lum": 0.1,        # weight for luminosity difference cost
        "max_lum_diff": 100  # hard gate on allowed luminosity difference
    }

    # Reverse-pass parameters (descending frame order)
    # Same geometry limits; speed-change limits are mirrored
    rev_params = {
        "memory": 0,
        "slow_max_speed": 6,
        "medium_max_speed": 12,
        "high_max_speed": 65,
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4 * np.pi, "upper": 0.4 * np.pi},
        "high_direction_change_limits": {"lower": -0.15 * np.pi, "upper": 0.15 * np.pi},
        "low_speed_change_limits": {"lower": -50, "upper": 30},
        "medium_speed_change_limits": {"lower": -55, "upper": 40},
        "high_speed_change_limits": {"lower": -60, "upper": 50},
        "w_distance": 1.0,
        "w_angle": 1.0,
        "w_lum": 0.1,
        "max_lum_diff": 100
    }

    end_frame = 600  # returned for potential future use (not consumed here)
    min_links = 10   # minimum detections per track to retain after filtering
    return (None, None, None, fwd_params, rev_params, end_frame, min_links)


class MemoryManager:
    """
    Keeps state for active tracks and their histories.
    """
    def __init__(self, memory):
        self.memory = memory              # tolerated missed frames
        self.active = {}                  # track_id -> state dict
        self.length = {}                  # track_id -> count of assigned detections

    def add(self, ids, positions, luminosities):
        """
        Start new tracks with given ids, initial positions and luminosities.
        """
        for i, p in enumerate(ids):
            self.active[p] = {
                "position": positions[i],         # last position (x, y)
                "history": [positions[i]],        # full trajectory (list of (x, y))
                "luminosity": luminosities[i],    # last luminosity
                "lum_history": [luminosities[i]], # luminosity history
                "frames_lost": 0,                 # consecutive unmatched frames
                "previous_speed": 0,              # last speed (pixels/frame)
                "speed_mode": "low"               # last regime: 'low'|'medium'|'high'
            }
            self.length[p] = 1

    def update(self, ids, new_positions, new_luminosities):
        """
        Update existing tracks that were matched this frame.
        """
        for p in ids:
            if p in self.active:
                self.active[p]["history"].append(new_positions[p])
                self.active[p]["position"] = new_positions[p]
                self.active[p]["lum_history"].append(new_luminosities[p])
                self.active[p]["luminosity"] = new_luminosities[p]
                self.length[p] += 1


class Statistics:
    """
    Collect diagnostics (for debugging/tuning) and cumulative distances per regime.
    """
    def __init__(self):
        self.dir_stats = {"low": [], "medium": [], "high": []}     # rejected direction changes
        self.speed_stats = {"low": [], "medium": [], "high": []}   # rejected speed changes
        self.distances = {"low": 0, "medium": 0, "high": 0}        # accepted distances

    def update_dir(self, mode, angle, params):
        self.dir_stats[mode].append(angle)

    def update_speed(self, mode, speed_change, params):
        self.speed_stats[mode].append(speed_change)

    def add_distance(self, mode, distance):
        self.distances[mode] += distance


def calc_angle_jit(prev, cur, nw):
    """
    Compute signed turning angle between (prev->cur) and (cur->nw).
    Returns radians in [-pi, pi]. If a segment is degenerate, returns 0.
    """
    m1 = cur - prev
    m2 = nw - cur
    n1 = np.sqrt(m1[0]**2 + m1[1]**2)
    n2 = np.sqrt(m2[0]**2 + m2[1]**2)
    if n1 == 0 or n2 == 0:
        return 0.0
    dp = (m1[0]*m2[0] + m1[1]*m2[1]) / (n1*n2)
    dp = min(max(dp, -1.0), 1.0)  # numerical clamp
    ang = np.arccos(dp)
    # Use 2D cross product sign for direction
    if m1[0]*m2[1] - m1[1]*m2[0] < 0:
        ang = -ang
    return ang


def calc_angle(prev, cur, nw):
    """
    Safe wrapper: handle None and ensure float arrays.
    """
    if prev is None:
        return 0.0
    return calc_angle_jit(np.asarray(prev, dtype=np.float64),
                          np.asarray(cur,  dtype=np.float64),
                          np.asarray(nw,   dtype=np.float64))


def find_prev(hist, cur, thresh):
    """
    Walk back along a track’s history to find the most recent point
    sufficiently far from 'cur' (distance > thresh). Stabilises angle.
    """
    for prev in reversed(hist[:-1]):
        if np.linalg.norm(cur - prev) > thresh:
            return prev
    return None


def custom_link(dets, params, stats, asc=True, tracker_name=""):
    """
    Greedy frame-by-frame linker with multi-objective matching:
      - distance (nearest neighbours via KD-tree)
      - turning angle (trajectory smoothness)
      - luminosity similarity (appearance)
    Candidate pairs outside geometry/luminosity gates are skipped.
    Accepted pairs are then tested against regime-specific direction/speed-change limits.
    """
    mem = params["memory"]
    max_dist = params["high_max_speed"]   # KD-tree search radius (pixels)

    # Cost weights & hard gate on luminosity
    w_distance = params["w_distance"]
    w_angle = params["w_angle"]
    w_lum = params["w_lum"]
    max_lum_diff = params["max_lum_diff"]

    pid = 0  # next new track id

    # Sort detections by frame (ascending for forward, descending for reverse)
    dets = dets.sort_values("frame", ascending=asc).reset_index(drop=True)
    dets["particle"] = np.nan

    mm = MemoryManager(mem)

    # Counters for diagnostics
    l1 = l2 = 0                            # accepted first-/second-nearest links
    overall_candidate_count = 0            # number of evaluated candidates
    overall_distance_cost = 0.0
    overall_angle_cost = 0.0
    overall_lum_cost = 0.0
    overall_skip_count = 0                 # luminosity gate rejections

    # Process one frame at a time
    for f, fd in dets.groupby("frame"):
        # Per-frame diagnostics
        frame_distance_cost = 0.0
        frame_angle_cost = 0.0
        frame_lum_cost = 0.0
        frame_candidate_count = 0
        frame_skip_count = 0

        # Current detections: positions and luminosities
        pos = fd[["x", "y"]].values
        lum_arr = fd["raw_luminosity"].values
        idxs = fd.index.values

        # If no active tracks, spawn new ones for every detection
        if not mm.active:
            new_ids = list(range(pid, pid + len(fd)))
            mm.add(new_ids, list(pos), list(lum_arr))
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # Rank active tracks by length (prefer longer, more stable tracks)
        act_ids = sorted(mm.active.keys(), key=lambda p: mm.length[p], reverse=True)
        active = mm.active

        # If empty after sorting (edge case), spawn all as new
        if not act_ids:
            new_ids = list(range(pid, pid + len(fd)))
            mm.add(new_ids, list(pos), list(lum_arr))
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # KD-tree over last positions of active tracks
        act_pos = np.array([active[p]["position"] for p in act_ids])
        tree = cKDTree(act_pos)

        # Query up to 2 nearest active tracks per detection within 'max_dist'
        dists, a_idxs = tree.query(pos, k=2, distance_upper_bound=max_dist)

        # Build candidate list with composite costs
        poss_list = []
        for i in range(len(pos)):
            for k in range(2):
                # Skip if outside radius or invalid index
                if a_idxs[i, k] >= len(act_ids) or dists[i, k] == np.inf:
                    continue
                aid = act_ids[a_idxs[i, k]]
                part = active[aid]

                # Turning angle using previous point (if available)
                prev_pt = part["history"][-2] if len(part["history"]) >= 2 else None
                angle = calc_angle(prev_pt, part["position"], pos[i])

                # Appearance gate: reject if luminosity differs too much
                lum_diff = abs(lum_arr[i] - part["luminosity"])
                if lum_diff > max_lum_diff:
                    frame_skip_count += 1
                    continue

                # Composite cost terms
                distance_cost = w_distance * dists[i, k]
                angle_cost = w_angle * abs(angle)
                lum_cost = w_lum * lum_diff
                total_cost = distance_cost + angle_cost + lum_cost

                # Accumulate diagnostics
                frame_distance_cost += distance_cost
                frame_angle_cost += angle_cost
                frame_lum_cost += lum_cost
                frame_candidate_count += 1

                # Store candidate: (track_id, det_idx_in_pos, det_row_index, angle, dist, rank(1|2), total_cost)
                poss_list.append((aid, i, idxs[i], angle, dists[i, k], k+1, total_cost))

        # Update global diagnostics
        overall_candidate_count += frame_candidate_count
        overall_distance_cost += frame_distance_cost
        overall_angle_cost += frame_angle_cost
        overall_lum_cost += frame_lum_cost
        overall_skip_count += frame_skip_count

        # Greedy assignment by ascending total_cost
        poss_list.sort(key=lambda x: x[6])
        used_a = set()  # tracks already matched this frame
        used_d = set()  # detections already matched this frame
        fm = []         # accepted matches (aid, det_idx_in_pos, rank)

        for candidate in poss_list:
            aid, i, dii, angle, dist_val, rank, cost = candidate
            if aid in used_a or i in used_d:
                continue

            part = active.get(aid)
            if part is None:
                continue

            # Recompute angle/speed using the most recent point(s)
            prev_pt = part["history"][-2] if len(part["history"]) >= 2 else None
            angle = calc_angle(prev_pt, part["position"], pos[i])
            disp = pos[i] - part["position"]
            spd = np.linalg.norm(disp)
            spd_ch = spd - part["previous_speed"]

            # Determine current/previous speed regimes
            cm = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            pm = part["speed_mode"]

            # Build cumulative constraints along regime transition path
            modes = ["low", "medium", "high"]
            pi, ci = modes.index(pm), modes.index(cm)
            seq = modes[pi:ci+1] if pi <= ci else modes[ci:pi+1]

            cs_lim = {"lower": -np.inf, "upper": np.inf}  # speed-change limits
            cd_lim = {"lower": -np.inf, "upper": np.inf}  # direction-change limits
            for m in seq:
                sp_lim = params[f"{m}_speed_change_limits"]
                d_lim = params[f"{m}_direction_change_limits"]
                cs_lim["lower"] = max(cs_lim["lower"], sp_lim["lower"])
                cs_lim["upper"] = min(cs_lim["upper"], sp_lim["upper"])
                cd_lim["lower"] = max(cd_lim["lower"], d_lim["lower"])
                cd_lim["upper"] = min(cd_lim["upper"], d_lim["upper"])

            # Enforce constraints; record rejects for diagnostics
            rej = False
            if not (cd_lim["lower"] <= angle <= cd_lim["upper"]):
                stats.update_dir(cm, angle, params)
                rej = True
            if not (cs_lim["lower"] <= spd_ch <= cs_lim["upper"]):
                stats.update_speed(cm, spd_ch, params)
                rej = True
            if rej:
                part["frames_lost"] += 1
                if part["frames_lost"] > params["memory"]:
                    # Drop track if memory exceeded
                    del active[aid]
                    del mm.length[aid]
                continue

            # Accept match
            dets.at[dii, "particle"] = aid
            fm.append((aid, i, rank))
            used_a.add(aid)
            used_d.add(i)
            l1 += (rank == 1)
            l2 += (rank == 2)

        # Commit updates for accepted matches and update per-track regimes
        new_pos = {}
        new_lum = {}
        mids = []
        for aid, i, _ in fm:
            new_pos[aid] = pos[i]
            new_lum[aid] = lum_arr[i]
            spd = np.linalg.norm(pos[i] - active[aid]["position"])
            sm_val = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            active[aid]["speed_mode"] = sm_val
            active[aid]["previous_speed"] = spd
            stats.add_distance(sm_val, spd)
            mids.append(aid)
        mm.update(mids, new_pos, new_lum)

        # Any unused detections spawn new tracks
        un = set(range(len(fd))) - used_d
        if un:
            new_ids = list(range(pid, pid + len(un)))
            pos_un = [pos[i] for i in un]
            lum_un = [lum_arr[i] for i in un]
            mm.add(new_ids, pos_un, lum_un)
            for p, i in zip(new_ids, un):
                dets.at[idxs[i], "particle"] = p
            pid += len(un)

    # Summary diagnostics for this tracker subset
    num_tracks = dets["particle"].nunique()
    num_links = len(dets) - num_tracks
    print("{}: Processed {} candidates; Total distance cost: {:.3f}, Total angle cost: {:.3f}, Total lum cost: {:.3f}; Total skipped (max_lum_diff): {}; Number of tracks: {}; Number of links: {}".format(
        tracker_name, overall_candidate_count, overall_distance_cost, overall_angle_cost, overall_lum_cost,
        overall_skip_count, num_tracks, num_links))

    # Nullable integer dtype for track ids
    dets["particle"] = dets["particle"].astype("Int64")

    # Preserve direction’s natural ordering of frames in the output
    return (dets.sort_values("frame").reset_index(drop=True) if not asc else dets, l1, l2)


def process_tracker_type(args):
    """
    Run both forward and reverse linking for one detection subtype (e.g. 'small' or 'big').
    Adds direction/type tags and builds a unique_id per produced track.
    """
    t, sub, fwd_params, rev_params = args
    stats_f = Statistics()
    stats_r = Statistics()

    # Forward pass
    lf, _, _ = custom_link(sub.copy(), fwd_params, stats_f, asc=True, tracker_name=t + "_forward")
    lf["link_direction"] = "forward"
    lf["tracker_type"] = t + "_forward"
    lf["unique_id"] = lf["particle"].apply(lambda x: f"{t}_forward_{int(x)}")

    # Reverse pass
    lr, _, _ = custom_link(sub.copy(), rev_params, stats_r, asc=False, tracker_name=t + "_reverse")
    lr["link_direction"] = "reverse"
    lr["tracker_type"] = t + "_reverse"
    lr["unique_id"] = lr["particle"].apply(lambda x: f"{t}_reverse_{int(x)}")

    return lf, lr


def find_overlaps(df, dt, mcf):
    """
    Find overlapping track pairs that are within 'dt' pixels for at least 'mcf'
    consecutive frames. Returns:
      - osm: dict[(idA,idB)] -> list of (start_frame, end_frame) segments
      - tl: dict unique_id -> track length (# rows)
    """
    # Group by frame to perform per-frame proximity queries
    frame_groups = {f: group for f, group in df.groupby("frame")}
    tl = df.groupby("unique_id").size().to_dict()  # track lengths

    # For each frame, record pairs closer than dt
    od = defaultdict(list)
    for f, group in frame_groups.items():
        pos = group[["x", "y"]].values
        parts = group["unique_id"].values
        if len(pos) == 0:
            continue
        tree = cKDTree(pos)
        for i, j in tree.query_pairs(dt):
            od[tuple(sorted((parts[i], parts[j])))].append(f)

    # Convert per-frame hits into contiguous segments of length >= mcf
    osm = {}
    for c, fs in od.items():
        fs = sorted(fs)
        seg = []
        s = fs[0]
        p = fs[0]
        for f in fs[1:]:
            if f == p + 1:
                p = f
            else:
                if p - s + 1 >= mcf:
                    seg.append((s, p))
                s = f
                p = f
        if p - s + 1 >= mcf:
            seg.append((s, p))
        if seg:
            osm[c] = seg
    return osm, tl


def filter_tracks(df, osm, tl):
    """
    For each overlapping pair, keep the longer track and remove the shorter
    within the overlapping frame ranges only.
    """
    rem = set()
    for cl, segs in osm.items():
        cll = {p: tl[p] for p in cl}
        lp = max(cll, key=cll.get)       # track to keep
        sp = [p for p in cl if p != lp]  # tracks to prune in overlap
        for s, e in segs:
            for p in sp:
                indices = df[(df["unique_id"] == p) & (df["frame"] >= s) & (df["frame"] <= e)].index.tolist()
                rem.update(indices)
    return df.drop(rem).reset_index(drop=True)


def main():
    """
    Pipeline:
      1) Load detections (+ raw luminosity).
      2) Split by original tracker_type (e.g. 'small'/'big').
      3) Run forward and reverse linking; combine results.
      4) Save combined tracks CSV (v2).
      5) Detect overlaps and prune shorter duplicates.
      6) Drop short tracks (< min_links) and save filtered CSV.
    """
    _, _, _, fwd_params, rev_params, end_frame, min_links = get_config()

    # List of detection CSVs to process (single file here)
    det_files = ["/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K/particle_detections.csv"]

    for det_file in det_files:
        # Load detections with luminosity for the appearance term
        det = pd.read_csv(det_file, usecols=["frame", "x", "y", "tracker_type", "raw_luminosity"])

        # Track each subtype independently (e.g., 'small' and 'big')
        unique_types = det["tracker_type"].unique()
        res = []
        for t in unique_types:
            sub = det[det["tracker_type"] == t].copy()
            lf, lr = process_tracker_type((t, sub, fwd_params, rev_params))
            res.extend([lf, lr])

        # Combine forward+reverse outputs
        comb = pd.concat(res, ignore_index=True)
        total_tracks = comb["unique_id"].nunique()
        total_links = len(comb) - total_tracks
        print("Combined Tracker: Total tracks: {}; Total links: {}".format(total_tracks, total_links))

        # Move 'raw_luminosity' to the end of the column order (purely cosmetic)
        cols = [col for col in comb.columns if col != "raw_luminosity"] + ["raw_luminosity"]
        comb = comb[cols]

        # Save combined tracks
        out_track = os.path.join(os.path.dirname(det_file), "particle_tracks_v2.csv")
        comb.to_csv(out_track, index=False)

        # Overlap detection parameters
        dt = 3   # spatial radius to consider overlap (pixels)
        mcf = 2  # minimum consecutive frames to qualify as an overlapping segment

        # Compute overlaps and filter shorter track segments in conflict
        osm, tl = find_overlaps(comb, dt, mcf)
        filt = filter_tracks(comb, osm, tl)

        # Keep only tracks with at least 'min_links' detections after pruning
        cnt = filt["unique_id"].value_counts()
        filt = filt[filt["unique_id"].isin(cnt[cnt >= min_links].index)].reset_index(drop=True)

        # Report filtered stats
        filt_tracks = filt["unique_id"].nunique()
        filt_links = len(filt) - filt_tracks
        print("Filtered Tracker: Total tracks: {}; Total links: {}".format(filt_tracks, filt_links))

        # Reorder columns again for consistency and save
        cols = [col for col in filt.columns if col != "raw_luminosity"] + ["raw_luminosity"]
        filt = filt[cols]
        out_filt = os.path.join(os.path.dirname(det_file), "filtered_particle_tracks_v2.csv")
        filt.to_csv(out_filt, index=False)


if __name__ == "__main__":
    main()


small_forward: Processed 14889 candidates; Total distance cost: 231078.434, Total angle cost: 13452.558, Total lum cost: 22525.845; Total skipped (max_lum_diff): 1134; Number of tracks: 2319; Number of links: 5969
small_reverse: Processed 14809 candidates; Total distance cost: 228085.767, Total angle cost: 12648.689, Total lum cost: 22447.516; Total skipped (max_lum_diff): 1141; Number of tracks: 2583; Number of links: 5705
big_forward: Processed 13773 candidates; Total distance cost: 219520.097, Total angle cost: 12261.695, Total lum cost: 21042.776; Total skipped (max_lum_diff): 929; Number of tracks: 2088; Number of links: 5529
big_reverse: Processed 13697 candidates; Total distance cost: 218273.187, Total angle cost: 11548.801, Total lum cost: 21001.235; Total skipped (max_lum_diff): 938; Number of tracks: 2330; Number of links: 5287
Combined Tracker: Total tracks: 9320; Total links: 22490
Filtered Tracker: Total tracks: 123; Total links: 2752


Version 2b

In [6]:
import os
import pandas as pd
import numpy as np
from numba import njit
from scipy.spatial import cKDTree
from collections import defaultdict

def get_config():
    """
    Define tracking hyperparameters (forward and reverse) and global settings.
    """
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"

    # Forward-pass parameters (ascending frame order)
    fwd_params = {
        "memory": 0,  # tolerated consecutive missed frames before dropping a track

        # Speed regime thresholds (pixels/frame)
        "slow_max_speed": 6,
        "medium_max_speed": 12,
        "high_max_speed": 65,

        # Allowed direction change (radians) per regime
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4*np.pi, "upper": 0.4*np.pi},
        "high_direction_change_limits": {"lower": -0.15*np.pi, "upper": 0.15*np.pi},

        # Allowed speed change (Δ pixels/frame) per regime
        "low_speed_change_limits": {"lower": -30, "upper": 50},
        "medium_speed_change_limits": {"lower": -40, "upper": 55},
        "high_speed_change_limits": {"lower": -50, "upper": 60},

        # Multi-objective matching weights and luminosity gate
        "w_distance": 1.0,  # weight for spatial distance
        "w_angle": 1.0,     # weight for turning angle
        "w_lum": 0.1,       # weight for luminosity difference
        "max_lum_diff": 100 # hard cutoff for appearance mismatch
    }

    # Reverse-pass parameters (descending frame order)
    rev_params = {
        "memory": 0,
        "slow_max_speed": 6,
        "medium_max_speed": 12,
        "high_max_speed": 65,
        "low_direction_change_limits": {"lower": -np.pi, "upper": np.pi},
        "medium_direction_change_limits": {"lower": -0.4*np.pi, "upper": 0.4*np.pi},
        "high_direction_change_limits": {"lower": -0.15*np.pi, "upper": 0.15*np.pi},
        "low_speed_change_limits": {"lower": -50, "upper": 30},
        "medium_speed_change_limits": {"lower": -55, "upper": 40},
        "high_speed_change_limits": {"lower": -60, "upper": 50},
        "w_distance": 1.0,
        "w_angle": 1.0,
        "w_lum": 0.1,
        "max_lum_diff": 100
    }

    end_frame = 600  # not directly used below; kept for consistency with your pipeline
    min_links = 10   # minimum detections per track to retain
    return (None, None, None, fwd_params, rev_params, end_frame, min_links)


class MemoryManager:
    """
    Holds active track states and simple lifecycle counters.
    """
    def __init__(self, memory):
        self.memory = memory          # tolerated consecutive misses
        self.active = {}              # track_id -> state dict
        self.length = {}              # track_id -> number of assigned detections

    def add(self, ids, positions, luminosities):
        """
        Start new tracks with initial (x,y) positions and luminosities.
        """
        for i, p in enumerate(ids):
            self.active[p] = {
                "position": positions[i],          # last known position
                "history": [positions[i]],         # trajectory history
                "luminosity": luminosities[i],     # last luminosity
                "lum_history": [luminosities[i]],  # luminosity history
                "frames_lost": 0,                  # consecutive frames unmatched
                "previous_speed": 0,               # last speed (pixels/frame)
                "speed_mode": "low"                # 'low'|'medium'|'high'
            }
            self.length[p] = 1

    def update(self, ids, new_positions, new_luminosities):
        """
        Update matched tracks with new positions and luminosities.
        """
        for p in ids:
            if p in self.active:
                self.active[p]["history"].append(new_positions[p])
                self.active[p]["position"] = new_positions[p]
                self.active[p]["lum_history"].append(new_luminosities[p])
                self.active[p]["luminosity"] = new_luminosities[p]
                self.length[p] += 1


class Statistics:
    """
    Collects diagnostics for rejected candidates and cumulative distances per regime.
    """
    def __init__(self):
        self.dir_stats = {"low": [], "medium": [], "high": []}     # rejected angle values
        self.speed_stats = {"low": [], "medium": [], "high": []}   # rejected speed-change values
        self.distances = {"low": 0, "medium": 0, "high": 0}        # accepted distances

    def update_dir(self, mode, angle, params):
        self.dir_stats[mode].append(angle)

    def update_speed(self, mode, speed_change, params):
        self.speed_stats[mode].append(speed_change)

    def add_distance(self, mode, distance):
        self.distances[mode] += distance


def calc_angle_jit(prev, cur, nw):
    """
    Signed turning angle between (prev->cur) and (cur->nw) in radians.
    Returns 0 for degenerate segments.
    """
    m1 = cur - prev
    m2 = nw - cur
    n1 = np.sqrt(m1[0]**2 + m1[1]**2)
    n2 = np.sqrt(m2[0]**2 + m2[1]**2)
    if n1 == 0 or n2 == 0:
        return 0.0
    dp = (m1[0]*m2[0] + m1[1]*m2[1]) / (n1*n2)
    dp = min(max(dp, -1.0), 1.0)  # numerical safety
    ang = np.arccos(dp)
    # Use cross-product sign to assign orientation
    if m1[0]*m2[1] - m1[1]*m2[0] < 0:
        ang = -ang
    return ang


def calc_angle(prev, cur, nw):
    """
    Safe wrapper: handle None and ensure float inputs.
    """
    if prev is None:
        return 0.0
    return calc_angle_jit(np.asarray(prev, dtype=np.float64),
                          np.asarray(cur,  dtype=np.float64),
                          np.asarray(nw,   dtype=np.float64))


def custom_link(dets, params, stats, asc=True, tracker_name=""):
    """
    Greedy, per-frame linking with composite matching cost and physical constraints:
      • Composite cost = w_distance*distance + w_angle*|angle| + w_lum*|Δlum|
      • Candidate gate on |Δlum| <= max_lum_diff
      • After provisional pairing, enforce regime-specific direction & speed-change limits
    """
    mem = params["memory"]
    max_dist = params["high_max_speed"]  # KD-tree search radius (pixels)

    # Cost weights and luminosity gate
    w_distance = params["w_distance"]
    w_angle = params["w_angle"]
    w_lum = params["w_lum"]
    max_lum_diff = params["max_lum_diff"]

    pid = 0  # next track id to allocate

    # Sort by frame according to pass direction
    dets = dets.sort_values("frame", ascending=asc).reset_index(drop=True)
    dets["particle"] = np.nan

    mm = MemoryManager(mem)

    # Diagnostics
    l1 = l2 = 0                   # number of accepted 1st/2nd-nearest links
    overall_candidate_count = 0
    overall_distance_cost = 0.0
    overall_angle_cost = 0.0
    overall_lum_cost = 0.0
    overall_skip_count = 0        # candidates skipped by luminosity gate

    # Iterate frames
    for f, fd in dets.groupby("frame"):
        # Per-frame diagnostics
        frame_distance_cost = 0.0
        frame_angle_cost = 0.0
        frame_lum_cost = 0.0
        frame_candidate_count = 0
        frame_skip_count = 0

        # Current detections
        pos = fd[["x", "y"]].values
        lum_arr = fd["raw_luminosity"].values
        idxs = fd.index.values

        # No active tracks: spawn all as new
        if not mm.active:
            new_ids = list(range(pid, pid + len(fd)))
            mm.add(new_ids, list(pos), list(lum_arr))
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # Prefer longer-lived active tracks first
        act_ids = sorted(mm.active.keys(), key=lambda p: mm.length[p], reverse=True)
        active = mm.active

        # Edge case: still none
        if not act_ids:
            new_ids = list(range(pid, pid + len(fd)))
            mm.add(new_ids, list(pos), list(lum_arr))
            dets.loc[idxs, "particle"] = new_ids
            pid += len(fd)
            continue

        # KD-tree over active last positions
        act_pos = np.array([active[p]["position"] for p in act_ids])
        tree = cKDTree(act_pos)

        # For each detection, query up to two nearest active tracks
        dists, a_idxs = tree.query(pos, k=2, distance_upper_bound=max_dist)

        # Build candidate list with composite cost
        poss_list = []
        for i in range(len(pos)):
            for k in range(2):
                if a_idxs[i, k] >= len(act_ids) or dists[i, k] == np.inf:
                    continue
                aid = act_ids[a_idxs[i, k]]
                part = active[aid]

                # Angle using previous point (if available)
                prev_pt = part["history"][-2] if len(part["history"]) >= 2 else None
                angle = calc_angle(prev_pt, part["position"], pos[i])

                # Appearance gate
                lum_diff = abs(lum_arr[i] - part["luminosity"])
                if lum_diff > max_lum_diff:
                    frame_skip_count += 1
                    continue

                # Composite cost
                distance_cost = w_distance * dists[i, k]
                angle_cost = w_angle * abs(angle)
                lum_cost = w_lum * lum_diff
                total_cost = distance_cost + angle_cost + lum_cost

                # Accumulate diagnostics
                frame_distance_cost += distance_cost
                frame_angle_cost += angle_cost
                frame_lum_cost += lum_cost
                frame_candidate_count += 1

                # (track_id, det_idx_in_pos, det_row_idx, angle, dist, rank(1|2), total_cost)
                poss_list.append((aid, i, idxs[i], angle, dists[i, k], k+1, total_cost))

        # Update global diagnostics
        overall_candidate_count += frame_candidate_count
        overall_distance_cost += frame_distance_cost
        overall_angle_cost += frame_angle_cost
        overall_lum_cost += frame_lum_cost
        overall_skip_count += frame_skip_count

        # Greedy assignment by lowest total_cost
        poss_list.sort(key=lambda x: x[6])
        used_a = set()  # matched tracks this frame
        used_d = set()  # matched detections this frame
        fm = []         # accepted matches: (aid, i, rank)

        for candidate in poss_list:
            aid, i, dii, angle, dist_val, rank, cost = candidate
            if aid in used_a or i in used_d:
                continue

            part = active.get(aid)
            if part is None:
                continue

            # Recompute final checks using freshest state
            prev_pt = part["history"][-2] if len(part["history"]) >= 2 else None
            angle = calc_angle(prev_pt, part["position"], pos[i])
            disp = pos[i] - part["position"]
            spd = np.linalg.norm(disp)
            spd_ch = spd - part["previous_speed"]

            # Determine old/new regimes to build cumulative limits
            cm = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            pm = part["speed_mode"]
            modes = ["low", "medium", "high"]
            pi, ci = modes.index(pm), modes.index(cm)
            seq = modes[pi:ci+1] if pi <= ci else modes[ci:pi+1]

            # Aggregate limits along the transition path
            cs_lim = {"lower": -np.inf, "upper": np.inf}  # speed-change
            cd_lim = {"lower": -np.inf, "upper": np.inf}  # direction-change
            for m in seq:
                sp_lim = params[f"{m}_speed_change_limits"]
                d_lim = params[f"{m}_direction_change_limits"]
                cs_lim["lower"] = max(cs_lim["lower"], sp_lim["lower"])
                cs_lim["upper"] = min(cs_lim["upper"], sp_lim["upper"])
                cd_lim["lower"] = max(cd_lim["lower"], d_lim["lower"])
                cd_lim["upper"] = min(cd_lim["upper"], d_lim["upper"])

            # Enforce constraints; record rejects and optionally retire stale tracks
            rej = False
            if not (cd_lim["lower"] <= angle <= cd_lim["upper"]):
                stats.update_dir(cm, angle, params)
                rej = True
            if not (cs_lim["lower"] <= spd_ch <= cs_lim["upper"]):
                stats.update_speed(cm, spd_ch, params)
                rej = True
            if rej:
                part["frames_lost"] += 1
                if part["frames_lost"] > params["memory"]:
                    del active[aid]
                    del mm.length[aid]
                continue

            # Accept pairing
            dets.at[dii, "particle"] = aid
            fm.append((aid, i, rank))
            used_a.add(aid)
            used_d.add(i)
            l1 += (rank == 1)
            l2 += (rank == 2)

        # Commit accepted updates and update per-track regime/speed
        new_pos = {}
        new_lum = {}
        mids = []
        for aid, i, _ in fm:
            new_pos[aid] = pos[i]
            new_lum[aid] = lum_arr[i]
            spd = np.linalg.norm(pos[i] - active[aid]["position"])
            sm_val = "low" if spd <= params["slow_max_speed"] else ("medium" if spd <= params["medium_max_speed"] else "high")
            active[aid]["speed_mode"] = sm_val
            active[aid]["previous_speed"] = spd
            stats.add_distance(sm_val, spd)
            mids.append(aid)
        mm.update(mids, new_pos, new_lum)

        # Unmatched detections: spawn new tracks
        un = set(range(len(fd))) - used_d
        if un:
            new_ids = list(range(pid, pid + len(un)))
            pos_un = [pos[i] for i in un]
            lum_un = [lum_arr[i] for i in un]
            mm.add(new_ids, pos_un, lum_un)
            for p, i in zip(new_ids, un):
                dets.at[idxs[i], "particle"] = p
            pid += len(un)

    # Summary diagnostics for this pass
    num_tracks = dets["particle"].nunique()
    num_links = len(dets) - num_tracks
    print("{}: Processed {} candidates; Total distance cost: {:.3f}, Total angle cost: {:.3f}, Total lum cost: {:.3f}; Total skipped: {}; Number of tracks: {}; Number of links: {}".format(
        tracker_name, overall_candidate_count, overall_distance_cost, overall_angle_cost, overall_lum_cost,
        overall_skip_count, num_tracks, num_links))

    # Nullable integer for track ids
    dets["particle"] = dets["particle"].astype("Int64")

    # Return as-is (already ordered per 'asc' during build)
    return dets, l1, l2


def process_tracker_type(args):
    """
    Run forward and reverse linking for a given detection subtype,
    then tag outputs with direction and construct unique ids.
    """
    t, sub, fwd_params, rev_params = args
    stats_f = Statistics()
    stats_r = Statistics()

    # Forward pass
    lf, l1_f, l2_f = custom_link(sub.copy(), fwd_params, stats_f, asc=True, tracker_name=t+"_forward")
    lf["link_direction"] = "forward"
    lf["tracker_type"] = t+"_forward"
    lf["unique_id"] = lf["particle"].apply(lambda x: f"{t}_forward_{int(x)}")

    # Reverse pass
    lr, l1_r, l2_r = custom_link(sub.copy(), rev_params, stats_r, asc=False, tracker_name=t+"_reverse")
    lr["link_direction"] = "reverse"
    lr["tracker_type"] = t+"_reverse"
    lr["unique_id"] = lr["particle"].apply(lambda x: f"{t}_reverse_{int(x)}")

    return lf, lr


def main():
    """
    Pipeline:
      1) Load detections with luminosity.
      2) For each original tracker_type (e.g., 'small'/'big'), run forward & reverse linking.
      3) Collect all candidate track sets.
      4) Select the single best set by maximum #links (len(rows) - #tracks).
      5) Drop short tracks (< min_links) and export CSV.
    """
    _, _, _, fwd_params, rev_params, end_frame, min_links = get_config()

    det_files = ["/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K/particle_detections.csv"]
    candidates = []  # store each (type, direction) tracking output

    for det_file in det_files:
        # Load detections required by the linker
        det = pd.read_csv(det_file, usecols=["frame", "x", "y", "tracker_type", "raw_luminosity"])

        # Build candidates per subtype
        unique_types = det["tracker_type"].unique()
        for t in unique_types:
            sub = det[det["tracker_type"] == t].copy()
            lf, lr = process_tracker_type((t, sub, fwd_params, rev_params))
            candidates.append(lf)
            candidates.append(lr)

    # Choose the best candidate set by #links (a proxy for continuity)
    best_df = None
    best_links = -1
    best_tracker = ""
    for df in candidates:
        tracks = df["unique_id"].nunique()
        links = len(df) - tracks
        if links > best_links:
            best_links = links
            best_df = df
            best_tracker = df["tracker_type"].iloc[0]

    # Prune short tracks
    cnt = best_df["unique_id"].value_counts()
    best_df = best_df[best_df["unique_id"].isin(cnt[cnt >= min_links].index)].reset_index(drop=True)

    # Final stats and export
    total_tracks = best_df["unique_id"].nunique()
    total_links = len(best_df) - total_tracks
    print("Selected Tracker: {}: Total tracks: {}; Total links: {}".format(best_tracker, total_tracks, total_links))

    out_track = os.path.join(os.path.dirname(det_files[0]), "filtered_particle_tracks_v2.csv")

    # Optional: move 'raw_luminosity' to the end for readability
    cols = [col for col in best_df.columns if col != "raw_luminosity"] + ["raw_luminosity"]
    best_df = best_df[cols]

    best_df.to_csv(out_track, index=False)


if __name__ == "__main__":
    main()


small_forward: Processed 14889 candidates; Total distance cost: 231078.434, Total angle cost: 13452.558, Total lum cost: 22525.845; Total skipped: 1134; Number of tracks: 2319; Number of links: 5969
small_reverse: Processed 14809 candidates; Total distance cost: 228085.767, Total angle cost: 12648.689, Total lum cost: 22447.516; Total skipped: 1141; Number of tracks: 2583; Number of links: 5705
big_forward: Processed 13773 candidates; Total distance cost: 219520.097, Total angle cost: 12261.695, Total lum cost: 21042.776; Total skipped: 929; Number of tracks: 2088; Number of links: 5529
big_reverse: Processed 13697 candidates; Total distance cost: 218273.187, Total angle cost: 11548.801, Total lum cost: 21001.235; Total skipped: 938; Number of tracks: 2330; Number of links: 5287
Selected Tracker: small_forward: Total tracks: 110; Total links: 2497


# Tracked Videos

In [None]:
import cv2
import numpy as np
import os
import pandas as pd

def get_config():
    """
    Defines input and output paths, video processing parameters, 
    and returns them as a tuple for use in the pipeline.
    """
    import glob
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"
    video_dir = "/Users/Ricardo/Desktop/Y4 Lab code/Cooldown 10"

    # Try to locate the input video
    heatmap_inp = os.path.join(video_dir, "_video (0).mp4")
    if not os.path.exists(heatmap_inp):
        hits = glob.glob(os.path.join(video_dir, "**", "_video (0).mp4"), recursive=True)
        if hits:
            heatmap_inp = hits[0]
        else:
            raise FileNotFoundError(f"Couldn't find '_video (0).mp4' under {video_dir}")

    # Define outputs for different processing steps
    out0 = os.path.join(base, "Original_Heatmap_NoDetections.mp4")
    out1 = os.path.join(base, "Original_Heatmap.mp4")
    out2 = os.path.join(base, "Background_Removed_Grayscale.mp4")
    out3 = os.path.join(base, "Background_Removed_Heatmap.mp4")
    out4 = os.path.join(base, "Temporal_Gaussian_Blur_Heatmap.mp4")

    # Region of interest (ROI) vertical limits
    yl = 100
    yu = 819
    end_frame = 600

    # Heatmap colour map
    cmap = cv2.COLORMAP_JET

    # Paths for track/detection data
    filt_v1 = os.path.join(base, "filtered_particle_tracks_v1.csv")
    filt_v2 = os.path.join(base, "filtered_particle_tracks_v2.csv")
    det = os.path.join(base, "particle_detections.csv")
    output_video = os.path.join(base, "Show All Tracks.mp4")

    # Forwarding parameters: thresholds for motion speeds
    fwd_params = {"slow_max_speed": 6, "medium_max_speed": 12, "high_max_speed": 65}

    # Video drawing parameters
    vid_params = {"line_thickness": 2, "distance_threshold": 3, "min_consecutive_frames": 2}

    # Colour palettes for track visualisation
    color_palettes = {
        'small': (0, 255, 255),
        'big': (0, 255, 0),
        'small_reverse': (0, 0, 255),
        'big_reverse': (0, 165, 255)
    }

    # Detection marker parameters
    detection_params = {"marker": "ring", "size": 5}

    # General processing parameters
    proc_params = {
        "speed_factor": 1,          # Playback speed (1 = real-time)
        "target_height": 1080,      # Final video height
        "y_min": 0,                 # Crop top
        "y_max": 900,               # Crop bottom
        "start_frame": 11,          # Starting frame index
        "end_frame": 600            # Ending frame index
    }

    return (heatmap_inp, out1, out2, out3, out4, yl, yu, end_frame, cmap,
            filt_v1, filt_v2, det, output_video, fwd_params,
            vid_params, color_palettes, detection_params, proc_params)


def apply_heatmap(frame, gray_roi, yl, yu, gmax, cmap):
    """
    Normalises grayscale ROI and applies a heatmap colour map.
    Places the coloured ROI back into the frame.
    """
    norm = (gray_roi / gmax * 255).astype(np.uint8) if gmax > 0 else np.zeros_like(gray_roi, dtype=np.uint8)
    heat = cv2.applyColorMap(np.clip(norm, 0, 255), cmap)
    out = frame.copy()
    out[yl:yu, :] = heat
    return out


def find_global_max(inp, coeffs, yl, yu, end_frame):
    """
    Finds global maximum grayscale intensity in ROI across frames.
    Used for consistent heatmap normalisation.
    """
    cap = cv2.VideoCapture(inp)
    if not cap.isOpened():
        return None, None
    gm1 = 0
    gm2 = 0
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

    for frm in range(end_frame + 1):
        ret, frame = cap.read()
        if not ret:
            break
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        roi = gray[yl:yu, :]
        gm1 = max(gm1, roi.max())
        gm2 = max(gm2, roi.max())
    cap.release()
    return gm1, gm2


def temporal_gaussian_blur(frames, kernel_size, sigma=1.0):
    """
    Applies temporal Gaussian blur across a sequence of frames 
    (smoothing motion over time).
    """
    radius = kernel_size // 2
    kernel = np.exp(-0.5 * (np.arange(-radius, radius + 1) / sigma) ** 2)
    kernel = kernel / kernel.sum()

    blurred = []
    n = len(frames)

    for i in range(n):
        acc = np.zeros_like(frames[i], dtype=np.float32)
        for j in range(-radius, radius + 1):
            idx = i + j
            # Handle boundary cases
            if idx < 0:
                idx = 0
            elif idx >= n:
                idx = n - 1
            acc += frames[idx].astype(np.float32) * kernel[j + radius]
        blurred.append(np.clip(acc, 0, 255).astype(np.uint8))
    return blurred


def prepare_segments(tracks):
    """
    Converts dataframe of tracks into dictionary of segments by unique ID and frame.
    """
    segs = {}
    for _, row in tracks.iterrows():
        f = int(row["frame"])
        pid = row["unique_id"]
        pos = (int(round(row["x"])), int(round(row["y"])))
        tracker = row["tracker_type"] if "tracker_type" in row else None
        segs.setdefault(pid, {}).setdefault(f, []).append((pos, tracker))
    return segs


def draw_grey_tracks(frame, segs, history, fc, last_det, sm, mm_speed, th, color_palettes, dot_radius):
    """
    Draws particle tracks (grey-scale for speed classification) on a frame.
    Maintains history of tracks for continuity.
    """
    def get_color(tracker, palettes):
        if tracker is None:
            return (255, 255, 255)
        key = tracker.replace('_forward', '') if tracker.endswith('_forward') else tracker
        return palettes.get(key, (255, 255, 255))

    detected = set()

    # Update current detections
    for pid, frames in segs.items():
        if fc in frames:
            detected.add(pid)
            # Reset history if track skipped a frame
            if pid in last_det and fc != last_det[pid] + 1:
                history[pid] = []
            last_det[pid] = fc
            history.setdefault(pid, []).extend(frames[fc])

    # Draw tracks
    for pid in list(history):
        if pid in detected:
            pts = history[pid]
            for (pt1, _), (pt2, _) in zip(pts, pts[1:]):
                d = np.linalg.norm(np.array(pt2) - np.array(pt1))
                # Colour line based on distance (speed proxy)
                line_color = (0, 0, 0) if d < sm else ((128, 128, 128) if d < mm_speed else (255, 255, 255))
                cv2.line(frame, pt1, pt2, line_color, th)
            for (pt, tracker) in pts:
                dot_color = get_color(tracker, color_palettes)
                cv2.circle(frame, pt, dot_radius, dot_color, thickness=-1)
        else:
            # Remove old tracks not detected anymore
            history.pop(pid)
            last_det.pop(pid, None)

    return frame


def process_combined_video_only(heatmap_inp, df_tracks_v1, df_tracks_v2, df_det,
                                fwd_params, vid_params, color_palettes, detection_params,
                                proc_params, gm1, cmap, output_video, yl, yu):
    """
    Main video processing pipeline:
    - Reads video frames
    - Applies heatmap and temporal blur
    - Draws detections and tracks
    - Combines multiple visualisations side-by-side
    - Writes combined output video
    """
    def get_color(tracker, palettes):
        key = tracker.replace('_forward', '') if isinstance(tracker, str) and tracker.endswith('_forward') else tracker
        return palettes.get(key, (255, 255, 255))

    cap = cv2.VideoCapture(heatmap_inp)
    if not cap.isOpened():
        raise FileNotFoundError(f"Could not open video: {heatmap_inp}")

    # Get video info
    total_frames_in_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0

    start_frame = max(0, int(proc_params["start_frame"]))
    end_frame_req = int(proc_params["end_frame"])

    # Adjust end frame if video shorter
    if total_frames_in_video > 0:
        end_frame = min(end_frame_req, total_frames_in_video - 1)
    else:
        end_frame = end_frame_req

    if start_frame > end_frame:
        raise ValueError(f"start_frame ({start_frame}) is after end_frame ({end_frame}). "
                         f"Video frame count: {total_frames_in_video}")

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

    heatmap_frames = []
    roi_frames = []

    # Read frames and apply heatmap
    for fidx in range(start_frame, end_frame + 1):
        ret, frame = cap.read()
        if not ret:
            break
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        roi = gray[yl:yu, :]
        hm = apply_heatmap(frame, roi, yl, yu, gm1, cmap)
        heatmap_frames.append(hm)
        roi_frames.append(hm[yl:yu, :])
    cap.release()

    if len(heatmap_frames) == 0:
        raise RuntimeError(
            f"No frames were read. Check that the video path is correct, and that "
            f"start_frame ({start_frame}) <= end_frame ({end_frame}). "
            f"Detected total_frames_in_video={total_frames_in_video}."
        )

    # Temporal blur of ROI frames
    kernel_size = 5 if len(roi_frames) >= 3 else max(1, len(roi_frames) | 1)
    blurred_roi_frames = temporal_gaussian_blur(roi_frames, kernel_size, 1.0)

    # Video output setup
    h_frame, w_frame, _ = heatmap_frames[0].shape
    target_height = int(proc_params["target_height"])
    y_min_final = int(proc_params["y_min"])
    y_max_final = int(proc_params["y_max"])
    crop_height = max(1, y_max_final - y_min_final)
    scale = target_height / crop_height
    resized_width = int(round(w_frame * scale))
    total_width = 5 * resized_width
    fourcc = cv2.VideoWriter_fourcc(*"avc1")
    out = cv2.VideoWriter(output_video, fourcc, fps, (total_width, target_height))

    # Prepare tracking data
    segs_v1 = prepare_segments(df_tracks_v1)
    segs_v2 = prepare_segments(df_tracks_v2)
    det_frame_groups = {int(f): group for f, group in df_det.groupby("frame")}
    history_grey_v1, last_det_v1 = {}, {}
    history_grey_v2, last_det_v2 = {}, {}

    sm = fwd_params["slow_max_speed"]
    mm_speed = fwd_params["medium_max_speed"]
    th = vid_params["line_thickness"]
    speed_factor = proc_params["speed_factor"]
    dot_radius = 1

    frame_num = start_frame
    i = 0
    total_frames = len(heatmap_frames)

    # Main frame loop
    while i < total_frames:
        base_frame = heatmap_frames[i]

        # Add temporal Gaussian blur
        tgb_frame = base_frame.copy()
        tgb_frame[yl:yu, :] = blurred_roi_frames[i]

        # Add detections
        frame_detect = base_frame.copy()
        if frame_num in det_frame_groups:
            cur_det = det_frame_groups[frame_num]
            for _, row in cur_det.iterrows():
                x = int(row["x"]); y = int(row["y"])
                tracker = row.get("tracker_type", None)
                color = get_color(tracker, color_palettes)
                mk = detection_params["marker"]
                sz = int(detection_params["size"])
                if mk == "circle":
                    cv2.circle(frame_detect, (x, y), sz, color, -1)
                elif mk == "cross":
                    cv2.line(frame_detect, (x - sz, y - sz), (x + sz, y + sz), color, th)
                    cv2.line(frame_detect, (x - sz, y + sz), (x + sz, y - sz), color, th)
                else:  # default: ring
                    cv2.circle(frame_detect, (x, y), sz, color, 1)

        # Add grey tracks from v1 and v2
        frame_grey_v1 = draw_grey_tracks(base_frame.copy(), segs_v1, history_grey_v1,
                                         frame_num, last_det_v1, sm, mm_speed, th, color_palettes, dot_radius)
        frame_grey_v2 = draw_grey_tracks(base_frame.copy(), segs_v2, history_grey_v2,
                                         frame_num, last_det_v2, sm, mm_speed, th, color_palettes, dot_radius)

        # Combine the five different views horizontally
        ordered = [base_frame, tgb_frame, frame_detect, frame_grey_v1, frame_grey_v2]
        crops = []
        for fr in ordered:
            crop = fr[y_min_final:y_max_final, :]
            resized = cv2.resize(crop, (resized_width, target_height))
            crops.append(resized)
        combined_frame = cv2.hconcat(crops)

        # Write output
        out.write(combined_frame)

        # Update frame counters (account for speed factor)
        i += 1
        frame_num += 1
        if speed_factor > 1:  # skip frames
            skip = int(round(speed_factor))
            i += skip - 1
            frame_num += skip - 1
        elif speed_factor < 1:  # duplicate frames
            rep = int(round(1 / speed_factor)) - 1
            for _ in range(rep):
                out.write(combined_frame)

    out.release()
    print("Combined video processing complete")


def main():
    """
    Main entry point:
    - Loads paths and parameters
    - Finds global max for heatmap normalisation
    - Loads track/detection CSVs
    - Runs combined video processing
    """
    (heatmap_inp, out1, out2, out3, out4, yl, yu, end_frame, cmap,
     filt_v1, filt_v2, det, output_video, fwd_params,
     vid_params, color_palettes, detection_params, proc_params) = get_config()

    # Default coefficients (unused in this version)
    default_coeff = {frm: [0, 0, 0, 0, 0, 0] for frm in range(end_frame + 1)}

    # Find global max values for normalisation
    gm1, gm2 = find_global_max(heatmap_inp, default_coeff, yl, yu, end_frame)

    # Load CSVs with track and detection data
    df_tracks_v1 = pd.read_csv(filt_v1)
    df_tracks_v2 = pd.read_csv(filt_v2)
    df_det = pd.read_csv(det)

    print("Detections read:", len(df_det),
          "Unique tracks v1:", len(df_tracks_v1["unique_id"].unique()),
          "Unique tracks v2:", len(df_tracks_v2["unique_id"].unique()))

    # Count links (track continuity measure)
    links_v1 = df_tracks_v1.groupby("unique_id").size().apply(lambda n: n - 1 if n > 0 else 0).sum()
    links_v2 = df_tracks_v2.groupby("unique_id").size().apply(lambda n: n - 1 if n > 0 else 0).sum()
    print("Number of links v1:", links_v1, "Number of links v2:", links_v2)

    # Process video
    process_combined_video_only(heatmap_inp, df_tracks_v1, df_tracks_v2, df_det,
                                fwd_params, vid_params, color_palettes, detection_params,
                                proc_params, gm1, cmap, output_video, yl, yu)


if __name__ == "__main__":
    main()


Detections read: 15905 Unique tracks v1: 130 Unique tracks v2: 110
Number of links v1: 2529 Number of links v2: 2497


# Manual Tracks

In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import ipywidgets as widgets
import ipyevents
from IPython.display import display, HTML
import copy
import asyncio
import uuid

def get_config():
    """
    Centralised configuration: file paths, UI sizes, and processing parameters.
    """
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"
    heatmap_inp = os.path.join("/Users/Ricardo/Desktop/Y4 Lab code/Cooldown 10", "_video (0).mp4")

    # ROI vertical bounds and frame limit for scanning
    yl, yu, end_frame = 100, 819, 600

    # Colour map for heatmap overlay
    cmap = cv2.COLORMAP_JET

    # Tracks CSV (auto-replaced by updated_file if present)
    tracks_csv_path = os.path.join(base, "filtered_particle_tracks_v2.csv")

    # Frame range to load from the video
    proc_params = {"start_frame": 11, "end_frame": 600}

    # Widget display width for side-by-side panels
    display_width = 200

    # Line thickness for track segments
    linke_thickness = 1

    # Where interactive edits are saved
    updated_file = os.path.join(base, "filtered_particle_tracks_updated.csv")

    return (heatmap_inp, yl, yu, end_frame, cmap, tracks_csv_path,
            proc_params, display_width, linke_thickness, updated_file)


def temporal_gaussian_blur(frames, kernel_size, sigma=1.0):
    """
    Apply 1D Gaussian blur across the time axis to a list of ROI frames.
    Edges are clamped to the nearest valid frame.
    """
    radius = kernel_size // 2
    kernel = np.exp(-0.5 * (np.arange(-radius, radius + 1) / sigma) ** 2)
    kernel /= kernel.sum()

    blurred = []
    n = len(frames)
    for i in range(n):
        acc = np.zeros_like(frames[i], dtype=np.float32)
        for j in range(-radius, radius + 1):
            idx = i + j
            idx = max(0, min(idx, n - 1))  # clamp at boundaries
            acc += frames[idx].astype(np.float32) * kernel[j + radius]
        blurred.append(np.clip(acc, 0, 255).astype(np.uint8))
    return blurred


def find_global_max(video_path, yl, yu, end_frame):
    """
    Scan the first end_frame+1 frames of the input video to get the global
    maximum intensity in the ROI (yl:yu). Used for heatmap normalisation.
    """
    cap = cv2.VideoCapture(video_path)
    gm = 0
    for _ in range(end_frame + 1):
        ret, frame = cap.read()
        if not ret:
            break
        roi = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[yl:yu, :]
        gm = max(gm, int(roi.max()))
    cap.release()
    return gm


def prepare_segments(tracks):
    """
    Convert a tracks DataFrame into a nested dict:
      {unique_id: {frame: [(x, y, BGR_color), ...], ...}, ...}
    If a 'color' column exists, map string labels to BGR tuples; default is brown.
    """
    segs = {}
    has_color = "color" in tracks.columns

    for _, r in tracks.iterrows():
        pid = r["unique_id"]
        f = int(r["frame"])
        x = int(round(r["x"]))
        y = int(round(r["y"]))

        # Default BGR colour (brown-ish)
        color_val = (29, 101, 181)

        if has_color:
            col = r["color"]
            if isinstance(col, str):
                # Map string to BGR
                if col.lower() == "brown":      # default mapping
                    color_val = (29, 101, 181)
                elif col.lower() == "green":
                    color_val = (0, 255, 0)
                elif col.lower() == "yellow":
                    color_val = (0, 255, 255)
                else:
                    color_val = (29, 101, 181)

        segs.setdefault(pid, {}).setdefault(f, []).append((x, y, color_val))
    return segs


def draw_tracks(frame, tracks, history, fc, last_det, linke_thickness):
    """
    Draw ‘existing’ tracks on a frame using their accumulated history.
    History is reset if there’s a gap between consecutive detections.
    """
    dot_radius = 1
    detected = set()

    # Accumulate detections for this frame into history
    for pid, frames_dict in tracks.items():
        if fc in frames_dict:
            detected.add(pid)
            if pid in last_det and fc != last_det[pid] + 1:
                history[pid] = []          # reset path if there was a gap
            last_det[pid] = fc
            history.setdefault(pid, []).extend(frames_dict[fc])

    # Draw paths + points for any track that has data in history at this frame
    for pid in list(history):
        if pid in detected:
            pts = history[pid]
            # Lines between consecutive points
            for pt1, pt2 in zip(pts, pts[1:]):
                cv2.line(frame, pt1[:2], pt2[:2], (0, 0, 0), linke_thickness)
            # Dots at each detection, coloured per-track
            for pt in pts:
                col = pt[2] if len(pt) >= 3 else (29, 101, 181)
                cv2.circle(frame, pt[:2], dot_radius, col, thickness=-1)
        else:
            # Track not seen this frame; drop its history
            history.pop(pid)
            last_det.pop(pid)
    return frame


def draw_manual_track(track, frame_img, current_rel_frame, linke_thickness):
    """
    Draw a user-built manual track up to the current relative frame.
    Track items: (rel_frame, (x,y) or (x,y,?), color[, snapped_track_id])
    """
    dot_radius = 1
    pts = [pt for pt in track if pt[0] <= current_rel_frame]
    if not pts:
        return frame_img

    # Lines only between different frames for visual continuity
    for (f1, pos1, *_), (f2, pos2, *_) in zip(pts, pts[1:]):
        cv2.line(frame_img, pos1[:2], pos2[:2], (0, 0, 0), linke_thickness)

    # Points
    for (f, pos, col, *_) in pts:
        cv2.circle(frame_img, pos[:2] if len(pos) > 2 else pos, dot_radius, col, thickness=-1)
    return frame_img


def draw_merged_track(merged, frame_img, current_abs_frame, linke_thickness):
    """
    Draw a merged track (list of (abs_frame, (x,y[,?]), color)) up to current absolute frame.
    """
    dot_radius = 1
    pts = [pt for pt in merged if pt[0] <= current_abs_frame]
    if not pts:
        return frame_img

    for (f1, pos1, _), (f2, pos2, _) in zip(pts, pts[1:]):
        cv2.line(frame_img, pos1[:2], pos2[:2], (0, 0, 0), linke_thickness)

    for f, pos, col in pts:
        cv2.circle(frame_img, pos[:2] if len(pos) > 2 else pos, dot_radius, col, thickness=-1)
    return frame_img


def interactive_video_widget(normal_frames, summed_frames, segs, start_frame,
                             display_width, linke_thickness, yl, yu, updated_file):
    """
    Build and display the interactive widget:
      • Left pane: current frame (cropped & scaled)
      • Centre pane: working panel with tracks/detections drawn
      • Right pane: controls (toggles, slider, save/undo, simple trim editor)
    Includes:
      - Drawing manual tracks (optionally snap to detections)
      - Merging manual segments into existing tracks
      - Undo stack and keyboard shortcuts
      - Save to CSV (updated_file)
    """
    # Helper to assign unique ids to standalone manual tracks
    def generate_manual_track_id():
        return "manual_" + uuid.uuid4().hex[:8]

    # Nested draw function specific to this widget context (includes frame tags)
    def draw_tracks(frame, tracks, history, fc, last_det, linke_thickness):
        dot_radius = 1
        detected = set()
        for pid, frames_dict in tracks.items():
            if fc in frames_dict:
                detected.add(pid)
                if pid in last_det and fc != last_det[pid] + 1:
                    history[pid] = []
                last_det[pid] = fc
                # Keep the frame index with the point for continuity checks
                for (x, y, col) in frames_dict[fc]:
                    history.setdefault(pid, []).append((x, y, col, fc))
        for pid in list(history):
            if pid in detected:
                pts = history[pid]
                for pt1, pt2 in zip(pts, pts[1:]):
                    if pt1[3] != pt2[3]:
                        cv2.line(frame, pt1[:2], pt2[:2], (0, 0, 0), linke_thickness)
                for pt in pts:
                    cv2.circle(frame, pt[:2], dot_radius, pt[2], thickness=-1)
            else:
                history.pop(pid)
                last_det.pop(pid)
        return frame

    def draw_manual_track(track, frame_img, current_rel_frame, linke_thickness):
        dot_radius = 1
        pts = [pt for pt in track if pt[0] <= current_rel_frame]
        if not pts:
            return frame_img
        for (f1, pos1, col1, snap1), (f2, pos2, col2, snap2) in zip(pts, pts[1:]):
            if f1 != f2:
                cv2.line(frame_img, pos1[:2], pos2[:2], (0, 0, 0), linke_thickness)
        for (f, pos, col, snap) in pts:
            cv2.circle(frame_img, pos[:2], dot_radius, col, thickness=-1)
        return frame_img

    def draw_merged_track(merged, frame_img, current_abs_frame, linke_thickness):
        dot_radius = 1
        pts = [pt for pt in merged if pt[0] <= current_abs_frame]
        if not pts:
            return frame_img
        for (f1, pos1, col1), (f2, pos2, col2) in zip(pts, pts[1:]):
            if f1 != f2:
                cv2.line(frame_img, pos1[:2], pos2[:2], (0, 0, 0), linke_thickness)
        for f, pos, col in pts:
            cv2.circle(frame_img, pos[:2], dot_radius, col, thickness=-1)
        return frame_img

    # Try to load detections to enable “snap to nearest detection” behaviour
    try:
        df_detections = pd.read_csv(os.path.join(os.path.dirname(updated_file), "particle_detections.csv"))
        detection_markers = {}
        for _, r in df_detections.iterrows():
            frm = int(r["frame"])
            pos = (int(round(r["x"])), int(round(r["y"])))
            detection_markers.setdefault(frm, []).append(pos)
    except Exception:
        detection_markers = {}

    # UI controls (with a neutral button colour theme)
    button_color = "#808080"
    video_mode_toggle = widgets.ToggleButton(value=False, description="Summed Video", button_style='')
    video_mode_toggle.style.button_color = button_color

    snap_track_toggle = widgets.ToggleButton(value=True, description="Snap Track", button_style='')
    snap_track_toggle.style.button_color = button_color

    allow_merge_toggle = widgets.ToggleButton(value=False, description="Allow Merge", button_style='')
    allow_merge_toggle.style.button_color = button_color

    draw_toggle = widgets.ToggleButton(value=True, description="Draw Track", button_style='')
    draw_toggle.style.button_color = button_color

    undo_button = widgets.Button(description="Undo", button_style='')
    undo_button.style.button_color = button_color

    save_button = widgets.Button(description="Save Changes", button_style='')
    save_button.style.button_color = button_color

    # Helper to select which video buffer is active (normal vs temporally blurred)
    def get_current_frames():
        return summed_frames if video_mode_toggle.value else normal_frames

    # Render the frame at index i with tracks up to that frame
    def compute_frame(i, tracks):
        base = get_current_frames()
        history = {}
        last_det = {}
        frame_number = start_frame
        interactive_frame = None
        for idx in range(i + 1):
            frame = base[idx].copy()
            frame = draw_tracks(frame, tracks, history, frame_number, last_det, linke_thickness)
            interactive_frame = frame
            frame_number += 1
        return interactive_frame, history

    # Small wrappers to pass configured thickness
    def draw_manual_track_wrapper(track, frame_img, current_abs_frame):
        return draw_manual_track(track, frame_img, current_abs_frame, linke_thickness)

    def draw_merged_track_wrapper(merged, frame_img, current_abs_frame):
        return draw_merged_track(merged, frame_img, current_abs_frame, linke_thickness)

    # Left static image (without overlays) for reference
    left_img_widget = widgets.Image(format="png", layout=widgets.Layout(width=f"{display_width}px"))

    def update_left_image(i):
        frame = get_current_frames()[i]
        cropped = frame[yl:yu, :]
        h, w = cropped.shape[:2]
        scale = display_width / w
        resized = cv2.resize(cropped, (display_width, int(h * scale)))
        _, buf = cv2.imencode(".png", resized)
        left_img_widget.value = buf.tobytes()

    # Main image with overlays (tracks / manual edits)
    img_widget = widgets.Image(format="png", layout=widgets.Layout(width=f"{display_width}px"))

    # Frame slider (0-based index into loaded frames array)
    slider = widgets.IntSlider(value=0, min=0, max=len(normal_frames) - 1, description="", readout=True)

    # Working copies / state
    active_segs = copy.deepcopy(segs)  # can be modified interactively
    merged_tracks = []                 # list of (snap_track_id, merged_points_list)
    current_history = {}
    selected_track = None

    # Manual drawing state
    manual_tracks = []                 # completed manual tracks
    current_manual_track = None        # in-progress manual track
    undo_stack = []                    # actions to allow undo

    # Common colours (BGR)
    brown = (29, 101, 181)
    yellow = (0, 255, 255)
    green = (0, 255, 0)

    # Re-render the UI images for the current slider position
    def update_image(i):
        nonlocal current_history
        update_left_image(i)

        # Base frame with current active track overlays
        frame, history = compute_frame(i, active_segs)
        current_history = history

        current_abs_frame = slider.value + start_frame

        # Draw fully completed manual tracks up to current frame
        for track in manual_tracks:
            if current_abs_frame <= track[-1][0]:
                frame = draw_manual_track_wrapper(track, frame, current_abs_frame)

        # Draw in-progress manual track if present
        if current_manual_track is not None:
            frame = draw_manual_track_wrapper(current_manual_track, frame, current_abs_frame)

        # Draw any merged tracks up to current frame
        for snap_track_id, merged in merged_tracks:
            if current_abs_frame <= merged[-1][0]:
                frame = draw_merged_track_wrapper(merged, frame, current_abs_frame)

        # Crop, scale, and display
        cropped = frame[yl:yu, :]
        h, w = cropped.shape[:2]
        scale = display_width / w
        resized = cv2.resize(cropped, (display_width, int(h * scale)))
        _, buf = cv2.imencode(".png", resized)
        img_widget.value = buf.tobytes()

    # Initial render
    update_image(0)

    # React to slider movement
    slider.observe(lambda change: update_image(change["new"]) if change["name"] == "value" else None, names="value")

    # Small inline editor (trim start/end of a selected track)
    track_label = widgets.Label(value="")
    remove_label = widgets.Label(value="Remove")
    edit_start = widgets.IntText(value=0, description="Start")
    edit_end = widgets.IntText(value=0, description="End")
    apply_button = widgets.Button(description="Apply")
    cancel_button = widgets.Button(description="Cancel")
    edit_controls = widgets.VBox([edit_start, edit_end])
    button_box = widgets.HBox([apply_button, cancel_button])
    edit_box = widgets.VBox([remove_label, edit_controls, button_box])

    # Merge a manual segment into an existing track if snapped at one/both ends
    def merge_manual_track(manual_track):
        start_snap = manual_track[0][3] if len(manual_track[0]) > 3 else None
        end_snap = manual_track[-1][3] if len(manual_track[-1]) > 3 else None
        if start_snap is None and end_snap is None:
            return None

        # Convert to absolute (frame, pos, colour), mark endpoints in green
        manual_abs = []
        for idx, pt in enumerate(manual_track):
            f_val = pt[0]
            pos = pt[1]
            col = pt[2]
            if (idx == 0 and start_snap is not None) or (idx == len(manual_track) - 1 and end_snap is not None):
                col = green
            manual_abs.append((f_val, pos, col))

        merged = []

        # If snapped to two different tracks, stitch start + manual + end
        if start_snap is not None and end_snap is not None and start_snap != end_snap:
            if start_snap in active_segs and end_snap in active_segs:
                existing_pts_start = []
                for f in sorted(active_segs[start_snap].keys()):
                    for p in active_segs[start_snap][f]:
                        existing_pts_start.append((f, (p[0], p[1]), p[2]))

                existing_pts_end = []
                for f in sorted(active_segs[end_snap].keys()):
                    for p in active_segs[end_snap][f]:
                        existing_pts_end.append((f, (p[0], p[1]), p[2]))

                # Avoid duplicating join points
                if existing_pts_start and existing_pts_start[-1][0] == manual_abs[0][0]:
                    existing_pts_start.pop()
                if existing_pts_end and existing_pts_end[0][0] == manual_abs[-1][0]:
                    existing_pts_end.pop(0)

                merged = existing_pts_start + manual_abs + existing_pts_end
            else:
                # If one side is missing, extend the one that exists
                snap_track = start_snap if start_snap in active_segs else end_snap
                if snap_track not in active_segs:
                    return None
                existing_pts = []
                for f in sorted(active_segs[snap_track].keys()):
                    for p in active_segs[snap_track][f]:
                        existing_pts.append((f, (p[0], p[1]), p[2]))
                if start_snap is not None:
                    if existing_pts and existing_pts[-1][0] == manual_abs[0][0]:
                        existing_pts.pop()
                    merged = existing_pts + manual_abs
                else:
                    if existing_pts and existing_pts[0][0] == manual_abs[-1][0]:
                        existing_pts.pop(0)
                    merged = manual_abs + existing_pts
        else:
            # Snapped to only one track: prepend or append manual segment
            snap_track = start_snap if start_snap is not None else end_snap
            if snap_track not in active_segs:
                return None
            existing_pts = []
            for f in sorted(active_segs[snap_track].keys()):
                for p in active_segs[snap_track][f]:
                    existing_pts.append((f, (p[0], p[1]), p[2]))
            if start_snap is not None:
                if existing_pts and existing_pts[-1][0] == manual_abs[0][0]:
                    existing_pts.pop()
                merged = existing_pts + manual_abs
            else:
                if existing_pts and existing_pts[0][0] == manual_abs[-1][0]:
                    existing_pts.pop(0)
                merged = manual_abs + existing_pts

        merged.sort(key=lambda x: x[0])
        return (start_snap if start_snap is not None else end_snap, merged)

    # Track where a manual drawing action began in the undo stack
    manual_track_undo_start = None

    # Toggle for drawing mode
    def on_draw_toggle_change(change):
        nonlocal current_manual_track, manual_tracks, active_segs, merged_tracks, manual_track_undo_start
        if change['new']:
            # Started drawing
            current_manual_track = []
            manual_track_undo_start = len(undo_stack)
        else:
            # Stopped drawing: finalise or merge
            if current_manual_track is not None and len(current_manual_track) > 1:
                if allow_merge_toggle.value:
                    merge_result = merge_manual_track(current_manual_track)
                    if merge_result is not None:
                        snap_track_id, merged = merge_result
                        removed_data = active_segs.pop(snap_track_id, None)  # remove old
                        merged_tracks.append((snap_track_id, merged))        # add merged
                        undo_stack.append(("merge", snap_track_id, removed_data))
                        track_label.value = "Merged track " + str(snap_track_id)
                    else:
                        # Not snapped: store as standalone manual track
                        manual_tracks.append(current_manual_track)
                        if manual_track_undo_start is not None:
                            # Remove in-progress draw actions from undo stack
                            while len(undo_stack) > manual_track_undo_start:
                                undo_stack.pop()
                            manual_track_undo_start = None
                        undo_stack.append(("complete_track",))
                else:
                    manual_tracks.append(current_manual_track)
                    undo_stack.append(("complete_track",))
            current_manual_track = None
            update_image(slider.value)

    draw_toggle.observe(on_draw_toggle_change, names="value")

    # Mouse click handler for picking/placing points
    def handle_click(event):
        nonlocal selected_track, current_manual_track
        x = event.get("relativeX", event.get("offsetX"))
        y = event.get("relativeY", event.get("offsetY"))
        if x is None or y is None:
            return

        # Map widget coords back to original ROI coordinates
        cropped = compute_frame(slider.value, active_segs)[0][yl:yu, :]
        h, w = cropped.shape[:2]
        scale = display_width / w
        orig_x = int(x / scale)
        orig_y = int(y / scale + yl)

        if draw_toggle.value:
            # Enforce strictly consecutive frames when drawing
            if slider.value == 0:
                return
            absolute_frame = slider.value + start_frame
            if current_manual_track:
                last_frame = current_manual_track[-1][0]
                if absolute_frame != last_frame + 1:
                    return

            dot_color = yellow
            snapped_track = None
            current_frame_number = absolute_frame

            # Optionally snap to closest detection to assist accurate clicking
            if snap_track_toggle.value and current_frame_number in detection_markers:
                best = None
                best_dist = 15  # pixel threshold for snapping
                for d in detection_markers[current_frame_number]:
                    dist = np.hypot(orig_x - d[0], orig_y - d[1])
                    if dist < best_dist:
                        best = d
                        best_dist = dist
                if best is not None:
                    orig_x, orig_y = best

                    # If merging is allowed, check if near an existing track endpoint
                    _, history = compute_frame(slider.value, active_segs)
                    if allow_merge_toggle.value:
                        for track_id, pts in history.items():
                            if pts:
                                first_pt = pts[0]
                                last_pt = pts[-1]
                                if np.hypot(orig_x - first_pt[0], orig_y - first_pt[1]) < 15:
                                    dot_color = green
                                    snapped_track = track_id
                                    break
                                elif np.hypot(orig_x - last_pt[0], orig_y - last_pt[1]) < 15:
                                    dot_color = green
                                    snapped_track = track_id
                                    break

            # Begin or extend manual track
            if current_manual_track is None:
                current_manual_track = []
            current_manual_track.append((absolute_frame, (orig_x, orig_y), dot_color, snapped_track))
            undo_stack.append(("draw_point",))
            update_image(slider.value)
        else:
            # Selection mode: pick an existing track near the click
            threshold = 5
            sel_track = ""
            min_dist = float("inf")
            for pid, pts in current_history.items():
                for pt in pts:
                    dist = np.hypot(orig_x - pt[0], orig_y - pt[1])
                    if dist < threshold and dist < min_dist:
                        min_dist = dist
                        sel_track = pid
            selected_track = sel_track
            track_label.value = str(selected_track)

    # Connect click events to the main image widget
    click_ev = ipyevents.Event(source=img_widget, watched_events=["click"])
    click_ev.on_dom_event(handle_click)

    # Apply trimming edits to a selected track (remove start/end segments)
    def apply_edit(b):
        nonlocal active_segs, selected_track
        if not selected_track:
            return
        frames = sorted(active_segs[selected_track].keys())
        if not frames:
            return

        # Save for undo
        original_data = copy.deepcopy(active_segs[selected_track])
        undo_stack.append(("remove_link", selected_track, original_data))

        rs = edit_start.value
        re = edit_end.value
        if rs + re >= len(frames):
            # If removing everything, drop the track entirely
            del active_segs[selected_track]
        else:
            # Keep the middle segment after trimming
            new_frames = frames[rs:len(frames) - re]
            new_data = {f: active_segs[selected_track][f] for f in new_frames}
            active_segs[selected_track] = new_data
        update_image(slider.value)

    def cancel_edit(b):
        nonlocal selected_track
        selected_track = None

    apply_button.on_click(apply_edit)
    cancel_button.on_click(cancel_edit)

    # Undo last action from the undo stack
    def on_undo_button_clicked(b):
        nonlocal current_manual_track, manual_tracks, active_segs, selected_track
        if not undo_stack:
            track_label.value = "Nothing to undo"
            return
        action = undo_stack.pop()

        if action[0] == "draw_point":
            if draw_toggle.value and current_manual_track:
                current_manual_track.pop()
            else:
                if manual_tracks:
                    manual_tracks.pop()

        elif action[0] == "complete_track":
            if manual_tracks:
                manual_tracks.pop()

        elif action[0] == "remove_link":
            track_id, original_data = action[1], action[2]
            active_segs[track_id] = original_data

        elif action[0] == "merge":
            snap_track_id = action[1]
            removed_data = action[2]
            # Remove the merged overlay and restore the original track data
            merged_tracks[:] = [mt for mt in merged_tracks if mt[0] != snap_track_id]
            if removed_data is not None:
                active_segs[snap_track_id] = removed_data

        update_image(slider.value)

    undo_button.on_click(on_undo_button_clicked)

    # Save the current (possibly edited) set of tracks to CSV
    def on_save_button_clicked(b):
        brown = (29, 101, 181)
        yellow = (0, 255, 255)
        green = (0, 255, 0)

        def color_to_str(color):
            if color == brown:
                return "brown"
            elif color == green:
                return "green"
            elif color == yellow:
                return "yellow"
            else:
                return "brown"

        rows = []

        # Existing (possibly trimmed) active tracks
        for tid, frames_dict in active_segs.items():
            for f in sorted(frames_dict.keys()):
                for pos in frames_dict[f]:
                    rows.append({
                        "unique_id": tid,
                        "frame": f,
                        "x": pos[0],
                        "y": pos[1],
                        "color": color_to_str(pos[2])
                    })

        # Merged tracks are stored with a suffixed id
        for snap_track_id, merged in merged_tracks:
            new_id = str(snap_track_id) + "_manual"
            for (f, pos, col) in merged:
                rows.append({
                    "unique_id": new_id,
                    "frame": f,
                    "x": pos[0],
                    "y": pos[1],
                    "color": color_to_str(col)
                })

        # Completed standalone manual tracks get autogenerated ids
        for track in manual_tracks:
            new_id = generate_manual_track_id()
            for (f, pos, col, _) in track:
                rows.append({
                    "unique_id": new_id,
                    "frame": f,
                    "x": pos[0],
                    "y": pos[1],
                    "color": color_to_str(col)
                })

        # If currently drawing, persist partial manual track as well
        if current_manual_track and len(current_manual_track) > 0:
            new_id = generate_manual_track_id()
            for (f, pos, col, _) in current_manual_track:
                rows.append({
                    "unique_id": new_id,
                    "frame": f,
                    "x": pos[0],
                    "y": pos[1],
                    "color": color_to_str(col)
                })

        df = pd.DataFrame(rows)
        df.to_csv(updated_file, index=False)
        track_label.value = "Saved changes"

    save_button.on_click(on_save_button_clicked)

    # Assemble the UI
    controls = widgets.VBox([
        video_mode_toggle, slider, draw_toggle, snap_track_toggle,
        allow_merge_toggle, undo_button, save_button, track_label, edit_box
    ])
    container = widgets.HBox([left_img_widget, img_widget, controls],
                             layout=widgets.Layout(align_items="center"))
    container.add_class("focusable")

    # Keyboard shortcuts with key repeat for frame stepping
    keys_down = set()
    repeating_tasks = {}

    def handle_keyboard_event(event):
        nonlocal keys_down, repeating_tasks
        key = event.get("key")

        if event["type"] == "keydown":
            if key == "e":
                snap_track_toggle.value = not snap_track_toggle.value
            elif key == "r":
                on_undo_button_clicked(None)
            elif key == "d":
                draw_toggle.value = not draw_toggle.value
            elif key == "f":
                on_save_button_clicked(None)
            elif key == "c":
                allow_merge_toggle.value = not allow_merge_toggle.value
            elif key in ("q", "w"):
                if key not in keys_down:
                    keys_down.add(key)

                    async def repeat_key_action():
                        while key in keys_down:
                            if key == "q":
                                slider.value = max(slider.min, slider.value - 1)
                            elif key == "w":
                                slider.value = min(slider.max, slider.value + 1)
                            await asyncio.sleep(0.1)

                    repeating_tasks[key] = asyncio.create_task(repeat_key_action())

        elif event["type"] == "keyup":
            if key in keys_down:
                keys_down.remove(key)
            if key in repeating_tasks:
                repeating_tasks[key].cancel()
                del repeating_tasks[key]

    # Attach keyboard listeners to the whole container
    keyboard_ev = ipyevents.Event(source=container, watched_events=["keydown", "keyup"])
    keyboard_ev.on_dom_event(handle_keyboard_event)

    # Simple dark theme for the container
    container.add_class("dark-container")
    display(HTML("""
    <style>
      .dark-container { background-color: black; padding: 10px; }
      .dark-container .widget-label, .dark-container .widget-readout, .widget-value { color: white; }
      .focusable { outline: none; }
    </style>
    """))

    # Show the widget
    display(container)


def main():
    """
    Load config and data, create normal and temporally blurred panels,
    then launch the interactive widget for manual review/editing.
    """
    (heatmap_inp, yl, yu, end_frame, cmap, tracks_csv_path, proc_params,
     display_width, linke_thickness, updated_file) = get_config()

    # Determine global max for heatmap normalisation
    gm = find_global_max(heatmap_inp, yl, yu, end_frame)

    # Prefer previously saved edits, else original tracks file
    if os.path.exists(updated_file):
        df_tracks = pd.read_csv(updated_file)
    else:
        df_tracks = pd.read_csv(tracks_csv_path)

    # Load frames from the requested range
    cap = cv2.VideoCapture(heatmap_inp)
    cap.set(cv2.CAP_PROP_POS_FRAMES, proc_params["start_frame"])
    frames = []
    for _ in range(proc_params["start_frame"], proc_params["end_frame"] + 1):
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()

    # Build normal (JET heatmap) frames and ROI-only frames for temporal blur
    normal_frames = []
    roi_frames = []
    for frame in frames:
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        roi = gray[yl:yu, :]

        # Normalise to [0,255] and invert before applying the colour map
        norm = (roi / gm * 255).astype(np.uint8) if gm > 0 else np.zeros_like(roi, dtype=np.uint8)
        heat = cv2.applyColorMap(255 - np.clip(norm, 0, 255), cmap)

        hm = frame.copy()
        hm[yl:yu, :] = heat
        rgb_frame = cv2.cvtColor(hm, cv2.COLOR_BGR2RGB)

        normal_frames.append(rgb_frame)
        roi_frames.append(rgb_frame[yl:yu, :].copy())

    # Temporal blur on ROI, then stitch back to full-height frames as a “summed” view
    blurred_roi_frames = temporal_gaussian_blur(roi_frames, 5, 1.0)
    summed_frames = []
    for i, frame in enumerate(normal_frames):
        new_frame = frame.copy()
        new_frame[yl:yu, :] = blurred_roi_frames[i]
        summed_frames.append(new_frame)

    # Convert the tracks DataFrame to quick-lookup segments by id and frame
    segs = prepare_segments(df_tracks)

    # Launch interactive widget
    interactive_video_widget(normal_frames, summed_frames, segs, proc_params["start_frame"],
                             display_width, linke_thickness, yl, yu, updated_file)


if __name__ == "__main__":
    main()


# Refined Tracked Videos

In [None]:
import cv2
import numpy as np
import os
import pandas as pd

def get_config():
    """
    Centralise file paths and rendering/processing parameters.
    Returns:
        heatmap_inp (str): path to input video.
        cmap (int): OpenCV colour map for heatmap overlay.
        tracks_csv_path (str): path to CSV with track points.
        proc_params (dict): processing config (frame range, sizes, etc.).
        base (str): base output directory.
        output_format (str): container/codec choice ("mp4" or "avi").
    """
    base = "/Users/Ricardo/Desktop/Y4 Lab code/Work 0.3K"
    video_folder = "/Users/Ricardo/Desktop/Y4 Lab code/Cooldown 10"
    video_filename = "_video (0).mp4"
    heatmap_inp = os.path.join(video_folder, video_filename)

    cmap = cv2.COLORMAP_JET

    # Input tracks CSV: columns expected ['unique_id','frame','x','y']
    tracks_csv_path = os.path.join(base, "draw_tracks.csv")

    # Main processing parameters
    proc_params = {
        "start_frame": 11,     # index of first frame to load/process (0-based)
        "end_frame": 311,      # last absolute frame index to process (inclusive)
        "target_height": 1080, # output panel height (each panel is resized to this)
        "y_min": 100,          # top of ROI (rows are [y_min:y_max])
        "y_max": 900,          # bottom of ROI (exclusive)
        "speed_factor": 1,     # reserved (not used here) for fast/slow playback
        "bar_thickness": 35    # width of centre separator with frame counter text
    }

    output_format = "mp4"      # choose "mp4" (H.264) or "avi" (MJPG)

    return heatmap_inp, cmap, tracks_csv_path, proc_params, base, output_format


def find_global_max(video_path, y_min, y_max, end_frame):
    """
    Scan frames up to end_frame to get the global max grey intensity in the ROI.
    Used to normalise the heatmap so colour scaling is consistent across frames.

    Args:
        video_path: path to video.
        y_min, y_max: ROI vertical bounds.
        end_frame: last frame index to consider.

    Returns:
        gm (int): global maximum pixel value within ROI across scanned frames.
    """
    cap = cv2.VideoCapture(video_path)
    gm = 0
    for _ in range(end_frame + 1):
        ret, frame = cap.read()
        if not ret:
            break
        roi = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[y_min:y_max, :]
        gm = max(gm, int(roi.max()))
    cap.release()
    return gm


def prepare_segments(tracks):
    """
    Convert the tracks DataFrame into an indexable structure by id then frame.

    Input DataFrame columns expected:
        - unique_id: track identifier (string or int)
        - frame: absolute frame number (int)
        - x, y: integer pixel coordinates

    Returns:
        segs (dict):
            {
              unique_id: {
                frame_number: [(x, y), (x, y), ...],
                ...
              },
              ...
            }
    """
    segs = {}
    for _, r in tracks.iterrows():
        pid = r["unique_id"]
        f, x, y = int(r["frame"]), int(round(r["x"])), int(round(r["y"]))
        segs.setdefault(pid, {}).setdefault(f, []).append((x, y))
    return segs


def draw_tracks(frame, tracks, history, fc, last_det, link_thickness):
    """
    Draw track paths and points up to the current frame.

    Args:
        frame (np.ndarray): BGR frame to draw on.
        tracks (dict): output of prepare_segments().
        history (dict): mutable; stores concatenated points per track across frames.
        fc (int): current absolute frame number being drawn.
        last_det (dict): mutable; last frame each track was detected (for gap reset).
        link_thickness (int): line thickness for path segments.

    Returns:
        The input frame with white points and black path lines rendered.
    """
    dot_radius = 1
    detected = set()

    # Aggregate current-frame detections into the running history
    for pid, frames_dict in tracks.items():
        if fc in frames_dict:
            detected.add(pid)
            # If a gap occurred, reset the path so lines don't jump across time
            if pid in last_det and fc != last_det[pid] + 1:
                history[pid] = []
            last_det[pid] = fc
            history.setdefault(pid, []).extend(frames_dict[fc])

    # Draw each track's accumulated path and points
    for pid in list(history):
        if pid in detected:
            pts = history[pid]
            if len(pts) >= 2:
                # Path as black polylines between consecutive points
                for pt1, pt2 in zip(pts, pts[1:]):
                    cv2.line(frame, pt1, pt2, (0, 0, 0), link_thickness)
                # Points as small white dots
                for pt in pts:
                    cv2.circle(frame, pt, dot_radius, (255, 255, 255), -1)
            else:
                # Single isolated point: draw just the dot
                pt = (int(pts[0][0]), int(pts[0][1]))
                cv2.circle(frame, pt, dot_radius, (255, 255, 255), -1)
        else:
            # If the track doesn't appear this frame, drop it from history
            history.pop(pid)
            last_det.pop(pid, None)
    return frame


def apply_heatmap(frame, y_min, y_max, gm, cmap):
    """
    Colourise a frame's ROI using an OpenCV colour map and global normalisation.

    Args:
        frame: BGR frame (modified in place).
        y_min, y_max: ROI vertical bounds.
        gm: global maximum for normalisation (avoid frame-to-frame colour shifts).
        cmap: cv2 colormap constant (e.g., COLORMAP_JET).

    Returns:
        The same frame with its ROI replaced by the heatmap colourised ROI.
    """
    roi = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[y_min:y_max, :]
    # Normalise to [0,255] using global max; clip for safety
    norm = (roi / gm * 255).astype(np.uint8)
    frame[y_min:y_max, :] = cv2.applyColorMap(np.clip(norm, 0, 255), cmap)
    return frame


def main():
    # Load configuration and inputs
    heatmap_inp, cmap, tracks_csv_path, proc_params, base, output_format = get_config()

    # Read track CSV and pre-index by track id then frame
    df_tracks = pd.read_csv(tracks_csv_path)
    segs = prepare_segments(df_tracks)

    # Determine global max intensity inside ROI for consistent heatmap scaling
    gm = find_global_max(heatmap_inp, proc_params["y_min"], proc_params["y_max"], proc_params["end_frame"])

    # Read the needed span of frames from the source video
    cap = cv2.VideoCapture(heatmap_inp)
    cap.set(cv2.CAP_PROP_POS_FRAMES, proc_params["start_frame"])
    frames = []
    while True:
        ret, frame = cap.read()
        # Stop if out of frames or we passed the requested end_frame
        if not ret or cap.get(cv2.CAP_PROP_POS_FRAMES) > proc_params["end_frame"]:
            break
        frames.append(frame)
    cap.release()

    # Create heatmap-coloured frames (modifies ROI only)
    normal_frames = [apply_heatmap(f, proc_params["y_min"], proc_params["y_max"], gm, cmap) for f in frames]

    # State for drawing paths across frames
    history, last_det = {}, {}

    # Compute output panel sizes based on ROI scaling to target height
    roi_height = proc_params["y_max"] - proc_params["y_min"]
    target_height = proc_params["target_height"]
    scale = target_height / roi_height
    resized_width = int(normal_frames[0].shape[1] * scale)

    # Many codecs require even dimensions — pad width/height if needed
    if resized_width % 2 != 0:
        resized_width += 1
    if target_height % 2 != 0:
        target_height += 1

    # Get FPS from the video for consistent output timing
    cap = cv2.VideoCapture(heatmap_inp)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()

    # Combined output is: [heatmap panel] + [vertical bar] + [heatmap+tracks panel]
    combined_width = 2 * resized_width + proc_params["bar_thickness"]

    # Choose codec based on output format
    if output_format.lower() == "mp4":
        fourcc = cv2.VideoWriter_fourcc(*"avc1")  # H.264 in MP4 container (platform dependent)
    elif output_format.lower() == "avi":
        fourcc = cv2.VideoWriter_fourcc(*"MJPG")  # Motion JPEG in AVI container
    else:
        raise ValueError(f"Unsupported output format: {output_format}")

    combined_video = os.path.join(base, f"Combined_Video_WhiteTracks.{output_format}")
    out_combined = cv2.VideoWriter(combined_video, fourcc, fps, (combined_width, target_height))

    # Render loop: build each combined frame and write it out
    for i in range(len(normal_frames)):
        fc = proc_params["start_frame"] + i           # absolute frame counter
        display_frame = i + 1                         # human-friendly (1-based) counter

        normal_frame = normal_frames[i]

        # Draw tracks on a copy so we can show both raw heatmap and overlay side-by-side
        raw_frame = draw_tracks(normal_frame.copy(), segs, history, fc, last_det, 1)

        # Crop to ROI, then resize each panel to the target height
        normal_cropped = normal_frame[proc_params["y_min"]:proc_params["y_max"], :]
        raw_cropped = raw_frame[proc_params["y_min"]:proc_params["y_max"], :]
        normal_resized = cv2.resize(normal_cropped, (resized_width, target_height))
        raw_resized = cv2.resize(raw_cropped, (resized_width, target_height))

        # Separator bar with frame number text (white on black)
        bar = np.zeros((target_height, proc_params["bar_thickness"], 3), dtype=np.uint8)
        text = str(display_frame)
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 1
        text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)
        text_x = (proc_params["bar_thickness"] - text_size[0]) // 2
        text_y = 50  # fixed vertical offset for readability
        cv2.putText(bar, text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness)

        # Concatenate panels: left (heatmap only) | bar | right (heatmap + tracks)
        combined_frame = np.hstack([normal_resized, bar, raw_resized])

        out_combined.write(combined_frame)

    out_combined.release()
    print("Combined video with raw tracks created:", combined_video)


if __name__ == "__main__":
    main()
