# Compare true vs artificial trajectories

This notebook compares the fetaures from the true and artificial trajectories, performing a cell-to-cell matching based on their initial position.

---

# Imports

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial.distance import cdist

In [None]:
plt.style.use("ggplot")
sns.set_context("paper")

# True trajectrories

In [None]:
true_features = pd.read_csv(
    Path("analyses", "biotine_full", "features_through_time_of_full_lifetime_simple_objects.csv")
)
true_features

# Generated trajectories

In [None]:
experiment_path = Path(
    "/",
    "projects",
    "static2dynamic",
    "Thomas",
    "experiments",
    "GaussianProxy",
    "biotine_all_paired_new_jz_MANUAL_WEIGHTS_DOWNLOAD_FROM_JZ_11-02-2025_14h31",
    "inferences",
    "",
    "trajectories_-1_1 raw",
    "cp_analysis",
)

In [None]:
gen_features = pd.read_csv(experiment_path / "whole_cell.csv")
gen_features

# Filter true traj on selected plate only

In [None]:
plate_name = "M_13_fld_3"

In [None]:
this_plate_true_features = true_features.query("global_object_id.str.contains(@plate_name)")
this_plate_true_features

In [None]:
print(f"Found {this_plate_true_features['global_object_id'].nunique()} objects in the true data")

# Match cells

In [None]:
initial_true_positions = {}
time_1_mask = this_plate_true_features["time"] == 1
this_plate_true_features_time_1 = this_plate_true_features[time_1_mask]
assert (
    len(this_plate_true_features_time_1) == this_plate_true_features["global_object_id"].nunique()
)

for obj_id in this_plate_true_features["global_object_id"].unique():
    this_obj_mask = this_plate_true_features_time_1["global_object_id"] == obj_id
    this_obj_init_pos = this_plate_true_features_time_1.loc[
        this_obj_mask, ("AreaShape_Center_X", "AreaShape_Center_Y")
    ].values[0]
    assert len(this_obj_init_pos) == 2, (
        "Found more than one initial position for object {obj_id}: {this_obj_init_pos}"
    )
    initial_true_positions[obj_id] = this_obj_init_pos

initial_true_positions

## Process artificial cells

In [None]:
# check time data concistency
gen_features[["Metadata_time", "FileName_images"]].value_counts(dropna=False)

In [None]:
gen_features_mask_time_1 = gen_features["Metadata_time"] == 1
gen_features_time_1 = gen_features[gen_features_mask_time_1]
gen_features_time_1

## Get the L2 distance matrix between all true center and all generated centers

In [None]:
initial_true_positions_array = np.array(list(initial_true_positions.values()))

initial_gen_positions_array = gen_features_time_1[
    ["AreaShape_Center_X", "AreaShape_Center_Y"]
].values

# Compute the L2 distance matrix
distance_matrix = cdist(
    initial_true_positions_array, initial_gen_positions_array, metric="euclidean"
)


plt.figure(figsize=(10, 10))
plt.imshow(distance_matrix, cmap="viridis", origin="lower")
plt.colorbar()
plt.ylabel("True objects")
plt.xlabel("Generated objects")
plt.title("L2 Distance between all True and Generated Cells")
plt.tight_layout()
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(distance_matrix.min(axis=1), "x", label="Minimum distance")

closest_objects = np.partition(distance_matrix, [0, 1, 2], axis=1)
assert np.all(distance_matrix.min(axis=1) == closest_objects[:, 0])

plt.plot(closest_objects[:, 1], "x", color="blue", label="Second closest distance")
plt.plot(closest_objects[:, 2], "x", color="green", label="Third closest distance")

plt.xlabel("True objects")
plt.ylabel("Minimum distance to all generated objects")
plt.title("Minimum distance to all generated objects for each true object")
plt.legend()
plt.tight_layout()
plt.show()

## Find closest generated cell to each true cell within threshold

In [None]:
threshold = 50

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(distance_matrix.min(axis=1), "x", label="Minimum distance")

closest_objects = np.partition(distance_matrix, [0, 1, 2], axis=1)
assert np.all(distance_matrix.min(axis=1) == closest_objects[:, 0])

plt.plot(closest_objects[:, 1], "x", color="blue", label="Second closest distance")
plt.plot(closest_objects[:, 2], "x", color="green", label="Third closest distance")

plt.xlabel("True objects")
plt.ylabel("Minimum distance to all generated objects")
plt.title("Minimum distance to all generated objects for each true object + pairing threshold")
plt.legend()
plt.ylim(0, threshold * 2)
plt.axhline(y=threshold, color="k", linestyle="--")
plt.tight_layout()
plt.show()

In [None]:
true_cells_to_gen_cells_mapping = {}

for obj_index, obj_id in enumerate(this_plate_true_features["global_object_id"].unique()):
    distances_to_this_obj = distance_matrix[obj_index]
    indices_below_threshold = np.nonzero(distances_to_this_obj < threshold)[0]

    if len(indices_below_threshold) > 0:
        # Find the index with minimum distance among those below threshold
        min_distance_idx = indices_below_threshold[
            np.argmin(distances_to_this_obj[indices_below_threshold])
        ]

        # Get the corresponding generated cell ID
        closest_gen_id = gen_features_time_1.iloc[min_distance_idx]["TrackObjects_Label_10"]

        true_cells_to_gen_cells_mapping[obj_id] = {
            "closest_gen_id": int(closest_gen_id),
            "min_distance": float(distances_to_this_obj[min_distance_idx]),
            "base_true_position": tuple(initial_true_positions[obj_id]),
        }
    else:
        # No matches below threshold
        true_cells_to_gen_cells_mapping[obj_id] = {
            "closest_gen_id": None,
            "min_distance": None,
            "base_true_position": tuple(initial_true_positions[obj_id]),
        }

