# Compare true vs artificial trajectories

This notebook compares the fetaures from the true and artificial trajectories, performing a cell-to-cell matching based on their initial position.

---

# Imports

In [323]:
import math
import random
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Union

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy.typing import ArrayLike
from scipy.spatial.distance import cdist
from tqdm.notebook import tqdm

In [324]:
plt.style.use(["ggplot", "fast"])

# CWD if using non-local venv

In [325]:
import os

os.chdir("/workspaces/biocomp/tboyer/sources/CellProfiler_GaussianProxy_analyses")

# True trajectrories

In [326]:
true_features = pd.read_parquet(
    Path(
        "analyses",
        "biotine_resized",
        "features_through_time_of_full_lifetime_cells.parquet",
    )
)
true_features

## Comparing trajectories of: filtered nucleus

In [327]:
# select nucleus features
true_features = true_features[true_features["file"] == "filtered_nucleus"]
true_features

# Generated trajectories

Load, add `global_object_id`, and concat

**beware: there is no (trivial) matching between `global_object_id` in true and generated data!!!**

In [328]:
experiments_base_path_path = Path(
    "/",
    "projects",
    "static2dynamic",
    "Thomas",
    "experiments",
    "GaussianProxy",
    "biotine_all_paired_new_jz_MANUAL_WEIGHTS_DOWNLOAD_FROM_JZ_11-02-2025_14h31",
    "inferences",
)

# `plate_names` and `experiments_names` must be in 1-1 ordered correspondence!
experiments_names = [
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_A_13_fld_2",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_f32_B_13_fld_4",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_C_13_fld_3",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_D_14_fld_1",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_f32_E_14_fld_1",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_F_14_fld_4",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_G_13_fld_1",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_H_14_fld_2",
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_I_14_fld_3",
    # "InvertedRegeneration_100_diffsteps_no_SNR_leading_J_14_fld_2", # TODO: run the CP pipeline!
    "InvertedRegeneration_100_diffsteps_no_SNR_leading_K_14_fld_1",
    # "InvertedRegeneration_100_diffsteps_no_SNR_leading_L_13_fld_2", # TODO: run the CP pipeline!
    "InvertedRegeneration_100_diffsteps_f32_noSNR_leading_M_13_fld_3",
    # "InvertedRegeneration_100_diffsteps_no_SNR_leading_N_14_fld_1", # TODO: run the CP pipeline!
    # "InvertedRegeneration_100_diffsteps_no_SNR_leading_O_14_fld_4", # TODO: run the CP pipeline!
]

plate_names = [
    "A_13_fld_2",
    "B_13_fld_4",
    "C_13_fld_3",
    "D_14_fld_1",
    "E_14_fld_1",
    "F_14_fld_4",
    "G_13_fld_1",
    "H_14_fld_2",
    "I_14_fld_3",
    # "J_14_fld_2",
    "K_14_fld_1",
    # "L_13_fld_2",
    "M_13_fld_3",
    # "N_14_fld_1",
    # "O_14_fld_4",
]
assert len(experiments_names) == len(plate_names), "Number of experiments and plates must match!"
assert len(set(plate_names)) == len(plate_names), "Plate names must be unique!"
for idx in range(len(plate_names)):
    assert experiments_names[idx].endswith(plate_names[idx]), (
        f"Experiment name {experiments_names[idx]} does not end with plate name {plate_names[idx]}!"
    )

experiments_paths = [
    experiments_base_path_path / name / "trajectories_-1_1 raw" / "cp_analysis"
    for name in experiments_names
]
experiments_paths

## Load gen data 

In [329]:
gen_features_list = []

for exp_idx, exp_p in enumerate(experiments_paths):
    this_plate_gen_features = pd.read_csv(exp_p / "filtered_nucleus.csv")

    this_plate_gen_features["plate_name"] = plate_names[exp_idx]

    this_plate_gen_features["global_object_id"] = (
        plate_names[exp_idx] + "-" + this_plate_gen_features["TrackObjects_Label_10"].astype(str)
    )

    gen_features_list.append(this_plate_gen_features)

gen_features = pd.concat(gen_features_list, ignore_index=True)
gen_features

In [330]:
print(f"Found {gen_features['global_object_id'].nunique()} objects in the synthetic data")

## Add time info to artificial frames

In [331]:
gen_features["time"] = (
    gen_features["FileName_images"].str.extract(r"^frame_(\d+).tiff$").astype(int)
)

In [332]:
# check time data concistency
gen_features[["time", "FileName_images"]].value_counts(dropna=False)

# Filter true traj on selected plate only

In [333]:
these_plates_true_features = true_features[
    true_features["global_object_id"].str.startswith(tuple(plate_names))
].copy(deep=True)
# deep clone for in-place modif later

these_plates_true_features

In [334]:
print(f"Found {these_plates_true_features['global_object_id'].nunique()} objects in the true data")

# Match cells

## Get initial positions of true cells

### Sanity check: only one starting pos per object

In [335]:
gen_features_mask_time_1 = gen_features["time"] == 0  # beware, zero here is true time 1!
gen_features_time_1 = gen_features[gen_features_mask_time_1]
gen_features_time_1

In [336]:
time_1_mask = these_plates_true_features["Metadata_time"] == 1
these_plates_true_features_time_1 = these_plates_true_features[time_1_mask]
assert (
    len(these_plates_true_features_time_1)
    == these_plates_true_features["global_object_id"].nunique()
)

# used later for pairing
initial_true_positions_object_ids_to_pos = {}

for obj_id in tqdm(these_plates_true_features_time_1["global_object_id"].unique()):
    this_obj_mask = these_plates_true_features_time_1["global_object_id"] == obj_id
    this_obj_init_pos = these_plates_true_features_time_1.loc[
        this_obj_mask, ("AreaShape_Center_X", "AreaShape_Center_Y")
    ].values
    assert len(this_obj_init_pos) == 1, (
        f"Found more than one initial position for object {obj_id}: {this_obj_init_pos}"
    )
    initial_true_positions_object_ids_to_pos[obj_id] = this_obj_init_pos[0]

for obj_id in tqdm(gen_features_time_1["global_object_id"].unique()):
    this_obj_mask = gen_features_time_1["global_object_id"] == obj_id
    this_obj_init_pos = gen_features_time_1.loc[
        this_obj_mask, ("AreaShape_Center_X", "AreaShape_Center_Y")
    ].values
    assert len(this_obj_init_pos) == 1, (
        f"Found more than one initial position for object {obj_id}: {this_obj_init_pos}"
    )

initial_true_positions_object_ids_to_pos

## Get the L2 distance matrix between all true center and all generated centers

Not forgetting to take into account image size...

In [337]:
TRUE_IMAGE_SIZE = 2040
GEN_IMAGE_SIZE = 1024

In [338]:
# the true images are resized to gen dim at the very beginning of the CP pipeline so GEN_IMAGE_SIZE for both here!
initial_true_positions_arrays: "dict[str, dict[str, np.ndarray]]" = {}  # dict[plate name -> dict[object id -> position]]
initial_gen_positions_arrays: "dict[str, dict[str, np.ndarray]]" = {}  # dict[plate name -> dict[object id -> position]]

# ! true images are resized to gen dim at the very beginning of the CP pipeline so GEN_IMAGE_SIZE for both here !

for plate_name in plate_names:
    # true positions
    this_plate_df = these_plates_true_features_time_1.loc[
        these_plates_true_features_time_1["global_object_id"].str.startswith(plate_name)
    ]
    initial_true_positions_arrays[plate_name] = {}
    for obj_id in this_plate_df["global_object_id"].unique():
        this_obj_mask = this_plate_df["global_object_id"] == obj_id
        this_obj_init_pos = this_plate_df.loc[
            this_obj_mask, ("AreaShape_Center_X", "AreaShape_Center_Y")
        ].values
        assert len(this_obj_init_pos) == 1, (
            f"Found more or less than one initial position for true object {obj_id}: {this_obj_init_pos}"
        )
        initial_true_positions_arrays[plate_name][obj_id] = this_obj_init_pos[0] / GEN_IMAGE_SIZE

    # gen positions
    this_plate_df = gen_features_time_1.loc[
        gen_features_time_1["global_object_id"].str.startswith(plate_name)
    ]
    initial_gen_positions_arrays[plate_name] = {}
    for obj_id in this_plate_df["global_object_id"].unique():
        this_obj_mask = this_plate_df["global_object_id"] == obj_id
        this_obj_init_pos = this_plate_df.loc[
            this_obj_mask, ("AreaShape_Center_X", "AreaShape_Center_Y")
        ].values
        assert len(this_obj_init_pos) == 1, (
            f"Found more or less than one initial position for gen object {obj_id}: {this_obj_init_pos}"
        )
        initial_gen_positions_arrays[plate_name][obj_id] = this_obj_init_pos[0] / GEN_IMAGE_SIZE

Plot the positions as sanity check