true_cells_to_gen_cells_mapping

# Show feature and L2 distance for one cell

## Find pairs with full lifetime

In [None]:
kept_true_ids = []

for true_object_id in true_cells_to_gen_cells_mapping:
    this_obj_true_features = this_plate_true_features.query("global_object_id == @true_object_id")
    if len(this_obj_true_features) != 19:
        print(
            f"Skipping object {true_object_id} with {len(this_obj_true_features)} != 19 timepoints"
        )
        continue

    matching_gen_id = true_cells_to_gen_cells_mapping[true_object_id]["closest_gen_id"]
    if matching_gen_id is None:
        # print(f"Skipping object {true_object_id} with no matching generated cell")
        continue

    matching_obj_gen_features = gen_features.query("TrackObjects_Label_10 == @matching_gen_id")
    if len(matching_obj_gen_features) != 50:
        print(
            f"Skipping object {true_object_id} with matching generated cell having {len(matching_obj_gen_features)} != 50 timepoints"
        )
        continue

    kept_true_ids.append(true_object_id)

print(f"\nKept {len(kept_true_ids)} true IDs: {kept_true_ids}")

## Select feature

In [None]:
selected_feature = "AreaShape_Area"

## Select ID

In [None]:
true_object_id = kept_true_ids[0]
true_object_id

## True features

In [None]:
this_obj_true_features = this_plate_true_features.query("global_object_id == @true_object_id")
this_obj_true_features

## Generated features

In [None]:
matching_gen_id = true_cells_to_gen_cells_mapping[true_object_id]["closest_gen_id"]
matching_gen_id

In [None]:
matching_obj_gen_features = gen_features.query("TrackObjects_Label_10 == @matching_gen_id")
matching_obj_gen_features

## Plot feature and compute L2 distance of timeseries

In [None]:
def compute_trajectory_l2_distance(true_traj: np.ndarray, gen_traj: np.ndarray):
    """
    Compute L2 distance between trajectories with different numbers of points

    Args:
        true_traj: numpy array of shape (n_true_points,)
        gen_traj: numpy array of shape (n_gen_points,)

    Returns:
        generated_traj_interp: numpy array of shape (n_true_points,)
        l2_dist: L2 distance between the trajectories
    """
    # Create synthetic time values (assuming uniform sampling)
    normalized_true_times = np.linspace(0, 1, len(true_traj))
    normalized_gen_times = np.linspace(0, 1, len(gen_traj))
    
    # Interpolate the generated trajectory at the true times
    gen_traj_interp = np.interp(normalized_true_times, normalized_gen_times, gen_traj)
    
    return gen_traj_interp, np.linalg.norm(true_traj - gen_traj_interp)

In [None]:
gen_traj_interp, l2_dist = compute_trajectory_l2_distance(
    this_obj_true_features[selected_feature].values,
    matching_obj_gen_features[selected_feature].values,
)

In [None]:
plt.figure(figsize=(12, 6))
# true
plt.plot(
    this_obj_true_features["time"], this_obj_true_features[selected_feature], "x-", label="True"
)
# generated
normalized_gen_times = (
    matching_obj_gen_features["Metadata_time"]
    / matching_obj_gen_features["Metadata_time"].max()
    * 18
    + 1
)
plt.plot(
    normalized_gen_times,
    matching_obj_gen_features[selected_feature],
    "x-",
    label="Generated",
)
# interpolated generated
plt.plot(
    this_obj_true_features["time"],
    gen_traj_interp,
    "x-",
    label="Interpolated Generated",
)

plt.title(f"{selected_feature} for true object {true_object_id}")
plt.xlabel("Time")
plt.ylabel(selected_feature)
plt.xticks(np.arange(1, 20, 1))
plt.legend()
plt.show()

In [None]:
l2_dist

# Show L2 on all cells

In [None]:
selected_feature

In [None]:
l2_dists = []

for true_object_id in kept_true_ids:
    # true features
    this_obj_true_features = this_plate_true_features.query("global_object_id == @true_object_id")
    # generated features
    matching_gen_id = true_cells_to_gen_cells_mapping[true_object_id]["closest_gen_id"]
    matching_obj_gen_features = gen_features.query("TrackObjects_Label_10 == @matching_gen_id")
    # compute L2 distance
    _, l2_dist = compute_trajectory_l2_distance(
        this_obj_true_features[selected_feature].values,
        matching_obj_gen_features[selected_feature].values,
    )
    l2_dists.append(l2_dist)

plt.figure(figsize=(4, 6))
ax = sns.boxplot(l2_dists, showfliers=False)
ax = sns.swarmplot(l2_dists, color=".25")
plt.xlabel(plate_name)
plt.ylabel(selected_feature)
plt.show()