In [339]:
def plot_starting_pos(
    plate_name: str,
    experiments_path: Path,
    this_plate_initial_true_positions_flattened_array: np.ndarray,
    this_plate_initial_gen_positions_flattened_array: np.ndarray,
    circles: bool = False,
    threshold: Union[float, None] = None,
):
    # figures
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    axs[0].imshow(
        plt.imread(
            f"/projects/static2dynamic/datasets/biotine/3_channels_min_99_perc_normalized_rgb_stacks_png/{plate_name}_time_01.png"
        ),
        alpha=0.7,
        origin="lower",
    )
    axs[0].set_title("True starting image")
    gen_img = axs[1].imshow(
        plt.imread(
            experiments_path.parent / "frame_00.tiff",
        ),
        alpha=0.7,
        origin="lower",
    )
    gen_img.set_extent(np.array(gen_img.get_extent()) * TRUE_IMAGE_SIZE / GEN_IMAGE_SIZE)
    axs[1].set_title("Synthetic starting image")
    for ax in axs:
        if circles:
            assert threshold is not None, "Threshold must be provided if circles are to be drawn"
            for x, y in this_plate_initial_true_positions_flattened_array:
                circle = plt.Circle(
                    (x * TRUE_IMAGE_SIZE, y * TRUE_IMAGE_SIZE),
                    threshold * TRUE_IMAGE_SIZE,
                    fill=True,
                    color="red",
                    alpha=0.5,
                )
                ax.add_patch(circle)
        ax.scatter(
            this_plate_initial_gen_positions_flattened_array[:, 0] * TRUE_IMAGE_SIZE,
            this_plate_initial_gen_positions_flattened_array[:, 1] * TRUE_IMAGE_SIZE,
            label="Synthetic",
            s=6,
            marker="x",
            c="blue",
        )
        ax.scatter(
            this_plate_initial_true_positions_flattened_array[:, 0] * TRUE_IMAGE_SIZE,
            this_plate_initial_true_positions_flattened_array[:, 1] * TRUE_IMAGE_SIZE,
            label="True",
            marker="x",
            c="red",
            s=6,
        )
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.grid(False)
    fig.suptitle(
        f"Initial positions of objects in {plate_name} (rescaled to {TRUE_IMAGE_SIZE}px)",
        y=0.9,
    )
    lines, labels = axs[1].get_legend_handles_labels()
    fig.legend(lines, labels, bbox_to_anchor=(0.9, 0.95), fontsize=12)
    fig.show()


for idx in range(len(plate_names)):
    plate_name = plate_names[idx]
    experiments_path = experiments_paths[idx]
    # get starting positions in handy format
    this_plate_initial_true_positions_flattened_array = np.array(
        list(initial_true_positions_arrays[plate_name].values())
    )
    this_plate_initial_gen_positions_flattened_array = np.array(
        list(initial_gen_positions_arrays[plate_name].values())
    )
    print(f"Plate {plate_name} - Experiment {experiments_names[idx]}")
    print(
        f"    Found {this_plate_initial_true_positions_flattened_array.shape[0]} objects in the true data"
    )
    print(
        f"    Found {this_plate_initial_gen_positions_flattened_array.shape[0]} objects in the synthetic data"
    )
    plot_starting_pos(
        plate_name,
        experiments_path,
        this_plate_initial_true_positions_flattened_array,
        this_plate_initial_gen_positions_flattened_array,
    )

In [340]:
# Compute the L2 distance matrices
distance_matrices = {}

for plate_name in plate_names:
    distance_matrices[plate_name] = cdist(
        np.array(list(initial_true_positions_arrays[plate_name].values())),
        np.array(list(initial_gen_positions_arrays[plate_name].values())),
        metric="euclidean",
    )

    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    plt.imshow(distance_matrices[plate_name], cmap="viridis", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(cax=cax)
    ax.set_ylabel("True objects")
    ax.set_xlabel("Generated objects")
    ax.set_title(
        "L2 Distance between all True and Generated Cells\n(positions normalized by image size)"
    )
    plt.suptitle(
        f"Plate {plate_name} - Experiment {experiments_names[0]}",
        y=0.9,
    )
    ax.grid(False)
    plt.tight_layout()
    plt.show()

In [341]:
# Compute the log-log L2 distance matrix
for plate_name in plate_names:
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    plt.imshow(np.log(distance_matrices[plate_name]), cmap="viridis", origin="lower")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(cax=cax)
    ax.set_ylabel("True objects")
    ax.set_xlabel("Generated objects")
    ax.set_title(
        "L2 Distance (log-scaled) between all True and Generated Cells\n(positions normalized by image size)"
    )
    plt.suptitle(
        f"Plate {plate_name} - Experiment {experiments_names[0]}",
        y=0.9,
    )
    plt.tight_layout()
    ax.grid(False)
    plt.show()

In [342]:
for plate_name in plate_names:
    plt.figure(figsize=(10, 6))
    plt.plot(distance_matrices[plate_name].min(axis=1), "x", label="Minimum distance")

    closest_objects = np.partition(distance_matrices[plate_name], [0, 1, 2], axis=1)
    assert np.all(distance_matrices[plate_name].min(axis=1) == closest_objects[:, 0])

    plt.plot(closest_objects[:, 1], "x", color="blue", label="Second closest distance")
    plt.plot(closest_objects[:, 2], "x", color="green", label="Third closest distance")

    plt.xlabel("True objects")
    plt.ylabel("Minimum distance to all generated objects")
    plt.title("Minimum distance to all generated objects for each true object")
    plt.legend()
    plt.suptitle(f"Plate {plate_name} - Experiment {experiments_names[0]}")
    plt.ylim(0, 0.07)
    plt.tight_layout()
    plt.show()

## Find closest generated cell to each true cell within threshold

In [343]:
threshold = 0.007

In [344]:
for plate_name in plate_names:
    plt.figure(figsize=(10, 6))
    plt.plot(distance_matrices[plate_name].min(axis=1), "x", label="Minimum distance")

    closest_objects = np.partition(distance_matrices[plate_name], [0, 1, 2], axis=1)
    assert np.all(distance_matrices[plate_name].min(axis=1) == closest_objects[:, 0])

    plt.plot(closest_objects[:, 1], "x", color="blue", label="Second closest distance")
    plt.plot(closest_objects[:, 2], "x", color="green", label="Third closest distance")

    plt.xlabel("True objects")
    plt.ylabel("Minimum distance to all generated objects")
    plt.title("Minimum distance to all generated objects for each true object")
    plt.legend()
    plt.suptitle(f"Plate {plate_name} - Experiment {experiments_names[0]}")
    plt.ylim(0, threshold * 2)
    plt.axhline(y=threshold, color="k", linestyle="--")
    plt.tight_layout()
    plt.show()

In [345]:
for idx in range(len(plate_names)):
    plate_name = plate_names[idx]
    experiments_path = experiments_paths[idx]
    # get starting positions in handy format
    this_plate_initial_true_positions_flattened_array = np.array(
        list(initial_true_positions_arrays[plate_name].values())
    )
    this_plate_initial_gen_positions_flattened_array = np.array(
        list(initial_gen_positions_arrays[plate_name].values())
    )
    print(f"Plate {plate_name} - Experiment {experiments_names[idx]}")
    print(
        f"    Found {this_plate_initial_true_positions_flattened_array.shape[0]} objects in the true data"
    )
    print(
        f"    Found {this_plate_initial_gen_positions_flattened_array.shape[0]} objects in the synthetic data"
    )
    plot_starting_pos(
        plate_name,
        experiments_path,
        this_plate_initial_true_positions_flattened_array,
        this_plate_initial_gen_positions_flattened_array,
        circles=True,
        threshold=threshold,
    )

## Perform the matching

In [346]:
true_cells_to_gen_cells_mapping = {}

for obj_id in tqdm(these_plates_true_features_time_1["global_object_id"].unique()):
    plate_name = obj_id.split("-")[0]
    assert plate_name in plate_names, (
        f"Plate name {plate_name} not found in plate names {plate_names}"
    )

    obj_index = list(initial_true_positions_arrays[plate_name]).index(obj_id)

    distances_to_this_obj = distance_matrices[plate_name][obj_index]
    indices_below_threshold = np.nonzero(distances_to_this_obj < threshold)[0]

    if len(indices_below_threshold) > 0:
        # Find the index with minimum distance among those below threshold
        min_distance_idx = indices_below_threshold[
            np.argmin(distances_to_this_obj[indices_below_threshold])
        ]

        # Get the corresponding generated cell ID
        closest_gen_id = list(initial_gen_positions_arrays[plate_name])[min_distance_idx]

        true_cells_to_gen_cells_mapping[obj_id] = {
            "closest_gen_id": closest_gen_id,
            "min_distance": float(distances_to_this_obj[min_distance_idx]),
            "base_true_position": tuple(initial_true_positions_object_ids_to_pos[obj_id]),
        }
    else:
        # No matches below threshold
        true_cells_to_gen_cells_mapping[obj_id] = {
            "closest_gen_id": None,
            "min_distance": None,
            "base_true_position": tuple(initial_true_positions_object_ids_to_pos[obj_id]),
        }

no_matches = list(
    filter(
        lambda p: p[1]["closest_gen_id"] is None,
        true_cells_to_gen_cells_mapping.items(),
    )
)
print(
    f"{len(no_matches)} out of {len(true_cells_to_gen_cells_mapping)} true cells have no matches in the generated data ({len(no_matches) / len(true_cells_to_gen_cells_mapping) * 100:.1f}%)"
)
print("Mapping:")
display(true_cells_to_gen_cells_mapping)

for true_id, mapping in true_cells_to_gen_cells_mapping.items():
    if mapping["closest_gen_id"] is not None:
        plate_name = true_id.split("-")[0]
        gen_id = mapping["closest_gen_id"]
        min_dist = mapping["min_distance"]

        assert plate_name == gen_id.split("-")[0], (
            f"Plate names do not match: {true_id} vs {gen_id}"
        )
        assert mapping["base_true_position"] == tuple(
            initial_true_positions_object_ids_to_pos[true_id]
        ), (
            f"True position does not match: {mapping['base_true_position']} vs {initial_true_positions_object_ids_to_pos[true_id]}"
        )
        assert min_dist == np.min(
            distance_matrices[plate_name][
                list(initial_true_positions_arrays[plate_name]).index(true_id)
            ]
        ), (
            f"Minimum distance does not match: {min_dist} vs {np.min(distance_matrices[plate_name][list(initial_true_positions_arrays[plate_name]).index(true_id)])}"
        )
        assert min_dist < threshold, (
            f"Minimum distance {min_dist} is not below threshold {threshold}"
        )

# Process not-simple objects

That is: (true and/or generated) objects that merge, or split

## Filter pairs on full lifetime (true and) generated objects

### Get expected full lifetime per generation experiment (= per plate)

In [347]:
nb_gen_frames_per_field = (
    gen_features.groupby("plate_name")["TrackObjects_FinalAge_10"].max().astype(int)
)
nb_gen_frames_per_field

In [348]:
NB_TRUE_FRAMES = 19
NB_GEN_FRAMES = nb_gen_frames_per_field.to_dict()

print("NB_TRUE_FRAMES:", NB_TRUE_FRAMES)
print("NB_GEN_FRAMES:", NB_GEN_FRAMES)
assert set(NB_GEN_FRAMES.keys()) == set(plate_names), (
    f"Plate names in NB_GEN_FRAMES do not match the expected plate names: {plate_names}. Found: {NB_GEN_FRAMES.keys()}"
)

### Filter

In [349]:
kept_true_ids = []

# these reasons are mutually exclusive and ordered!
reasons_skipped_count = {
    "no_matching_gen_cell": 0,
    "gen_no_full_lifetime": 0,
    "not_simple_true_object": 0,
    "not_simple_gen_object": 0,
}


print(f"Getting object types and skipping true objects in {plate_names}:")
for true_object_id in tqdm(true_cells_to_gen_cells_mapping):
    plate_name = true_object_id.split("-")[0]

    this_obj_true_features = these_plates_true_features.query("global_object_id == @true_object_id")
    matching_gen_id = true_cells_to_gen_cells_mapping[true_object_id]["closest_gen_id"]
    matching_obj_gen_features = gen_features.query("global_object_id == @matching_gen_id")

    true_cells_to_gen_cells_mapping[true_object_id]["object_type"] = []

    # 1. No initial match
    if matching_gen_id is None:
        print(f"{true_object_id} with no matching generated cell")
        reasons_skipped_count["no_matching_gen_cell"] += 1
        true_cells_to_gen_cells_mapping[true_object_id]["object_type"].append(
            "no_matching_gen_cell"
        )
        continue  # skip this true object

    # 2. Not simple true object
    # this filtering should have been performed before
    assert len(this_obj_true_features) >= NB_TRUE_FRAMES, (
        f"Expected only full lifetime true cells, but found {len(this_obj_true_features)} timepoints for {true_object_id}"
    )
    if len(this_obj_true_features) != NB_TRUE_FRAMES:  # so >
        true_cells_to_gen_cells_mapping[true_object_id]["object_type"].append(
            "not_simple_true_object"
        )

    # 3. Not simple generated object
    if len(matching_obj_gen_features) > NB_GEN_FRAMES[plate_name]:
        true_cells_to_gen_cells_mapping[true_object_id]["object_type"].append(
            "not_simple_gen_object"
        )

    # 4. Not full lifetime generated object
    gen_final_age = matching_obj_gen_features["TrackObjects_FinalAge_10"].dropna().unique()
    assert len(gen_final_age) == 1, (
        f"Expected only one final age, but found {len(gen_final_age)} for true id {true_object_id} & matching gen id {matching_gen_id}"
    )
    gen_final_age = gen_final_age[0]
    if gen_final_age != NB_GEN_FRAMES[plate_name]:
        print(
            f"{true_object_id} with no full lifetime of matching generated cell having {int(gen_final_age)} < {NB_GEN_FRAMES[plate_name]} final age"
        )
        reasons_skipped_count["gen_no_full_lifetime"] += 1
        true_cells_to_gen_cells_mapping[true_object_id]["object_type"].append(
            "gen_no_full_lifetime"
        )
        continue  # skip this true object

    kept_true_ids.append(true_object_id)

    if true_cells_to_gen_cells_mapping[true_object_id]["object_type"] == []:
        true_cells_to_gen_cells_mapping[true_object_id]["object_type"].append("simple_true_&_gen")

orig_kept_true_ids = kept_true_ids.copy()
print(f"\nKept {len(kept_true_ids)} true IDs: {kept_true_ids}")

In [350]:
print(
    "Reasons of filtering (ordered & mutually exclusive, some reasons actually do not filter anymore):"
)
for k, v in reasons_skipped_count.items():
    print(f"    {k}: {v} ({v / len(true_cells_to_gen_cells_mapping) * 100:.1f}%)")

### Check nb cells kept per plate

In [351]:
print(
    f"Kept {len(kept_true_ids)} out of {len(true_cells_to_gen_cells_mapping)} true cells ({len(kept_true_ids) / len(true_cells_to_gen_cells_mapping) * 100:.1f}%) from {len(plate_names)} plates"
)
print("Number of true cells per plate:")
for plate_name in plate_names:
    print(
        f"    {plate_name}: {len([k for k in kept_true_ids if k.startswith(plate_name)])} out of {len([k for k in true_cells_to_gen_cells_mapping if k.startswith(plate_name)])} ({len([k for k in kept_true_ids if k.startswith(plate_name)]) / len([k for k in true_cells_to_gen_cells_mapping if k.startswith(plate_name)]) * 100:.1f}%)"
    )

### Check object types

In [352]:
object_type_counts = {}
for true_obj_id, data in true_cells_to_gen_cells_mapping.items():
    for types in data["object_type"]:
        if types not in object_type_counts:
            object_type_counts[types] = 0
        object_type_counts[types] += 1

print("Type of objects (unordered & not exclusive!):")
for k, v in object_type_counts.items():
    if k in ("simple_true_&_gen", "not_simple_true_object", "not_simple_gen_object"):
        print(
            f"    {k}: {v} ({v / len(true_cells_to_gen_cells_mapping) * 100:.1f}% of total) ({v / len(kept_true_ids) * 100:.1f}% of kept)"
        )
    else:
        print(f"    {k}: {v} ({v / len(true_cells_to_gen_cells_mapping) * 100:.1f}% of total)")


assert (
    len(kept_true_ids)
    == len(true_cells_to_gen_cells_mapping)
    - reasons_skipped_count["no_matching_gen_cell"]
    - reasons_skipped_count["gen_no_full_lifetime"]
), (
    f"Something went wrong: {len(kept_true_ids)} != {len(true_cells_to_gen_cells_mapping)} - {reasons_skipped_count['no_matching_gen_cell']} - {reasons_skipped_count['gen_no_full_lifetime']}"
)

## "Duplicate" non-simple objects by reconstructing their full trajectory with another ID

### Util funcs

In [353]:
def build_tracking_tree(df: pd.DataFrame) -> tuple:
    """
    Builds a directed graph representation of the tracking data showing parent-child relationships.

    Args:
        df: DataFrame containing trajectory data

    Returns:
        tuple: (DiGraph, positions_dict) where:
            - DiGraph is the networkx graph of the tracking tree
            - positions_dict contains node positioning information for visualization
    """
    # Sort by time and create directed graph
    time_key = "Metadata_time" if "Metadata_time" in df.columns else "time"
    df_sorted = df.sort_values(time_key)
    G = nx.DiGraph()

    # Add nodes and edges in one pass
    nodes_by_time = {}
    for _, row in df_sorted.iterrows():
        time = int(row[time_key])
        obj_num = int(row["ObjectNumber"])  # per-time id
        obj_id = row["global_object_id"]  # per-object id
        track_label = row["TrackObjects_Label_10"]

        # Node attributes and grouping by time
        # we use object number and not tracking label as node id
        # as the link between objects at t and t+1 is done on object number by CP
        node_id = (time, obj_num, obj_id)
        label = f"{obj_num}\n({track_label})"
        G.add_node(node_id, label=label, time=time)

        # Group nodes by time for positioning
        if time not in nodes_by_time:
            nodes_by_time[time] = []
        nodes_by_time[time].append(node_id)

        # Add edge if parent exists
        parent_num = row["TrackObjects_ParentObjectNumber_10"]
        if pd.notna(parent_num) and parent_num > 0:
            parent_id = (time - 1, parent_num, obj_id)
            if parent_id in G.nodes:
                G.add_edge(parent_id, node_id)

    # Create positions for nodes
    positions = {}
    for time, nodes in nodes_by_time.items():
        num_nodes = len(nodes)
        for i, node_id in enumerate(sorted(nodes, key=lambda x: x[1])):
            positions[node_id] = (time, i - (num_nodes - 1) / 2)

    return G, positions


def get_all_tree_paths(G: nx.DiGraph, progress: bool = False) -> list:
    """
    Extracts all paths from root nodes to leaf nodes in the tracking tree.

    Args:
        G: NetworkX DiGraph object representing the tracking tree
        progress: Whether to show a progress bar

    Returns:
        list: List of paths, where each path is a list of node IDs from root to leaf
    """
    # Find root nodes (those with no predecessors)
    root_nodes = [node for node in G.nodes() if G.in_degree(node) == 0]

    all_paths = []

    # Set up progress bar
    if progress:
        pbar = tqdm(total=len(root_nodes), desc="Processing root nodes")

    # Process each root node with iterative DFS
    for root in root_nodes:
        # Use a stack for iterative DFS
        stack = [(root, [root])]

        while stack:
            node, path = stack.pop()

            # If this is a leaf node, we've found a complete path
            if G.out_degree(node) == 0:
                all_paths.append(path)
            else:
                # Add all child nodes to the stack with their paths
                for successor in G.successors(node):
                    stack.append((successor, path + [successor]))

        if progress:
            pbar.update(1)

    if progress:
        pbar.close()

    return all_paths


def plot_tracking_tree(G: nx.DiGraph, positions: dict, ax=None, figsize: tuple = (8, 6)):
    """
    Plots a visual tree representation of the tracking data showing parent-child relationships.

    Args:
        G: NetworkX DiGraph object representing the tracking tree
        positions: Dictionary mapping node IDs to (x,y) positions
        ax: Optional matplotlib axis to plot on
        figsize: Figure size as (width, height) in inches
    """
    # Create plot
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Draw the graph (edges, nodes, labels)
    nx.draw_networkx_edges(G, positions, ax=ax, arrows=True, arrowstyle="->", width=1.2)
    nx.draw_networkx_nodes(G, positions, ax=ax, node_color="lightblue")
    nx.draw_networkx_labels(
        G,
        positions,
        labels={n: G.nodes[n]["label"] for n in G.nodes()},
        font_size=5,
        font_color="black",
    )

    # Get the time range from the graph
    time_points = sorted(set(G.nodes[n]["time"] for n in G.nodes()))

    # Customize plot appearance
    if time_points:
        time_min, time_max = min(time_points), max(time_points)
        ax.set_title(f"Tracking Tree (Time {time_min} to {time_max})")
    else:
        ax.set_title("Tracking Tree")

    ax.set_xlabel("Time")

    # Remove y-axis ticks as they don't represent anything specific
    ax.set_yticks([])

    # Reset x-axis formatting completely
    ax.tick_params(reset=True)

    # Force x-axis visibility
    ax.spines["bottom"].set_visible(True)
    ax.spines["bottom"].set_linewidth(1.0)
    ax.set_xticks(time_points)
    ax.set_xticklabels(time_points)
    ax.tick_params(
        axis="x",
        which="both",
        length=4,
        width=1,
        colors="black",
        labelcolor="black",
        bottom=True,
        labelbottom=True,
    )
    ax.xaxis.set_tick_params(which="both", bottom=True, labelbottom=True)

    # Add a light grid
    ax.grid(axis="x", linestyle="--", alpha=0.7)

    # Make plot more compact
    plt.tight_layout()

    return ax

### Visualize few non-simple examples

In [354]:
true_ids_non_simple_true_trajs = [
    k
    for k, v in true_cells_to_gen_cells_mapping.items()
    if "not_simple_true_object" in v["object_type"]
]

for id in random.sample(true_ids_non_simple_true_trajs, 3) + ["A_13_fld_2-14.0"]:
    print("True ID", id)
    tmp_df = these_plates_true_features.query("global_object_id == @id")

    # Build the tracking tree
    G, positions = build_tracking_tree(tmp_df)

    # Find all paths from roots to leaves
    all_paths = get_all_tree_paths(G)

    # Print path information
    print("\nPaths through tracking tree:")
    for i, path in enumerate(all_paths):
        path_str = " → ".join([str(o) for t, o, l in path])
        print(f"Path {i + 1}: {path_str}")

    # Plot tree visualization
    fig, ax = plt.subplots(figsize=(10, 5), dpi=150)
    plot_tracking_tree(G, positions, ax=ax)
    plt.show()

    assert len(tmp_df) == len(G), (
        f"Expected the same number of nodes in the graph than rows in the DataFrame, got {len(tmp_df)} rows and {len(G)} nodes"
    )

### Visualize the full tree of non-simple true objects

In [355]:
# Build the full tracking tree
G, positions = build_tracking_tree(
    these_plates_true_features.query("global_object_id in @true_ids_non_simple_true_trajs")
)

# Find all paths from roots to leaves
all_paths = get_all_tree_paths(G)

# Print path information
print(
    f"Paths through tracking tree of {len(true_ids_non_simple_true_trajs)} non-simple true objects:"
)
for i, path in enumerate(all_paths):
    path_str = " → ".join([str(o) for _, o, l in path])
    print(f"Path {i + 1}: {path_str}")

# # Plot tree visualization  (commented because matplotlib is SOOO SLOOOW)
# fig, ax = plt.subplots(figsize=(30, 60), dpi=150)
# plot_tracking_tree(G, positions, ax=ax)
# plt.show()

No mergers because of the CP pipeline keeping track of only *the* closest parent cell...... ... .....

In [356]:
root_nodes = [node for node in G.nodes() if G.in_degree(node) == 0]
leaf_nodes = [node for node in G.nodes() if G.out_degree(node) == 0]
print(f"Found {len(root_nodes)} root nodes and {len(leaf_nodes)} leaf nodes")

### Process the data

#### Utils

In [357]:
def process_splitting_trajs(
    obj_id: str, this_obj_features: pd.DataFrame, features: pd.DataFrame, full_lifetime: int
):
    """
    Returns the `features` dataframe with subtrajs with non full lifetime removed and duplicated subtrajs with full lifetime,
    using the graph computed from `this_obj_features`.

    Beware: this function does not check if `features` was already processed, (will result in duplicated rows),
    and is far from optimized for speed.
    """
    # Build the tracking tree for this object
    G, _ = build_tracking_tree(this_obj_features)
    # Find all paths from roots to leaves
    all_paths: list[list[tuple[int, int]]] = get_all_tree_paths(G)
    nodes_to_keep: set[tuple[int, int]] = set()
    for path in all_paths:
        if len(path) == full_lifetime:
            # duplicate the full lifetime subpath to be a full path, in-place in features
            features = add_path_to_df(path, features, obj_id)
            # add the nodes from that full lifetime path to the ones to be kept
            nodes_to_keep = nodes_to_keep.union(set(path))

    # finally remove the non full lifetime paths not part of the full lifetime paths
    # (ie: prune the tree!)
    nodes_to_prune = set(G.nodes()) - nodes_to_keep
    if len(nodes_to_prune) != 0:
        remove_nodes_from_df(nodes_to_prune, features, obj_id)

    return features


def remove_nodes_from_df(nodes: set, features: pd.DataFrame, obj_id: str):
    """
    Removes `nodes` of the trajectories of `obj_id` in the `features` DataFrame

    Beware: this function modifies the DataFrame in place.
    """
    this_obj_maks = features["global_object_id"] == obj_id
    time_key = "Metadata_time" if "Metadata_time" in features.columns else "time"
    indexes_to_remove = []
    for t, o, l in nodes:
        # get this object at this time
        this_time_this_obj_mask = (features[time_key] == t) & (this_obj_maks)
        this_time_this_obj = features[this_time_this_obj_mask]
        # checks
        if len(this_time_this_obj) <= 1:
            # if we remove something at this time, another one should exist
            raise ValueError(
                f"Expected at least 2 object numbers at time {t} for global object {obj_id}, but found none"
            )
        else:
            # now get this subpath for this global object at this time
            this_subtraj_mask = this_time_this_obj["ObjectNumber"] == o
            this_time_this_obj_this_subtraj = this_time_this_obj[this_subtraj_mask]
            assert len(this_time_this_obj_this_subtraj) == 1, (
                f"Expected one subobject at time {t} for obj_id={obj_id} and subobject {o}, but found {len(this_time_this_obj_this_subtraj)}"
            )
            index_to_rm = this_time_this_obj_this_subtraj.index[0]
            indexes_to_remove.append(index_to_rm)
    assert len(indexes_to_remove) > 0, (
        f"Expected to remove at least one timepoint for {obj_id}, but found none"
    )
    features.drop(index=indexes_to_remove, inplace=True)
    print(f"  Removed {len(indexes_to_remove)} timepoints from {obj_id}")


def add_path_to_df(path: list, features: pd.DataFrame, obj_id: str):
    """
    Duplicates a subpath of a trajectory in the DataFrame with a new global object id

    Beware: this function returns a reindexed dataframe
    """
    this_obj_maks = features["global_object_id"] == obj_id
    time_key = "Metadata_time" if "Metadata_time" in features.columns else "time"
    indexes_to_duplicate = []
    for t, o, l in path:
        this_time_this_obj_mask = (features[time_key] == t) & (this_obj_maks)
        this_time_this_obj = features[this_time_this_obj_mask]
        if len(this_time_this_obj) == 0:
            raise ValueError(
                f"Expected at least one object at time {t} for {obj_id}, but found none"
            )
        else:
            this_time_this_obj_this_subtraj_mask = this_time_this_obj["ObjectNumber"] == o
            this_time_this_obj_this_subtraj = this_time_this_obj[
                this_time_this_obj_this_subtraj_mask
            ]
            assert len(this_time_this_obj_this_subtraj) == 1, (
                f"Expected one subobject at time {t} for obj_id={obj_id} and subobject {o}, but found {len(this_time_this_obj_this_subtraj)}"
            )
            index_to_duplicate = this_time_this_obj_this_subtraj.index[0]
            indexes_to_duplicate.append(index_to_duplicate)
    assert len(np.unique(indexes_to_duplicate)) == len(indexes_to_duplicate) == len(path)
    # now add these rows to the dataframe with a new global object id
    subtraj_nb = 0
    new_obj_id = f"{obj_id}-extracted_subtraj_{subtraj_nb}"
    while new_obj_id in features["global_object_id"].values:
        subtraj_nb += 1
        new_obj_id = f"{obj_id}-extracted_subtraj_{subtraj_nb}"
    features_to_add = features.loc[indexes_to_duplicate].copy(deep=True)
    features_to_add["global_object_id"] = new_obj_id
    features = pd.concat([features, features_to_add], ignore_index=True)
    print(
        f"  Added {len(features_to_add)} timepoints to {new_obj_id} ({len(path)} timepoints in the path)"
    )
    return features

#### Perform the processing

In [358]:
# some checks to ensure no double-processing
assert not these_plates_true_features["global_object_id"].str.contains("-extracted_subtraj_").any()
assert not gen_features["global_object_id"].str.contains("-extracted_subtraj_").any()

for true_obj_id, data in tqdm(true_cells_to_gen_cells_mapping.items()):
    if data["object_type"] == ["simple_true_&_gen"]:
        continue
    elif data["object_type"] == ["no_matching_gen_cell"]:
        continue
    elif "gen_no_full_lifetime" in data["object_type"]:
        continue
    else:
        plate_name = true_obj_id.split("-")[0]
        this_obj_true_features = these_plates_true_features.query(
            "global_object_id == @true_obj_id"
        )
        matching_gen_id = true_cells_to_gen_cells_mapping[true_obj_id]["closest_gen_id"]
        matching_obj_gen_features = gen_features.query("global_object_id == @matching_gen_id")

        object_type_ok = False

        if "not_simple_true_object" in data["object_type"]:
            object_type_ok = True
            print(f"Processing true {true_obj_id}")
            these_plates_true_features = process_splitting_trajs(
                true_obj_id, this_obj_true_features, these_plates_true_features, NB_TRUE_FRAMES
            )
            # remove the original non-simple objects
            these_plates_true_features.drop(
                these_plates_true_features.query("global_object_id == @true_obj_id").index,
                inplace=True,
            )

        if "not_simple_gen_object" in data["object_type"]:
            object_type_ok = True
            print(f"Processing gen {matching_gen_id}")
            gen_features = process_splitting_trajs(
                matching_gen_id, matching_obj_gen_features, gen_features, NB_GEN_FRAMES[plate_name]
            )
            # remove the original non-simple objects
            gen_features.drop(
                gen_features.query("global_object_id == @matching_gen_id").index,
                inplace=True,
            )

        if not object_type_ok:
            raise ValueError(f"Unknown object type: {data['object_type']}")

In [360]:
tmp_df = these_plates_true_features.loc[
    these_plates_true_features["global_object_id"].str.startswith("M_13_fld_3-76.0")
]

print(f"Shown gloabl_object_id: {tmp_df['global_object_id'].unique()}")

G, positions = build_tracking_tree(tmp_df)

# Find all paths from roots to leaves
all_paths = get_all_tree_paths(G)

# Print path information
print("\nPaths through tracking tree:")
for i, path in enumerate(all_paths):
    path_str = " → ".join([str(o) for t, o, l in path])
    print(f"Path {i + 1}: {path_str}")

# Plot tree visualization
fig, ax = plt.subplots(figsize=(10, 5), dpi=150)
plot_tracking_tree(G, positions, ax=ax)
plt.show()

In [361]:
these_plates_true_features.loc[
    these_plates_true_features["global_object_id"].str.startswith("M_13_fld_3-76.0")
]

#### Check the result

In [362]:
these_plates_true_features["global_object_id"].unique()

## Update `kept_true_ids` wih the new IDs

In [363]:
object_type_counts

In [364]:
for id in tqdm(orig_kept_true_ids):
    assert "-extracted_subtraj_" not in id, (
        f"Expected {id} to be a simple object, but found -extracted_subtraj_ in the ID"
    )
    if id not in these_plates_true_features["global_object_id"].unique():
        # check the now missing IDs are only non-simple objects
        obj_type = true_cells_to_gen_cells_mapping[id]["object_type"]
        assert "not_simple_true_object" in obj_type or "not_simple_gen_object" in obj_type, (
            f"Expected {id} to be a non-simple object, but found {obj_type}"
        )
        # add the newly created IDs to the kept IDs
        new_ids = these_plates_true_features.query("global_object_id.str.startswith(@id)")[
            "global_object_id"
        ].unique()
        kept_true_ids.extend(new_ids)
        # remove the original ID from the kept IDs
        kept_true_ids.remove(id)

print(f"Now having {len(kept_true_ids)} true cell IDs (vs {len(orig_kept_true_ids)} before)")

## More checks after data processing

In [365]:
G, positions = build_tracking_tree(
    these_plates_true_features.query("global_object_id in @kept_true_ids")
)

# Find all paths from roots to leaves
all_paths = get_all_tree_paths(G)

# Print path information
print("\nPaths through tracking tree:")
for i, path in enumerate(all_paths):
    path_str = " → ".join([str(o) for t, o, l in path])
    print(f"Path {i + 1}: {path_str}")

# Plot tree visualization (commented because matplotlib is SOOO SLOOOW)
# fig, ax = plt.subplots(figsize=(30, 60), dpi=150)
# plot_tracking_tree(G, positions, ax=ax)
# plt.show()

In [366]:
# Find root nodes (those with no predecessors)
root_nodes = [node for node in G.nodes() if G.in_degree(node) == 0]

# Find leaf nodes (those with no successors)
leaf_nodes = [node for node in G.nodes() if G.out_degree(node) == 0]

print(len(root_nodes), "root nodes")
print(len(leaf_nodes), "leaf nodes")
assert len(root_nodes) == len(leaf_nodes)

# Show feature and L2/cosine/... distance for one cell

## Select feature

In [367]:
selected_feature = "AreaShape_Area"

## Select ID

In [368]:
# true_object_id = kept_true_ids[np.random.randint(0, len(kept_true_ids))]
true_object_id = plate_name + "-373.0"
true_object_id

## True features

In [369]:
assert (
    true_object_id in kept_true_ids
    and true_object_id in these_plates_true_features["global_object_id"].values
)

this_obj_true_features = these_plates_true_features.query("global_object_id == @true_object_id")
this_obj_true_features

## Generated features

In [370]:
matching_gen_id = true_cells_to_gen_cells_mapping[
    re.sub(r"-extracted_subtraj_\d+", "", true_object_id)
]["closest_gen_id"]  # might not exist! (because the matching gen was itself a subtraj)
matching_gen_id

In [371]:
matching_obj_gen_features = gen_features.query("global_object_id == @matching_gen_id")
matching_obj_gen_features

## Plot feature and compute L2 /cosine sim distance of timeseries

In [372]:
def compute_trajectory_comparison_metrics(
    true_traj: Union[np.ndarray, ArrayLike], gen_traj: Union[np.ndarray, ArrayLike]
):
    """
    Compute metrics between trajectories with different numbers of points (by interpolating the generated trajectory at the true times).

    Args:
        true_traj: array of shape (n_true_points,)
        gen_traj: array of shape (n_gen_points,)

    Returns:
        gen_traj_interp: array of shape (n_true_points,) of the interpolated generated trajectory at the true times
        l2: float, L2 distance between the trajectories
        cos_sim: float, cosine similarity between the trajectories
        normd_l2: float, L2 distance between the trajectories normalized by the true trajectory
    """
    # Checks
    assert true_traj.ndim == 1, "true_traj should be a 1D array"
    assert gen_traj.ndim == 1, "gen_traj should be a 1D array"
    assert len(true_traj) > 0, "true_traj should not be empty"
    assert len(gen_traj) > 0, "gen_traj should not be empty"

    # Create synthetic time values (assuming uniform sampling)
    if len(true_traj) != len(gen_traj):
        normalized_true_times = np.linspace(0, 1, len(true_traj))
        normalized_gen_times = np.linspace(0, 1, len(gen_traj))

        # Interpolate the generated trajectory at the true times
        gen_traj_interp = np.interp(normalized_true_times, normalized_gen_times, gen_traj)
    else:
        gen_traj_interp = gen_traj

    # l2_dist
    l2 = np.linalg.norm(true_traj - gen_traj_interp)

    # cosine_sim
    cos_sim = np.dot(true_traj, gen_traj_interp) / (
        np.linalg.norm(true_traj) * np.linalg.norm(gen_traj_interp)
    )

    # normd_l2
    normd_l2 = l2 / np.linalg.norm(true_traj)

    return gen_traj_interp, l2, cos_sim, normd_l2

In [373]:
gen_traj_interp, l2_dist, cosine_sim, normd_l2 = compute_trajectory_comparison_metrics(
    this_obj_true_features[selected_feature].values,
    matching_obj_gen_features[selected_feature].values,
)

In [374]:
plt.figure(figsize=(12, 6))
# true
plt.plot(
    this_obj_true_features["Metadata_time"],
    this_obj_true_features[selected_feature].values,
    "x-",
    label="True",
)
# generated
normalized_gen_times = (
    matching_obj_gen_features["time"] / matching_obj_gen_features["time"].max() * 18 + 1
)
plt.plot(
    normalized_gen_times,
    matching_obj_gen_features[selected_feature].values,
    "x-",
    label="Generated",
)
# interpolated generated
plt.plot(
    this_obj_true_features["Metadata_time"],
    gen_traj_interp,
    "x-",
    label="Interpolated Generated",
)

plt.title(f"{selected_feature} for true object {true_object_id}")
plt.xlabel("Time")
plt.ylabel(selected_feature)
plt.xticks(np.arange(1, 20, 1))
plt.legend()
plt.show()

In [375]:
assert cosine_sim == np.dot(this_obj_true_features[selected_feature].values, gen_traj_interp) / (
    np.linalg.norm(this_obj_true_features[selected_feature].values)
    * np.linalg.norm(gen_traj_interp)
)

l2_dist, cosine_sim, normd_l2

# Show L2 / cosine sim on all cells

In [376]:
selected_feature

## Checks

In [377]:
nb_miss_true_traj_split = 0
nb_orig_matches_miss = 0

for true_object_id in tqdm(kept_true_ids):
    plate_name = true_object_id.split("-")[0]

    matching_gen_id = true_cells_to_gen_cells_mapping[
        re.sub(r"-extracted_subtraj_\d+", "", true_object_id)
    ]["closest_gen_id"]  # might not exist! (because the matching gen was itself a subtraj)
    matching_obj_gen_features = gen_features.query("global_object_id == @matching_gen_id")
    assert matching_gen_id is not None

    if len(matching_obj_gen_features) == 0:
        nb_orig_matches_miss += 1

        if "-extracted_subtraj_" in true_object_id:
            nb_miss_true_traj_split += 1

    elif len(matching_obj_gen_features) != NB_GEN_FRAMES[plate_name]:
        raise RuntimeError(
            f"{len(matching_obj_gen_features)} timepoints for matching generated cell {matching_gen_id} of true id {true_object_id}, expected 50"
        )

    else:
        assert matching_gen_id in gen_features["global_object_id"].values

print(
    f"\n{nb_orig_matches_miss} in total out of {len(kept_true_ids)} kept true cells ({nb_orig_matches_miss / len(kept_true_ids) * 100:.1f}%) have no matches anymore in the generated data (because of generated trajectories splitting), and must be rematched"
)
print(
    f"{nb_miss_true_traj_split} of these true trajs have also been split ({nb_miss_true_traj_split / nb_orig_matches_miss * 100:.1f}%)"
)

## Rematching strategy

In [378]:
rematching_strategy = "all-wheighted"  # choose from: "no", "random", "all" TODO: add "best" options

## Helper function

In [379]:
def compare_true_to_gen_traj(
    these_plates_true_features: pd.DataFrame,
    true_object_id: str,
    true_cells_to_gen_cells_mapping: dict[str, dict[str, Any]],
    gen_features: pd.DataFrame,
    rematching_strategy: str,
    selected_feature: str,
    verbose: bool = False,
    compute_weights_return_ids: bool = False,
):
    """
    Appends to the lists given in args the `metrics` computed between the true trajectory corresponding
    to `true_object_id` and the matching generated trajectory/ies in `gen_features`.
    """
    # true features
    this_obj_true_features = these_plates_true_features.query("global_object_id == @true_object_id")
    # generated features
    matching_gen_id = true_cells_to_gen_cells_mapping[
        re.sub(r"-extracted_subtraj_\d+", "", true_object_id)
    ]["closest_gen_id"]
    # might not exist! (because the matching gen was itself a subtraj)
    # -> so find new matches
    if matching_gen_id not in gen_features["global_object_id"].values:
        split_matching_gen_ids = gen_features.loc[
            gen_features["global_object_id"].str.startswith(
                matching_gen_id + "-extracted_subtraj_"
            ),
            "global_object_id",
        ].unique()
        split_matching_gen_ids = [str(id) for id in split_matching_gen_ids.copy()]
        # choose one at random
        if rematching_strategy == "no":
            if verbose:
                print("Skipping rematching for", true_object_id)
            return None
        elif rematching_strategy == "random":
            matching_gen_id = np.random.choice(split_matching_gen_ids)
            if verbose:
                print("Rematching for", true_object_id, "at random:", matching_gen_id)
            if compute_weights_return_ids:
                weights = [1]
                ids = [(true_object_id, matching_gen_id)]
        elif rematching_strategy == "all":
            matching_gen_id = split_matching_gen_ids
            if verbose:
                print("Rematching for", true_object_id, "at all matches:", matching_gen_id)
            if compute_weights_return_ids:
                weights = [1] * len(matching_gen_id)
                ids = [(true_object_id, gen_id) for gen_id in matching_gen_id]
        elif rematching_strategy == "all-wheighted":
            matching_gen_id = split_matching_gen_ids
            if verbose:
                print("Rematching for", true_object_id, "at all matches:", matching_gen_id)
            if compute_weights_return_ids:
                weights = [1 / len(matching_gen_id)] * len(matching_gen_id)
                ids = [(true_object_id, gen_id) for gen_id in matching_gen_id]
        else:
            raise ValueError(f"Unknown rematching strategy: {rematching_strategy}")
    else:
        # nothing special to do
        if compute_weights_return_ids:
            weights = [1]
            ids = [(true_object_id, matching_gen_id)]

    # compute distances
    if not isinstance(matching_gen_id, list):
        matching_gen_id = [matching_gen_id]

    l2_dists = []
    cosine_sims = []
    normd_l2s = []
    for gen_id in matching_gen_id:
        matching_obj_gen_features = gen_features.query("global_object_id == @gen_id")
        _, l2_dist, cosine_sim, normed_l2 = compute_trajectory_comparison_metrics(
            this_obj_true_features[selected_feature].values,
            matching_obj_gen_features[selected_feature].values,
        )
        l2_dists.append(l2_dist)
        cosine_sims.append(cosine_sim)
        normd_l2s.append(normed_l2)

    return (
        l2_dists,
        cosine_sims,
        normd_l2s,
        weights if compute_weights_return_ids else None,
        ids if compute_weights_return_ids else None,
    )

## Compute & plot

In [380]:
all_cells_l2_dists = []
all_cells_cosine_sims = []
all_cells_normd_l2s = []
all_cells_weights = []
all_pairs_ids = []

print("Using rematching strategy:", rematching_strategy)
print("Using feature:", selected_feature)

for true_object_id in tqdm(kept_true_ids):
    res = compare_true_to_gen_traj(
        these_plates_true_features,
        true_object_id,
        true_cells_to_gen_cells_mapping,
        gen_features,
        rematching_strategy,
        selected_feature,
        True,
        True,
    )
    if res is not None:
        l2_dists, cosine_sims, normd_l2s, weights, ids = res
        all_pairs_ids.extend(ids)
        all_cells_l2_dists.extend(l2_dists)
        all_cells_cosine_sims.extend(cosine_sims)
        all_cells_normd_l2s.extend(normd_l2s)
        # if the true object is a subtraj, we need to re-weight *again* the matched generated trajectories,
        # otherwise we are only correcting for the duplication of the generated trajectories,
        # and *not* for that of the true ones!
        if "-extracted_subtraj_" in true_object_id:  # that true object_id is a subtraj
            nb_true_subtrajs_duplicated = len(
                [
                    id
                    for id in kept_true_ids
                    if id.startswith(re.sub(r"-extracted_subtraj_\d+", "", true_object_id))
                ]
            )
            # re-normalize weights so that the sum of the weights of all *true* subjtrajs equals one
            weights = [w / nb_true_subtrajs_duplicated for w in weights]
        all_cells_weights.extend(weights)

assert (
    len(all_cells_l2_dists)
    == len(all_cells_cosine_sims)
    == len(all_cells_normd_l2s)
    == len(all_cells_weights)
    == len(all_pairs_ids)
), (
    f"Expected the same number of distances, weights, and pairs, but found {len(all_cells_l2_dists)}, {len(all_cells_cosine_sims)}, {len(all_cells_normd_l2s)}, {len(all_cells_weights)}, and {len(all_pairs_ids)}"
)
if rematching_strategy != "all-wheighted":
    assert all(w == 1 for w in all_cells_weights), (
        f"Expected all weights to be 1, but found {all_cells_weights}"
    )

if rematching_strategy == "all-wheighted":
    for true_object_id in kept_true_ids:
        this_true_object_pairs_idxes = [
            idx
            for idx, (true_id, _) in enumerate(all_pairs_ids)
            if true_id.startswith(re.sub(r"-extracted_subtraj_\d+", "", true_object_id))
        ]
        assert sum(all_cells_weights[idx] for idx in this_true_object_pairs_idxes) == 1, (
            f"Expected the sum of weights for true object {true_object_id} to be 1, but found {sum(all_cells_weights[idx] for idx in this_true_object_pairs_idxes)}"
        )

In [381]:
fig, axes = plt.subplots(1, 3, figsize=(12, 6))
axes = axes.flatten()

# Plot 1: L2 distance
mean_l2 = np.average(all_cells_l2_dists, weights=all_cells_weights)
q1, med, q3 = np.quantile(
    all_cells_l2_dists, [0.25, 0.5, 0.75], weights=all_cells_weights, method="inverted_cdf"
)
iqr = q3 - q1
whislo = max(q1 - 1.5 * iqr, np.min(all_cells_l2_dists))
whishi = min(q3 + 1.5 * iqr, np.max(all_cells_l2_dists))
sns.boxplot([whislo, q1, med, q3, whishi], showfliers=False, ax=axes[0])
sns.swarmplot(all_cells_l2_dists, color=".25", size=2, ax=axes[0])
axes[0].axhline(y=mean_l2, color="orange", linestyle="-")
axes[0].text(
    0.05,
    0.95,
    f"Mean: {round(mean_l2, 1)}\nMedian: {round(med, 1)}",
    transform=axes[0].transAxes,
    verticalalignment="top",
)
axes[0].set_ylabel("L2 distance")
axes[0].set_xlabel(selected_feature)
axes[0].set_ylim(
    0 - 0.05 * np.max(all_cells_l2_dists),
    np.max(all_cells_l2_dists) + 0.05 * np.max(all_cells_l2_dists),
)

# Plot 2: Cosine similarity
mean_cos = np.mean(all_cells_cosine_sims)
q1, med, q3 = np.quantile(
    all_cells_cosine_sims, [0.25, 0.5, 0.75], weights=all_cells_weights, method="inverted_cdf"
)
iqr = q3 - q1
whislo = max(q1 - 1.5 * iqr, np.min(all_cells_cosine_sims))
whishi = min(q3 + 1.5 * iqr, np.max(all_cells_cosine_sims))
sns.boxplot([whislo, q1, med, q3, whishi], showfliers=False, ax=axes[1])
sns.swarmplot(all_cells_cosine_sims, color=".25", size=2, ax=axes[1])
axes[1].axhline(y=mean_cos, color="orange", linestyle="-")
axes[1].text(
    0.05,
    0.1,
    f"Mean: {round(mean_cos, 3)}\nMedian: {round(med, 3)}",
    transform=axes[1].transAxes,
    verticalalignment="top",
)
axes[1].set_ylabel("Cosine similarity")
axes[1].set_xlabel(selected_feature)

# Plot 3: Normalized L2 distance
all_cells_normd_l2s_pct = np.array(all_cells_normd_l2s) * 100
mean_normd = np.mean(all_cells_normd_l2s_pct)
q1, med, q3 = np.quantile(
    all_cells_normd_l2s_pct, [0.25, 0.5, 0.75], weights=all_cells_weights, method="inverted_cdf"
)
iqr = q3 - q1
whislo = max(q1 - 1.5 * iqr, np.min(all_cells_normd_l2s_pct))
whishi = min(q3 + 1.5 * iqr, np.max(all_cells_normd_l2s_pct))
sns.boxplot([whislo, q1, med, q3, whishi], showfliers=False, ax=axes[2])
sns.swarmplot(all_cells_normd_l2s_pct, color=".25", size=2, ax=axes[2])
axes[2].axhline(y=mean_normd, color="orange", linestyle="-")
axes[2].text(
    0.05,
    0.95,
    f"Mean: {round(mean_normd, 2)}\nMedian: {round(med, 2)}",
    transform=axes[2].transAxes,
    verticalalignment="top",
)
axes[2].set_ylabel("Normalized L2 distance (% of true traj)")
axes[2].set_xlabel(selected_feature)

plt.suptitle(f"Statistics on {len(all_cells_l2_dists)} trajectories from {len(plate_names)} fields")
plt.tight_layout()
plt.show()

# Show cosine sim and normalized L2 on all cells on all features

## Select features

In [382]:
list(true_features.columns)

In [383]:
# remove "false" features
features_to_remove = [
    "file",
    "Metadata_time",
    "global_object_id",
    "Location_Center_Z",
    "ObjectNumber",
]
features_to_remove += [
    f
    for f in true_features.columns
    if f.startswith("TrackObjects_")
    or f.startswith("Metadata_")
    or f.startswith("Parent")
    or "BoundingBox" in f
    or f == "EulerNumber"
]
features_to_compare = [f for f in true_features.columns if f not in features_to_remove]
features_to_compare

## Rematching strategy

In [384]:
rematching_strategy  # don't change here, weights must be the same as before!

## Compute

In [385]:
def compute_trajectory_comparison_metrics_for_this_feat(selected_feature: str):
    """
    Compute metrics for the given `selected_feature`

    TODO: factorize matching beforehand
    TODO: parallelize along objects too
    """
    this_feat_cosine_sims = []
    this_feat_normd_l2s = []
    for true_object_id in kept_true_ids:
        res = compare_true_to_gen_traj(
            these_plates_true_features,
            true_object_id,
            true_cells_to_gen_cells_mapping,
            gen_features,
            rematching_strategy,
            selected_feature,
        )
        if res is not None:
            _, cosine_sims, normd_l2s, _, _ = res
            this_feat_cosine_sims.extend(cosine_sims)
            this_feat_normd_l2s.extend(normd_l2s)

    return {"cosine_sim": this_feat_cosine_sims, "normd_l2": this_feat_normd_l2s}

In [386]:
cosine_sims: "dict[str, list]" = dict.fromkeys(features_to_compare)
normd_l2s: "dict[str, list]" = dict.fromkeys(features_to_compare)

print("Using rematching strategy:", rematching_strategy)

with ProcessPoolExecutor() as executor:
    future_to_feature = {
        executor.submit(
            compute_trajectory_comparison_metrics_for_this_feat, selected_feature
        ): selected_feature
        for selected_feature in features_to_compare
    }
    for future in tqdm(as_completed(future_to_feature), total=len(future_to_feature)):
        selected_feature = future_to_feature[future]
        res = future.result()  # raises any exception
        assert res is not None, f"Expected a result for {selected_feature}, but got None"
        cosine_sims[selected_feature] = res["cosine_sim"]
        normd_l2s[selected_feature] = res["normd_l2"]

## Metrics plots helper funcs

In [387]:
def plot_metrics_histograms(metrics_dict: dict[str, list[float]], weights: list[float]):
    means = [np.average(v, weights=weights) for v in metrics_dict.values()]
    medians = [
        np.quantile(v, [0.5], weights=weights, method="inverted_cdf")[0]
        for v in metrics_dict.values()
    ]

    means_weights = np.ones_like(means) / len(means) * 100
    medians_weights = np.ones_like(medians) / len(medians) * 100

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), dpi=100)

    # Plot histogram of means
    hist = ax1.hist(
        means,
        bins=100,
        range=(-1, 1),
        weights=means_weights,
        alpha=0.75,
        color="blue",
        edgecolor="black",
    )
    ax1.set_title("Histogram of Mean Cosine Similarities")
    ax1.set_xlabel("Mean Cosine Similarity")
    ax1.set_ylabel("Number of Features (%)")
    ax1.grid(alpha=0.3)
    ax1.vlines(
        x=np.mean(means),
        ymin=0,
        ymax=hist[0].max(),
        color="red",
        linestyle="--",
        label="Mean of means",
    )
    ax1.text(
        x=np.mean(means) - 0.2,
        y=hist[0].max() * 0.9,
        s=f"{np.mean(means):.2f}",
        color="red",
        fontsize=10,
    )
    ax1.legend()

    # Plot histogram of medians
    hist = ax2.hist(
        medians,
        bins=100,
        range=(-1, 1),
        weights=medians_weights,
        alpha=0.75,
        color="green",
        edgecolor="black",
    )
    ax2.set_title("Histogram of Median Cosine Similarities")
    ax2.set_xlabel("Median Cosine Similarity")
    ax2.set_ylabel("Number of Features (%)")
    ax2.grid(alpha=0.3)
    ax2.vlines(
        x=np.median(medians),
        ymin=0,
        ymax=hist[0].max(),
        color="red",
        linestyle="--",
        label="Median of medians",
    )

    ax2.text(
        x=np.median(medians) - 0.2,
        y=hist[0].max() * 0.9,
        s=f"{np.median(medians):.2f}",
        color="red",
        fontsize=10,
    )
    ax2.legend()
    plt.tight_layout()
    plt.show()


def plot_metrics_boxplots(features_names: list[str], metrics: dict[str, list[float]]):
    # Calculate the grid size for subplots
    num_features = len(features_names)
    num_cols = 10  # Number of columns in the subplot grid
    num_rows = math.ceil(num_features / num_cols)  # Calculate number of rows needed
    print("num features", num_features)
    print("num rows", num_rows)
    print("num cols", num_cols)

    # Create figure with subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5 * num_rows), dpi=150)
    axes = axes.flatten() if num_features > 1 else [axes]

    # Create a boxplot and swarmplot for each feature
    for i, feature in enumerate(tqdm(features_names)):
        try:
            sns.boxplot(y=metrics[feature], ax=axes[i], showfliers=False)
            sns.stripplot(y=metrics[feature], ax=axes[i], color=".25", size=3)
            axes[i].set_ylabel(feature)
            axes[i].set_ylim(-1, 1)
        except Exception as e:
            print(f"Error processing feature {feature}: {e}")
            axes[i].set_visible(False)

    # Hide any unused subplots
    for j in range(num_features, len(axes)):
        axes[j].set_visible(False)

    plt.suptitle("Feature Cosine Similarities ordered by mean", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.96)  # Adjust to make room for suptitle
    plt.show()

## Cosine sims

### Throw NaN cosine features

In [388]:
nan_features = []

for k, v in cosine_sims.items():
    print(f"    {k}: {np.mean(v):.3f} ± {np.std(v):.3f}")
    if np.isnan(np.mean(v)):
        print(f"WARNING!!! => feature {k} has NaN mean")
        nan_features.append(k)

print("NaN features:", nan_features)
nan_features

In [389]:
non_nan_features = [k for k in cosine_sims.keys() if k not in nan_features]
print(f"Kept {len(non_nan_features)} non-nan features out of {len(cosine_sims)}")

non_nan_cosine_sims = {k: v for k, v in cosine_sims.items() if k in non_nan_features}

### Histograms

In [390]:
plot_metrics_histograms(non_nan_cosine_sims, all_cells_weights)

### All features

In [391]:
features_order_by_mean_cossim = sorted(
    non_nan_cosine_sims, key=lambda f: np.mean(cosine_sims[f]), reverse=True
)
features_order_by_mean_cossim

In [392]:
plot_metrics_boxplots(features_order_by_mean_cossim, non_nan_cosine_sims)

## Normalized L2 features

### Throw NaN normalized L2 features

In [393]:
nan_features = []

for k, v in normd_l2s.items():
    print(f"    {k}: {np.mean(v):.3f} ± {np.std(v):.3f}")
    if np.isnan(np.mean(v)):
        print(f"WARNING!!! => feature {k} has NaN mean")
        nan_features.append(k)

print("NaN features:", nan_features)
nan_features

In [394]:
non_nan_features = [k for k in normd_l2s.keys() if k not in nan_features]
print(f"Kept {len(non_nan_features)} non-nan features out of {len(normd_l2s)}")

non_nan_normd_l2s = {k: v for k, v in normd_l2s.items() if k in non_nan_features}

### Normalized L2 Histograms

In [395]:
means = [np.mean(v) for v in non_nan_normd_l2s.values()]
medians = [np.median(v) for v in non_nan_normd_l2s.values()]

means_weights = np.ones_like(means) / len(means) * 100
medians_weights = np.ones_like(medians) / len(medians) * 100

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), dpi=100)

# Plot histogram of means
hist = ax1.hist(
    np.array(means) * 100,
    bins=100,
    range=(0, np.max(means) * 100),
    weights=means_weights,
    alpha=0.75,
    color="blue",
    edgecolor="black",
)
ax1.set_title("Histogram of Mean Normalized L2 Norm")
ax1.set_xlabel("Mean Normalized L2 Norm (% of true traj)")
ax1.set_ylabel("Number of Features (%)")
ax1.grid(alpha=0.3)
ax1.vlines(
    x=np.mean(means) * 100,
    ymin=0,
    ymax=hist[0].max(),
    color="red",
    linestyle="--",
    label="Mean of means",
)
ax1.text(
    x=np.mean(means) * 100 - 2,
    y=hist[0].max() * 0.9,
    s=f"{np.mean(means) * 100:.2f}",
    color="red",
    fontsize=10,
)
ax1.legend()

# Plot histogram of medians
hist = ax2.hist(
    np.array(medians) * 100,
    bins=100,
    range=(0, np.max(medians) * 100),
    weights=medians_weights,
    alpha=0.75,
    color="green",
    edgecolor="black",
)
ax2.set_title("Histogram of Median Normalized L2 Norm")
ax2.set_xlabel("Median Normalized L2 Norm (% of true traj)")
ax2.set_ylabel("Number of Features (%)")
ax2.grid(alpha=0.3)
ax2.vlines(
    x=np.median(medians) * 100,
    ymin=0,
    ymax=hist[0].max(),
    color="red",
    linestyle="--",
    label="Median of medians",
)

ax2.text(
    x=np.median(medians) * 100 - 2,
    y=hist[0].max() * 0.9,
    s=f"{np.median(medians) * 100:.2f}",
    color="red",
    fontsize=10,
)
ax2.legend()
plt.tight_layout()
plt.show()

redo the hist of mean but plotting only between 0 and 100%

In [396]:
means = [np.mean(v) for v in non_nan_normd_l2s.values()]
medians = [np.median(v) for v in non_nan_normd_l2s.values()]

means_weights = np.ones_like(means) / len(means) * 100
medians_weights = np.ones_like(medians) / len(medians) * 100

fig = plt.figure(figsize=(10, 5), dpi=100)
# Plot histogram of means
plt.hist(
    np.array(means) * 100,
    bins=100,
    range=(0, 100),
    weights=means_weights,
    alpha=0.75,
    color="blue",
    edgecolor="black",
)
plt.title("Histogram of Mean Normalized L2 Norm")
plt.xlabel("Mean Normalized L2 Norm (% of true traj)")
plt.ylabel("Number of Features (%)")
plt.grid(alpha=0.3)
plt.show()

fig = plt.figure(figsize=(10, 5), dpi=100)
# Plot histogram of means
plt.hist(
    np.array(means) * 100,
    bins=100,
    range=(0, 1000),
    weights=means_weights,
    alpha=0.75,
    color="blue",
    edgecolor="black",
)
plt.title("Histogram of Mean Normalized L2 Norm")
plt.xlabel("Mean Normalized L2 Norm (% of true traj)")
plt.ylabel("Number of Features (%)")
plt.grid(alpha=0.3)
plt.show()

In [397]:
import numpy as np
import plotly.express as px

# Use the same data from your original code
means_data = np.array(means) * 100

# Create the first cumulative histogram (0-100 range)
fig1 = px.histogram(
    means_data,
    nbins=5000,
    range_x=[0, 500],
    cumulative=True,
    histnorm="percent",
    title="Cumulative Histogram of Mean Normalized L2 Norm",
    opacity=0.5,
)

fig1.update_layout(
    xaxis_title="Mean Normalized L2 Norm (% of true traj)",
    yaxis_title="Cumulative Percentage of Features",
    height=600,
    width=1000,
    template="plotly_white",
)

fig1.show()

In [398]:
features_order_by_mean_normd_l2 = sorted(non_nan_normd_l2s, key=lambda f: np.mean(normd_l2s[f]))
features_order_by_mean_normd_l2

### All features

In [399]:
# Calculate the grid size for subplots
num_features = len(features_order_by_mean_normd_l2)
num_cols = 10  # Number of columns in the subplot grid
num_rows = math.ceil(num_features / num_cols)  # Calculate number of rows needed
print("num features", num_features)
print("num rows", num_rows)
print("num cols", num_cols)

# Create figure with subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5 * num_rows), dpi=200)
axes = axes.flatten() if num_features > 1 else [axes]

# Create a boxplot and swarmplot for each feature
for i, feature in enumerate(tqdm(features_order_by_mean_normd_l2)):
    try:
        sns.boxplot(y=np.array(normd_l2s[feature]) * 100, ax=axes[i], showfliers=False)
        sns.stripplot(y=np.array(normd_l2s[feature]) * 100, ax=axes[i], color=".25", size=3)
        axes[i].set_ylabel(feature + " (%)")
        axes[i].set_ylim(0, np.max(normd_l2s[feature]) * 100)
    except Exception as e:
        print(f"Error processing feature {feature}: {e}")
        axes[i].set_visible(False)

# Hide any unused subplots
for j in range(num_features, len(axes)):
    axes[j].set_visible(False)

plt.suptitle(
    f"Feature Normalized L2 Metrics (% of true traj) ordered by mean - {plate_name}", fontsize=16
)
plt.tight_layout()
plt.subplots_adjust(top=0.96)  # Adjust to make room for suptitle
plt.show()

# Features correlation

In [400]:
these_plates_true_features[features_to_compare]

## Correlation matrix

In [401]:
# Compute the correlation matrix
correlation_matrix = these_plates_true_features[features_to_compare].corr()
print(f"Computed correlation matrix of shape {correlation_matrix.shape}")

# drop Nan values
correlation_matrix = correlation_matrix.dropna(axis=1, how="all")
correlation_matrix = correlation_matrix.dropna(axis=0, how="all")
print(f"After dropping NaN values, correlation matrix shape is {correlation_matrix.shape}")

# Create a clustered heatmap visualization
plt.figure(figsize=(25, 20))
# Use hierarchical clustering to group correlated features
linkage = scipy.cluster.hierarchy.linkage(correlation_matrix, method="ward")
dendro = scipy.cluster.hierarchy.dendrogram(
    linkage, labels=correlation_matrix.columns, no_plot=True
)
reordered_idx = dendro["leaves"]
reordered_corr = correlation_matrix.iloc[reordered_idx, reordered_idx]

# Find highly correlated pairs
corr_threshold = 0.95  # Correlation threshold
high_corr_pairs = []

for i in range(len(correlation_matrix.columns)):
    for j in range(i):
        if abs(correlation_matrix.iloc[i, j]) > corr_threshold:
            high_corr_pairs.append(
                (
                    correlation_matrix.columns[i],
                    correlation_matrix.columns[j],
                    correlation_matrix.iloc[i, j],
                )
            )

# Sort by absolute correlation value
high_corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)

print(
    f"Found {len(high_corr_pairs)} feature pairs with absolute correlation > {corr_threshold} ({round(len(high_corr_pairs) / len(correlation_matrix) * 100)}% of total)"
)
for feature1, feature2, corr in high_corr_pairs:
    print(f"    {feature1} ~ {feature2}: {corr:.3f}")

# Plot the heatmap
mask = np.triu(np.ones_like(reordered_corr))  # Keep only lower triangle
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(
    reordered_corr,
    mask=mask,
    cmap=cmap,
    vmax=1,
    vmin=-1,
    center=0,
    square=True,
    linewidths=0.005,
    cbar_kws={"shrink": 0.8},
)
# Make colorbar tick labels bigger
cbar = plt.gcf().axes[-1]  # Get the colorbar axes
cbar.tick_params(labelsize=14)  # Set tick label size

highly_correlated_features = set()
for feature1, feature2, _ in high_corr_pairs:
    highly_correlated_features.add(feature1)
    highly_correlated_features.add(feature2)
# Get the tick labels
ax = plt.gca()
xticklabels = ax.get_xticklabels()
yticklabels = ax.get_yticklabels()

# Make labels bold for highly correlated features
for label in xticklabels:
    if label.get_text() in highly_correlated_features:
        label.set_fontweight("bold")

for label in yticklabels:
    if label.get_text() in highly_correlated_features:
        label.set_fontweight("bold")

# Apply changes
plt.draw()

plt.title("Feature Correlation Matrix (Clustered)", fontsize=16)
plt.tight_layout()
plt.show()

## Plot histograms of uncorrelated features

In [402]:
kept_indep_features = set(non_nan_cosine_sims.keys())

for f1, f2, _ in high_corr_pairs:
    if f1 not in kept_indep_features or f2 not in kept_indep_features:
        # either f1 or f2 was already removed, no need to do anything from this pair
        continue
    else:
        # we need to remove one of the two features
        f_rm = random.choice([f1, f2])
        kept_indep_features.remove(f_rm)
        print(f"Keeping {f1 if f_rm == f2 else f2} and removing {f_rm} from kept features")

print(
    f"Kept {len(kept_indep_features)} independent features out of {len(features_to_compare)} ({len(kept_indep_features) / len(features_to_compare) * 100:.0f}%)"
)

# check
tmp_check = []
tmp_filtered_corr_mat = correlation_matrix.loc[list(kept_indep_features), list(kept_indep_features)]
for i in range(len(tmp_filtered_corr_mat.columns)):
    for j in range(i):
        assert abs(tmp_filtered_corr_mat.iloc[i, j]) <= corr_threshold, (
            f"{tmp_filtered_corr_mat.iloc[i, j]} > {corr_threshold} for {tmp_filtered_corr_mat.columns[i]} and {tmp_filtered_corr_mat.columns[j]}"
        )

In [403]:
plot_metrics_histograms(
    {k: v for k, v in non_nan_cosine_sims.items() if k in kept_indep_features}, all_cells_weights
)

# Simple vs non-simple objects

First show some stats

In [404]:
nb_nonsimple_cells = 0
nb_simple_cells = 0

for true_id in kept_true_ids:
    if "-extracted_subtraj_" in true_id:
        assert (
            "not_simple_true_object"
            in true_cells_to_gen_cells_mapping["-".join(true_id.split("-")[:2])]["object_type"]
        )
        nb_nonsimple_cells += 1
    else:
        nb_simple_cells += 1

print(
    f"{nb_simple_cells} simple true cells ({nb_simple_cells / len(kept_true_ids) * 100:.1f}%) and {nb_nonsimple_cells} nonsimple true cells ({nb_nonsimple_cells / len(kept_true_ids) * 100:.1f}%)"
)

## Separate simple and non-simple objects

In [405]:
simple_pairs_idxes = []  # pairs of simple objects
nonsimple_pairs_idxes = []  # pairs of non-simple objects (at least one)

for idx, (true_id, gen_id) in enumerate(all_pairs_ids):
    if "-extracted_subtraj_" not in true_id and "-extracted_subtraj_" not in gen_id:
        simple_pairs_idxes.append(idx)
    else:
        nonsimple_pairs_idxes.append(idx)

print(
    f"{len(simple_pairs_idxes)} simple pairs ({len(simple_pairs_idxes) / len(all_pairs_ids) * 100:.1f}%) and {len(nonsimple_pairs_idxes)} nonsimple pairs ({len(nonsimple_pairs_idxes) / len(all_pairs_ids) * 100:.1f}%)"
)
print(
    f"In original data before duplication *and filtering*, we had {object_type_counts['not_simple_true_object']} nonsimple cells (out of: we didn't keep track!)"
)

## Plot histograms 

In [406]:
non_nan_cosine_sims_simple_cells = {
    k: np.array(v)[simple_pairs_idxes] for k, v in non_nan_cosine_sims.items()
}
non_nan_cosine_sims_nonsimple_cells = {
    k: np.array(v)[nonsimple_pairs_idxes] for k, v in non_nan_cosine_sims.items()
}

print("Non simple pairs:")
plot_metrics_histograms(non_nan_cosine_sims_nonsimple_cells, [1] * len(nonsimple_pairs_idxes))
print("Simple pairs:")
plot_metrics_histograms(non_nan_cosine_sims_simple_cells, [1] * len(simple_pairs_idxes))

## Plot histograms of uncorrelated features

In [407]:
non_nan_cosine_sims_simple_cells_uncorr_feats = {
    k: np.array(v)[simple_pairs_idxes]
    for k, v in non_nan_cosine_sims.items()
    if k in kept_indep_features
}
non_nan_cosine_sims_nonsimple_cells_uncorr_feats = {
    k: np.array(v)[nonsimple_pairs_idxes]
    for k, v in non_nan_cosine_sims.items()
    if k in kept_indep_features
}

print("Non simple pairs:")
plot_metrics_histograms(
    non_nan_cosine_sims_nonsimple_cells_uncorr_feats, [1] * len(nonsimple_pairs_idxes)
)
print("Simple pairs:")
plot_metrics_histograms(
    non_nan_cosine_sims_simple_cells_uncorr_feats, [1] * len(simple_pairs_idxes)
)