In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib widget

from pathlib import Path

import arviz as az
import joblib
import numpy as np
import pandas as pd
import pymc as pm
from matplotlib import pyplot as plt
from sklearn.metrics import (
    adjusted_mutual_info_score,
    adjusted_rand_score,
)
from sklearn.preprocessing import StandardScaler
from sqlalchemy import create_engine

from ppcluster import logger, mcmc
from ppcluster.cvat import (
    filter_dataframe_by_polygons,
    read_polygons_from_cvat,
)
from ppcluster.griddata import create_2d_grid, map_grid_to_points
from ppcluster.mcmc.postproc import (
    aggregate_multiscale_clustering,
    remove_small_grid_components,
)
from ppcluster.preprocessing import (
    apply_2d_gaussian_filter,
    apply_dic_filters,
    preprocess_velocity_features,
    spatial_subsample,
)
from ppcluster.utils.config import ConfigManager
from ppcluster.utils.database import (
    get_dic_analysis_by_ids,
    get_dic_analysis_ids,
    get_image,
    get_multi_dic_data,
)

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)

# Load configuration
config = ConfigManager()
db_engine = create_engine(config.db_url)


SAVE_OUTPUTS = True  # Set to True to save inference results
LOAD_EXISTING = False  # Set to False to run sampling again

# MCMC parameters
DRAWS = 2000  # Number of MCMC draws
TUNE = 1000  # Number of tuning steps
CHAINS = 4  # Number of MCMC chains
CORES = 4  # Number of CPU cores to use
TARGET_ACCEPT = 0.9  # Target acceptance rate for NUTS sampler

# Data selection parameters
camera_name = "PPCX_Tele"
reference_date = None  # "2024-08-02"
reference_start_date = "2024-08-02"  #
reference_end_date = "2024-08-02"  #
dt_min = 72  # Minimum time difference between images in hours
dt_max = 96  # Maximum time difference between images in hours
# dt_min = 24  # 72  # Minimum time difference between images in hours
# dt_max = 200  # 96  # Maximum time difference between images in hours

SUBSAMPLE_FACTOR = 1  # 1=Take every n point
SUBSAMPLE_METHOD = "random"  # or 'random', 'stratified'

# Parse various parameters from config file (or set manually here)
variables_names = ["V"]
filter_kwargs = dict(
    filter_outliers=False,
    tails_percentile=0.005,
    min_velocity=0.0,
    apply_2d_median=False,
    median_window_size=5,
    median_threshold_factor=3,
    apply_2d_gaussian=False,
    gaussian_sigma=1.0,
)

# == PRIORS and ROI ==
# Define a specific prior probability for each sector (overrides PRIOR_STRENGTH)
# This is a dictionary where keys are sector names and values are lists of prior probabilities (Sector names must match those in the XML file)
# Sector name: [P(Cluster A), P(Cluster B), P(Cluster C)...]
# PRIOR_PROBABILITY = {
#     "A": [1.0, 0.0, 0.0],
#     "B": [0.1, 0.7, 0.2],
#     "C": [0.0, 0.2, 0.8],
# }
SECTOR_PRIOR_FILE = Path("data/priors_4_sectors.xml")
PRIOR_PROBABILITY = {
    "A": [1.0, 0.0, 0.0, 0.0],
    "B": [0.1, 0.7, 0.2, 0.0],
    "C": [0.0, 0.2, 0.8, 0.0],
    "D": [0.0, 0.0, 0.3, 0.7],
}

roi_path = Path("data/roi.xml")


In [3]:
# Read roi and spatial priors
roi = read_polygons_from_cvat(roi_path, image_name=None)
sectors = read_polygons_from_cvat(SECTOR_PRIOR_FILE, image_name=None)

# Check that at least the reference date or an interval of dates is provided
if not (reference_date or (reference_start_date and reference_end_date)):
    raise ValueError(
        "Either reference_date or both reference_start_date and reference_end_date must be provided."
    )

# Fetch DIC ids
dic_ids = get_dic_analysis_ids(
    db_engine,
    camera_name=camera_name,
    reference_date=reference_date,
    reference_date_start=reference_start_date,
    reference_date_end=reference_end_date,
    time_difference_min=dt_min,
    time_difference_max=dt_max,
)
if len(dic_ids) < 1:
    raise ValueError("No DIC analyses found for the given criteria")

# Get DIC analysis metadata
dic_analyses = get_dic_analysis_by_ids(db_engine=db_engine, dic_ids=dic_ids)
logger.info("Fetched DIC analysis:")
for _, row in dic_analyses.iterrows():
    print(
        f"DIC ID: {row['dic_id']}, date: {row['reference_date']}, dt (hrs): {row['dt_hours']}, Master: {row['master_timestamp']}, Slave: {row['slave_timestamp']}"
    )
print("Summary of selected DIC analyses:")
print(dic_analyses.describe())


# Output paths
date_start = dic_analyses.iloc[0]["master_timestamp"].strftime("%Y-%m-%d")
date_end = dic_analyses.iloc[0]["slave_timestamp"].strftime("%Y-%m-%d")
output_dir = Path("output") / f"{camera_name}_{date_end}_mcmc_multiscale"
output_dir.mkdir(parents=True, exist_ok=True)
base_name = f"{date_start}_{date_end}"

# Get master image
master_image_id = dic_analyses["master_image_id"].iloc[0]
img = get_image(image_id=master_image_id, config=config)

# Fetch DIC data
out = get_multi_dic_data(
    dic_ids,
    stack_results=False,
    config=config,
)
logger.info(f"Found stack of {len(out)} DIC dataframes. Run filtering...")

# Apply filter for each df in the dictionary and then stack them
processed = []
for src_id, df_src in out.items():
    try:
        # Filter only points inside the spatial priors sectors
        df_src = filter_dataframe_by_polygons(df_src, polygons=roi)

        # Apply other DIC filters if any
        df_src = apply_dic_filters(df_src, **filter_kwargs)

        # Append processed dataframe to the list
        processed.append(df_src)
    except Exception as exc:
        logger.warning("Filtering failed for %s: %s", src_id, exc)
if not processed:
    raise RuntimeError("No dataframes left after filtering.")
# Stack all processed dataframes
df = pd.concat(processed, ignore_index=True)
logger.info("Data shape after filtering and stacking: %s", df.shape)

# Apply subsampling
if SUBSAMPLE_FACTOR > 1:
    df_subsampled = spatial_subsample(
        df, n_subsample=SUBSAMPLE_FACTOR, method=SUBSAMPLE_METHOD
    )
    df = df_subsampled
    logger.info(f"Data shape after subsampling: {df.shape}")


2025-10-01 18:37:43 | [INFO    ] Found 1 DIC analyses matching criteria
2025-10-01 18:37:43 | [INFO    ] Fetched DIC analysis:


DIC ID: 1915, date: 2024-08-02, dt (hrs): 72, Master: 2024-07-30 05:00:18+00:00, Slave: 2024-08-02 05:00:18+00:00
Summary of selected DIC analyses:
       dic_id  master_image_id  slave_image_id  dt_hours
count     1.0              1.0             1.0       1.0
mean   1915.0          34633.0         34670.0      72.0
std       NaN              NaN             NaN       NaN
min    1915.0          34633.0         34670.0      72.0
25%    1915.0          34633.0         34670.0      72.0
50%    1915.0          34633.0         34670.0      72.0
75%    1915.0          34633.0         34670.0      72.0
max    1915.0          34633.0         34670.0      72.0


2025-10-01 18:37:43 | [INFO    ] Fetched DIC data for id 1915 with 3927 points
2025-10-01 18:37:43 | [INFO    ] Found stack of 1 DIC dataframes. Run filtering...
2025-10-01 18:37:43 | [INFO    ] Starting DIC filtering pipeline with 2659 points
2025-10-01 18:37:43 | [INFO    ] Min velocity filtering: 2659 -> 2659 points (removed 0 points below 0.0)
2025-10-01 18:37:43 | [INFO    ] DIC filtering pipeline completed: 2659 -> 2659 points (removed 0 total)
2025-10-01 18:37:43 | [INFO    ] Data shape after filtering and stacking: (2659, 5)


## RUN MCMC multiple times with different smoothing scales and median the clustering results


In [None]:
# Assign spatial priors
prior_probs = mcmc.assign_spatial_priors(
    x=df["x"].to_numpy(),
    y=df["y"].to_numpy(),
    polygons=sectors,
    prior_probs=PRIOR_PROBABILITY,
    method="exponential",
    method_kws={"decay_rate": 0.001},
)

fig, axes = mcmc.plot_spatial_priors(df, prior_probs, img=img)
fig.savefig(
    output_dir / f"{base_name}_spatial_priors.jpg",
    dpi=150,
    bbox_inches="tight",
)
plt.close(fig)

# Plot velocity field
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
mcmc.plot_velocity_magnitude(
    df["x"].to_numpy(),
    df["y"].to_numpy(),
    df["V"].to_numpy(),
    img=img,
    ax=ax,
)
fig.savefig(
    output_dir / f"{base_name}_velocity_field.jpg",
    dpi=150,
    bbox_inches="tight",
)
plt.close(fig)

In [None]:
## Helper function to preprocess velocity features


def run_mcmc_clustering(
    df_input,
    prior_probs,
    sectors,
    output_dir,
    base_name,
    img=None,
    variables_names=None,
    transform_velocity="none",
    transform_params=None,
    mu_params=None,
    sigma_params=None,
    feature_weights=None,
    sample_args=None,
    mrf_regularization: bool = False,
    mrf_kwargs: dict | None = None,
    second_pass: str = "full",  # "skip" | "short" | "full"
    second_pass_sample_args: dict | None = None,
    random_seed=8927,
):
    """
    Run MCMC-based clustering on velocity data with flexible velocity transformations.

    Parameters:
    -----------
    df_input : pandas.DataFrame
        Input dataframe with 'x', 'y', 'V' columns
    transform_velocity : str, default="none"
        Type of velocity transformation: "power", "exponential", "threshold", "sigmoid", or "none"
    transform_params : dict, optional
        Parameters for velocity transformation (see preprocess_velocity_features for details)
    """

    # --- helper: build initvals from idata posterior means (warm-start) ---
    def _initvals_from_idata(idata_in, n_chains):
        mu_mean = idata_in.posterior["mu"].mean(dim=["chain", "draw"]).values
        sigma_mean = idata_in.posterior["sigma"].mean(dim=["chain", "draw"]).values
        # Ensure shapes match the model dims; return a list of per-chain dicts
        init = {"mu": mu_mean, "sigma": sigma_mean}
        return [init for _ in range(n_chains)]

    logger.info(f"Running MCMC clustering for {base_name}...")

    # Default parameters if not provided
    if mu_params is None:
        mu_params = {"mu": 0, "sigma": 1}
    if sigma_params is None:
        sigma_params = {"sigma": 1}
    if sample_args is None:
        sample_args = dict(
            target_accept=0.95,
            draws=2000,
            tune=1000,
            chains=4,
            cores=4,
            random_seed=random_seed,
        )
    if variables_names is None:
        variables_names = ["V"]

    if "V" not in df_input.columns:
        raise ValueError("Input dataframe must contain 'V' column for velocities.")

    # Preprocess velocity features to enhance high velocities
    velocities, transform_info = preprocess_velocity_features(
        velocities=df_input["V"].to_numpy(),
        velocity_transform=transform_velocity,
        velocity_params=transform_params,
    )

    # Extract data array for clustering
    if len(variables_names) > 1:
        # Concatenate other features to velocities
        additional_vars = variables_names.copy()
        if "V" in additional_vars:
            additional_vars.remove("V")
        additional_data = df_input[additional_vars].to_numpy()
        data_array = np.column_stack((velocities, additional_data))
    else:
        # Use only velocities
        data_array = velocities.reshape(-1, 1)

    scaler = StandardScaler()
    scaler.fit(data_array)
    joblib.dump(scaler, output_dir / f"{base_name}_scaler.joblib")

    # Scale data for model input
    data_array_scaled = scaler.transform(data_array)

    # Build model
    logger.info(f"Running MCMC clustering for {base_name}...")
    model = mcmc.build_marginalized_mixture_model(
        data_array_scaled,
        prior_probs,
        sectors,
        mu_params=mu_params,
        sigma_params=sigma_params,
        feature_weights=feature_weights,
    )

    # Sample model (1st pass)
    idata, convergence_flag = mcmc.sample_model(
        model, output_dir, base_name, **sample_args
    )
    if not convergence_flag:
        idata_summary = az.summary(idata, var_names=["mu", "sigma"])
        logger.info(f"MCMC did not converge. Summary:\n{idata_summary}")

    # --- MRF regularization of priors and optional re-sample ---
    prior_used = prior_probs
    if mrf_regularization:
        x_pos = df_input["x"].to_numpy()
        y_pos = df_input["y"].to_numpy()
        mkw = dict(n_neighbors=8, length_scale=50, beta=2.0, n_iter=5)
        if mrf_kwargs:
            mkw.update(mrf_kwargs)
        prior_mrf, q_mrf = mcmc.run_mrf_regularization(
            data_array_scaled, idata, prior_probs, x_pos, y_pos, **mkw
        )
        prior_used = prior_mrf

        # visualize refined priors
        try:
            fig, _ = mcmc.plot_spatial_priors(df_input, prior_mrf, img=img)
            fig.savefig(
                output_dir / f"{base_name}_mrf_priors.png", dpi=150, bbox_inches="tight"
            )
            plt.close(fig)
        except Exception as exc:
            logger.warning(f"Could not plot MRF priors: {exc}")

    # Decide second pass strategy
    if mrf_regularization and second_pass.lower() == "skip":
        # Fastest: don't re-sample. Use q_mrf as final posterior_probs and argmax as labels.
        posterior_probs = q_mrf
        cluster_pred = np.argmax(posterior_probs, axis=1)
        uncertainty = 1.0 - posterior_probs.max(axis=1)
        # keep idata from 1st pass for plots/params
    else:
        # Re-sample with refined priors (short or full)
        if mrf_regularization:
            with model:
                pm.set_data({"prior_w": prior_used})

        # Allow short second pass and warm start
        sp2_args = dict(**sample_args)
        if second_pass.lower() == "short":
            # much fewer draws/tune; fewer chains can also help
            sp2_args.update(dict(draws=600, tune=400, chains=2, cores=2))
            if second_pass_sample_args:
                sp2_args.update(second_pass_sample_args)
        elif second_pass_sample_args:
            sp2_args.update(second_pass_sample_args)

        # Warm-start from previous posterior means
        initvals = _initvals_from_idata(idata, sp2_args.get("chains", 2))

        with model:
            # pass initvals through sample_model if it supports, else call pm.sample directly
            try:
                idata, convergence_flag = mcmc.sample_model(
                    model,
                    output_dir,
                    base_name + ("_mrf" if mrf_regularization else ""),
                    initvals=initvals,
                    **sp2_args,
                )
            except TypeError:
                # fallback if your wrapper doesn't accept initvals
                idata = pm.sample(**sp2_args)
                convergence_flag = True

        # Compute posterior-based assignments
        posterior_probs, cluster_pred, uncertainty = mcmc.compute_posterior_assignments(
            idata, n_posterior_samples=200
        )

    # Generate plots
    fig = mcmc.plot_velocity_clustering(
        df_features=df_input,
        img=img,
        idata=idata,
        cluster_pred=cluster_pred,
        posterior_probs=posterior_probs,
        scaler=scaler,
    )
    fig.savefig(
        output_dir / f"{base_name}_results.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.close(fig)

    # Trace plots
    fig, axes = plt.subplots(2, 2, figsize=(10, 6))
    az.plot_trace(
        idata, var_names=["mu", "sigma"], axes=axes, compact=True, legend=True
    )
    fig.savefig(output_dir / f"{base_name}_trace_plots.png", dpi=150)
    plt.close(fig)

    # Forest plots
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    az.plot_forest(idata, var_names=["mu", "sigma"], combined=True, ess=True, ax=axes)
    fig.savefig(output_dir / f"{base_name}_forest_plot.png", dpi=150)
    plt.close(fig)

    # Collect and save metadata
    metadata = mcmc.collect_run_metadata(
        idata=idata,
        convergence_flag=convergence_flag,
        data_array_scaled=data_array_scaled,
        variables_names=variables_names,
        sectors=sectors,
        prior_probs=prior_probs,
        sample_args=sample_args,
        frame=locals(),
    )
    mcmc.save_run_metadata(output_dir, base_name, metadata)

    # Return results dictionary
    result = {
        "metadata": metadata,
        "idata": idata,
        "scaler": scaler,
        "convergence_flag": convergence_flag,
        "posterior_probs": posterior_probs,
        "cluster_pred": cluster_pred,
        "uncertainty": uncertainty,
    }

    plt.close("all")
    return result


# Sample the second derivative at each point's y-coordinate and add as new feature
# if "d2v_dy2" in variables_names:
#     df, d2v_sampled = add_second_derivative_feature(
#         df=df,
#         y_values=df["y"].to_numpy(),
#         bin_centers=bin_centers,
#         vel_second_derivative=vel_second_derivative,
#         valid_bins=valid_bins,
#     )


In [None]:
# Define sigma values for Gaussian smoothing
sigma_values = [2]
variables_names = ["V"]


# Loop through smoothing scales
results = []
for sigma in sigma_values:
    logger.info(f"Processing with Gaussian smoothing sigma={sigma}...")

    # Create scale-specific base name
    scale_base_name = f"{date_start}_{date_end}_sigma{sigma}"

    # Apply Gaussian smoothing if needed (skipped for sigma=0)
    df_run = apply_2d_gaussian_filter(df, sigma=sigma)

    # Adjust model parameters based on scale
    mu_params = {"mu": 0, "sigma": 1 if sigma <= 2 else 0.5}
    sigma_params = {"sigma": 1 if sigma <= 2 else 0.5}

    # Run MCMC clustering with the smoothed data
    result = run_mcmc_clustering(
        df_input=df_run,
        prior_probs=prior_probs,
        sectors=sectors,
        variables_names=variables_names,
        output_dir=output_dir,
        base_name=scale_base_name,
        img=img,
        # transform_velocity="sigmoid",
        # transform_params={"midpoint_percentile": 70, "steepness": 2.0},
        mu_params=mu_params,
        sigma_params=sigma_params,
        random_seed=RANDOM_SEED,
        mrf_regularization=True,
        mrf_kwargs=dict(n_neighbors=8, length_scale=50, beta=2.0, n_iter=5),
        # Speed choices:
        # 1) "short": short sampling + warm-start (recommended default):
        # 2) "skip": fastest, rely only on MRF priors, no re-sampling
        second_pass="short",
        second_pass_sample_args=dict(
            draws=500, tune=300, chains=4, cores=4, target_accept=0.9
        ),
    )

    # Add scale information to result
    result["sigma"] = sigma

    # Append to results list
    results.append(result)


In [None]:
# Save all results to a single joblib file
# joblib.dump(
#     results,
#     output_dir
#     / f"{date_start}_{date_end}_all_results_multiscale.joblib",
# )

# Read the results again
# results = joblib.load(
#     output_dir
#     / f"{date_start}_{date_end}_all_results_multiscale.joblib",
# )

In [None]:
# ===  If a multi-scale approach was used, aggregate the results.
if len(sigma_values) > 1:
    aggregated_results = aggregate_multiscale_clustering(
        results,
        similarity_threshold=0.7,
        overall_threshold=0.8,
        fig_path=output_dir
        / f"{reference_start_date}_{reference_end_date}_similarity_heatmap.jpg",
    )

    # Unpack aggregated results
    cluster_pred = aggregated_results["combined_cluster_pred"]
    posterior_probs = aggregated_results["avg_posterior_probs"]
    entropy = aggregated_results["avg_entropy"]
    similarity_matrix = aggregated_results["similarity_matrix"]
    stability_score = aggregated_results["stability_score"]
    valid_scales = aggregated_results["valid_scales"]

else:
    # Otherwise extract the single result
    cluster_pred = results[0]["cluster_pred"]
    posterior_probs = results[0]["posterior_probs"]
    entropy = -np.sum(posterior_probs * np.log(posterior_probs + 1e-10), axis=1)
    similarity_matrix = None
    stability_score = None
    valid_scales = None


# ===  Save final clustering results
cluster_aggregation_outs = {
    "cluster_pred": cluster_pred,
    "posterior_probs": posterior_probs,
    "entropy": entropy,
    "similarity_matrix": similarity_matrix,
    "stability_score": stability_score,
    "valid_scales": valid_scales,
}
joblib.dump(
    cluster_aggregation_outs,
    output_dir
    / f"{reference_start_date}_{reference_end_date}_kinematic_clustering_results.joblib",
)

In [None]:
# Read data again to skip mcmc sampling if already done
# cluster_aggregation_outs = joblib.load(
#     output_dir
#     / f"{reference_start_date}_{reference_end_date}_kinematic_clustering_results.joblib",
# )
# cluster_pred = cluster_aggregation_outs["cluster_pred"]
# posterior_probs = cluster_aggregation_outs["posterior_probs"]
# entropy = cluster_aggregation_outs["entropy"]
# similarity_matrix = cluster_aggregation_outs["similarity_matrix"]
# stability_score = cluster_aggregation_outs["stability_score"]
# valid_scales = cluster_aggregation_outs["valid_scales"]


### POST-PROCESSING OF CLUSTERING RESULTS


In [None]:
# === Do some post-processing on the clustering results
X, Y, label_grid = create_2d_grid(
    x=df["x"].to_numpy(), y=df["y"].to_numpy(), labels=cluster_pred, grid_spacing=None
)

# Remove small holes
min_size = 50  # Minimum size of connected components to keep
connectivity = 8  # 4 or 8 for pixel connectivity
label_grid = remove_small_grid_components(
    label_grid, min_size=min_size, connectivity=connectivity
)

# # Separate non-connected regions with same label
# label_grid, label_mapping = split_disconnected_components(
#     label_grid, connectivity=connectivity, start_label=0
# )

point_labels_cleaned, x, y = map_grid_to_points(
    X,
    Y,
    label_grid,
    x_points=df["x"].to_numpy(),
    y_points=df["y"].to_numpy(),
    keep_nan=True,
)

# === Compute similarity scores with prior clusters
# Create a "prior class" assignment based on the sector with highest probability
sector_names = list(PRIOR_PROBABILITY.keys())
sector_assignments = np.zeros_like(cluster_pred)
for i, point_probs in enumerate(prior_probs):
    sector_assignments[i] = np.argmax(point_probs)

# Compute similarity metrics
ari = adjusted_rand_score(sector_assignments, cluster_pred)
ami = adjusted_mutual_info_score(sector_assignments, cluster_pred)

# === Make final clustering plot after cleaning
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(img, alpha=0.5, cmap="gray")
colormap = plt.get_cmap("tab10")
for i, label in enumerate(np.unique(point_labels_cleaned)):
    mask = point_labels_cleaned == label
    ax.scatter(
        x[mask],
        y[mask],
        color=colormap(i),
        label=f"Cluster {label}",
        s=10,
        alpha=0.7,
    )
ax.legend(loc="upper right", framealpha=0.9, fontsize=10)
ax.set_aspect("equal")


if valid_scales is not None and len(valid_scales) > 1:
    title = f"Combined Clustering (scales: {valid_scales}, stability: {stability_score if stability_score is not None else '':.2f})\nPrior Agreement: AMI={ami if ami is not None else 0.0:.2f}"
else:
    title = f"Clustering (scale: {sigma_values[0]})\nPrior Agreement: AMI={ami if ami is not None else 0.0:.2f}"

ax.set_title(title)
plt.savefig(
    output_dir
    / f"{reference_start_date}_{reference_end_date}_kinematic_clustering.png",
    dpi=300,
    bbox_inches="tight",
)
# plt.close(fig)

## FIND DISCONTINUITY IN VELOCITY FIELD


In [None]:
from ppcluster.discontinuity import (
    find_vertical_discontinuities,
    plot_discontinuities,
)

df_smooth = apply_2d_gaussian_filter(df, sigma=0)
x = df_smooth["x"].to_numpy()
y = df_smooth["y"].to_numpy()
v = df_smooth["V"].to_numpy()

discontinuity_results = find_vertical_discontinuities(
    x=x,
    y=y,
    v=v,
    vertical_bins=50,
    horizontal_bins=10,
    min_points_per_bin_col=20,  # try 5-20 depending on data density
    gradient_threshold_factor=0.3,  # # Threshold for significant gradient (as fraction of max) adjust to be more/less sensitive
    smoothing_sigma_1d=1.0,
    min_strength=1e-3,
    cluster_eps_factor=2.0,
    cluster_min_samples=3,
    border=[500, 500, 1000, 500],  # left, right, bottom, top in px units
    sectors=sectors,  # optional: use predefined sectors instead of DBSCAN
)
fig, ax = plt.subplots(figsize=(10, 8))
plot_discontinuities(
    x,
    y,
    v,
    discontinuities=discontinuity_results,
    img=img,
    ax=ax,
)


## Morpho-kinematic analysis


In [None]:
from ppcluster.mcmc.postproc import (
    create_2d_grid,
    map_grid_to_points,
    split_disconnected_components,
)

# Retrieve data
kinematics_cluster = point_labels_cleaned.copy()

df_smooth = apply_2d_gaussian_filter(df, sigma=1)
x = df_smooth["x"].to_numpy()
y = df_smooth["y"].to_numpy()
v = df_smooth["V"].to_numpy()

# Ensure arrays are numpy arrays
x = np.asarray(x)
y = np.asarray(y)
kin_cluster = np.asarray(kinematics_cluster)

X, Y, kin_cluster_grid = create_2d_grid(x=x, y=y, labels=kin_cluster)

# Filter out small clusters
kin_cluster_grid = remove_small_grid_components(
    kin_cluster_grid, min_size=100, connectivity=8
)

# Split clusters along detected discontinuities
kin_cluster_grid, split_mapping = split_disconnected_components(
    kin_cluster_grid, connectivity=8, start_label=0
)
kin_cluster, x, y = map_grid_to_points(X, Y, kin_cluster_grid, x, y, keep_nan=True)

# Remove non classified points (-1 label)
valid_mask = kin_cluster >= 0
x = x[valid_mask]
y = y[valid_mask]
v = v[valid_mask]
kin_cluster = kin_cluster[valid_mask]

# Order clusters by median y descending (bottom = largest y first)
clusters_ids = np.unique(kin_cluster)
cluster_median_y = {int(c): float(np.median(y[kin_cluster == c])) for c in clusters_ids}

ordered_clusters_ids = sorted(
    clusters_ids, key=lambda c: cluster_median_y[int(c)], reverse=True
)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(img, alpha=0.5, cmap="gray")
colormap = plt.get_cmap("tab10")
for i, label in enumerate(ordered_clusters_ids):
    mask = kin_cluster == label
    ax.scatter(
        x[mask],
        y[mask],
        color=colormap(i),
        label=f"Cluster {label}",
        s=10,
        alpha=0.7,
    )
ax.legend(loc="upper right", framealpha=0.9, fontsize=10)
ax.set_aspect("equal")

In [None]:
# Manual assignment of clusters to sectors
mk_label_str = np.full_like(kin_cluster, "", dtype=object)
mk_label_id = -1 * np.ones_like(kin_cluster, dtype=int)

# Assign cluster 1 to sector A
mk_label_str[kin_cluster == ordered_clusters_ids[0]] = "A"
mk_label_id[kin_cluster == ordered_clusters_ids[0]] = 0

# Assign cluster 2 to sector B
mk_label_str[kin_cluster == ordered_clusters_ids[1]] = "B"
mk_label_id[kin_cluster == ordered_clusters_ids[1]] = 1

# Assign cluster 0 to sector B1 (fast area in sector B)
# mk_label_str[kin_cluster == ordered_clusters_ids[2]] = "B1"
# mk_label_id[kin_cluster == ordered_clusters_ids[2]] = 2

# Assign cluster 4 to sector C
mk_label_str[kin_cluster == ordered_clusters_ids[2]] = "C"
mk_label_id[kin_cluster == ordered_clusters_ids[2]] = 2

# Assign cluster 5 to sector D
mk_label_str[kin_cluster == ordered_clusters_ids[3]] = "D"
mk_label_id[kin_cluster == ordered_clusters_ids[3]] = 3

# Assign cluster 3 to sector E (fast area in sector D)
mk_label_str[kin_cluster == ordered_clusters_ids[4]] = "D1"
mk_label_id[kin_cluster == ordered_clusters_ids[4]] = 4

unique_mk, counts = np.unique(mk_label_str[mk_label_str != ""], return_counts=True)

In [None]:
# Colors: red, orange, dark_orange yellow, green, dark_green
colors = {
    "A": "#b3140b",
    "B": "#ee9c21",
    "B1": "#ff4800",
    "C": "#f1ee30",
    "D": "#5fb61c",
    "D1": "#006837",
}


# Prepare output arrays (string labels like "A1","B1","C1", and numeric ids)
print("Morpho-kinematic assignment summary:")
for u, c in zip(unique_mk, counts, strict=False):
    print(f"  {u}: {c} points")

# Plot the assignment for quick inspection
fig, ax = plt.subplots(figsize=(8, 8))
if img is not None:
    ax.imshow(img, alpha=0.5, cmap="gray")

# Plot points colored by mk label
for lab in unique_mk:
    mask = mk_label_str == lab
    ax.scatter(x[mask], y[mask], color=colors[lab], label=lab, s=10, alpha=0.8)

ax.legend(loc="upper right", fontsize=9)
ax.set_title("Morpho-Kinematic Sectors")
ax.set_aspect("equal")
plt.show()


In [None]:
def compute_mk_sector_polygons(
    x,
    y,
    mk_label_str,
    *,
    smooth_iters=2,
    prevent_overlap=True,
    containment_strategy="difference",
):
    """Return smoothed MK-sector polygons with optional overlap and containment handling."""
    import warnings
    from collections import defaultdict

    import numpy as np
    from scipy.spatial import Delaunay

    containment_strategy = str(containment_strategy).lower()
    if containment_strategy not in {"difference", "keep"}:
        raise ValueError("containment_strategy must be 'difference' or 'keep'.")

    shapely_available = False
    Polygon = MultiPolygon = unary_union = None
    require_shapely = prevent_overlap or containment_strategy == "difference"
    if require_shapely:
        try:
            from shapely.geometry import MultiPolygon, Polygon  # type: ignore
            from shapely.ops import unary_union  # type: ignore

            shapely_available = True
        except ImportError:
            warnings.warn(
                "Shapely is not installed. Falling back to simple polygons without "
                "overlap containment handling.",
                RuntimeWarning,
            )
            prevent_overlap = False
            containment_strategy = "keep"

    def _chaikin(poly, n_iter=2):
        if poly is None or len(poly) < 3 or n_iter <= 0:
            return poly
        P = np.asarray(poly, float)
        for _ in range(n_iter):
            if not np.allclose(P[0], P[-1]):
                P = np.vstack([P, P[0]])
            Q = []
            for i in range(len(P) - 1):
                p0, p1 = P[i], P[i + 1]
                Q.append(0.75 * p0 + 0.25 * p1)
                Q.append(0.25 * p0 + 0.75 * p1)
            P = np.array(Q)
        if np.allclose(P[0], P[-1]):
            P = P[:-1]
        return P

    def _boundary_polygon(points):
        pts = np.asarray(points, float)
        if pts.shape[0] < 3:
            return None
        try:
            tri = Delaunay(pts)
        except Exception:
            return None
        edge_count = defaultdict(int)
        for simplex in tri.simplices:
            simplex = list(simplex)
            edges = [
                (simplex[i], simplex[j]) for i in range(3) for j in range(i + 1, 3)
            ]
            for i, j in edges:
                if i > j:
                    i, j = j, i
                edge_count[(i, j)] += 1
        boundary_edges = [e for e, c in edge_count.items() if c == 1]
        if not boundary_edges:
            return None
        adj = defaultdict(list)
        for i, j in boundary_edges:
            adj[i].append(j)
            adj[j].append(i)
        start = boundary_edges[0][0]
        poly_idx = [start]
        prev = None
        current = start
        while True:
            neigh = adj[current]
            nxt_candidates = [n for n in neigh if n != prev]
            if not nxt_candidates:
                break
            nxt = nxt_candidates[0]
            if nxt == poly_idx[0] and len(poly_idx) >= 3:
                break
            poly_idx.append(nxt)
            prev, current = current, nxt
            if len(poly_idx) > len(pts) * 2:
                break
        if len(poly_idx) < 3:
            return None
        return pts[poly_idx]

    class PolygonDict(dict):
        def __init__(self, *args, geometries=None, **kwargs):
            super().__init__(*args, **kwargs)
            self.geometries = geometries or {}

    labels = np.asarray(mk_label_str)
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    unique_labels = [lab for lab in np.unique(labels) if isinstance(lab, str) and lab]

    polygons = PolygonDict()
    geometries = {}

    for lab in unique_labels:
        mask = labels == lab
        pts = np.column_stack([x[mask], y[mask]])
        poly = _boundary_polygon(pts)
        if poly is None:
            continue
        poly_sm = _chaikin(poly, n_iter=smooth_iters)
        if poly_sm is None or len(poly_sm) < 3:
            continue

        poly_shape = None
        if shapely_available:
            poly_shape = Polygon(poly_sm)
            if not poly_shape.is_valid:
                poly_shape = poly_shape.buffer(0)

        if shapely_available and poly_shape and not poly_shape.is_empty:
            existing_union = (
                unary_union(list(geometries.values())) if geometries else None
            )
            containing_labels = [
                prev_lab
                for prev_lab, prev_shape in geometries.items()
                if prev_shape.contains(poly_shape)
            ]

            if containing_labels and containment_strategy == "difference":
                for prev_lab in containing_labels:
                    updated_shape = geometries[prev_lab].difference(poly_shape)
                    if updated_shape.is_empty:
                        geometries.pop(prev_lab, None)
                        polygons.pop(prev_lab, None)
                    else:
                        if isinstance(updated_shape, MultiPolygon):
                            updated_shape = max(
                                updated_shape.geoms, key=lambda g: g.area
                            )
                        geometries[prev_lab] = updated_shape
                        coords = np.asarray(updated_shape.exterior.coords)
                        if coords.shape[0] >= 4:
                            polygons[prev_lab] = coords[:-1]
                        else:
                            polygons.pop(prev_lab, None)
                existing_union = (
                    unary_union(list(geometries.values())) if geometries else None
                )

            skip_difference = bool(containing_labels and containment_strategy == "keep")

            if prevent_overlap and not skip_difference and existing_union:
                poly_shape = poly_shape.difference(existing_union)

            if poly_shape.is_empty:
                continue

            if isinstance(poly_shape, MultiPolygon):
                poly_shape = max(poly_shape.geoms, key=lambda g: g.area)
            coords = np.asarray(poly_shape.exterior.coords)
            if coords.shape[0] < 4:
                continue
            polygons[lab] = coords[:-1]
            geometries[lab] = poly_shape
        else:
            polygons[lab] = np.asarray(poly_sm, float)

    if shapely_available:
        polygons.geometries = geometries

    return polygons


def plot_mk_sectors_smooth_perimeters(
    x,
    y,
    mk_label_str,
    *,
    v=None,
    img=None,
    smooth_iters=2,
    colormap=None,
    palette="tab20",
    fill=False,
    fill_alpha=0.2,
    edge_alpha=0.95,
    scatter_alpha=0.3,
    ax=None,
    polygons=None,
    polygon_kwargs=None,
):
    """Plot MK-sector perimeters, optionally reusing pre-computed polygons."""
    import numpy as np
    from matplotlib import pyplot as plt

    labels = np.asarray(mk_label_str)
    x = np.asarray(x, float)
    y = np.asarray(y, float)

    if polygons is None:
        polygon_kwargs = polygon_kwargs or {}
        polygons = compute_mk_sector_polygons(
            x,
            y,
            labels,
            smooth_iters=smooth_iters,
            **polygon_kwargs,
        )

    unique_labels = list(polygons.keys())
    if colormap is not None:
        palette_colors = [colormap.get(lab, (0.5, 0.5, 0.5)) for lab in unique_labels]
    else:
        cmap = plt.get_cmap(palette)
        palette_colors = cmap(np.linspace(0, 1, max(len(unique_labels), 1)))

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    else:
        fig = ax.figure

    if img is not None:
        ax.imshow(img, cmap="gray", alpha=0.5)

    if v is not None:
        sc = ax.scatter(
            x,
            y,
            c=np.asarray(v, float),
            cmap="viridis",
            s=4,
            alpha=scatter_alpha,
            zorder=0,
        )
        cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("Velocity", rotation=270, labelpad=15)

    for idx, lab in enumerate(unique_labels):
        poly = polygons[lab]
        if poly is None or len(poly) < 3:
            continue
        edge_color = (
            colormap.get(lab, palette_colors[idx]) if colormap else palette_colors[idx]
        )

        if fill:
            ax.fill(
                poly[:, 0],
                poly[:, 1],
                color=edge_color,
                alpha=fill_alpha,
                lw=0,
                zorder=1,
            )

        ax.plot(
            np.r_[poly[:, 0], poly[0, 0]],
            np.r_[poly[:, 1], poly[0, 1]],
            color=edge_color,
            alpha=edge_alpha,
            lw=2,
            label=lab,
            zorder=2,
        )

        cx, cy = poly.mean(axis=0)
        ax.text(
            cx,
            cy,
            lab,
            color="k",
            fontsize=10,
            ha="center",
            va="center",
            bbox=dict(facecolor="white", alpha=0.6, edgecolor="none", pad=1.5),
            zorder=3,
        )

    ax.set_aspect("equal")
    if unique_labels:
        ax.legend(loc="upper right", fontsize=9, framealpha=0.85)
    ax.set_title("Morpho-kinematic sectors")
    return fig, ax


def compute_mk_sector_stats(
    polygons,
    mk_label_str,
    *,
    x,
    y,
    v=None,
    img_shape=None,
    rasterize=False,
):
    import numpy as np
    import pandas as pd

    def _poly_area(poly):
        x0, y0 = poly[:, 0], poly[:, 1]
        return 0.5 * np.abs(np.dot(x0, np.roll(y0, -1)) - np.dot(y0, np.roll(x0, -1)))

    def _poly_perimeter(poly):
        diffs = np.diff(np.vstack([poly, poly[0]]), axis=0)
        return np.sum(np.hypot(diffs[:, 0], diffs[:, 1]))

    def _poly_centroid(poly):
        x0, y0 = poly[:, 0], poly[:, 1]
        a = np.dot(x0, np.roll(y0, -1)) - np.dot(y0, np.roll(x0, -1))
        A = a / 2.0
        if np.isclose(A, 0):
            return np.array([x0.mean(), y0.mean()])
        cx = (1 / (6 * A)) * np.sum(
            (x0 + np.roll(x0, -1)) * (x0 * np.roll(y0, -1) - y0 * np.roll(x0, -1))
        )
        cy = (1 / (6 * A)) * np.sum(
            (y0 + np.roll(y0, -1)) * (x0 * np.roll(y0, -1) - y0 * np.roll(x0, -1))
        )
        return np.array([cx, cy])

    try:
        from skimage.draw import polygon as sk_polygon

        sk_ok = True
    except Exception:
        sk_ok = False

    labels = np.asarray(mk_label_str)
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    v = np.asarray(v, float) if v is not None else None
    stats_rows = []
    for lab, poly in polygons.items():
        if poly is None or len(poly) < 3:
            continue
        mask = labels == lab
        n_pts = int(np.sum(mask))
        area = _poly_area(poly)
        perim = _poly_perimeter(poly)
        centroid = _poly_centroid(poly)
        compactness = (4 * np.pi * area) / (perim**2 + 1e-12) if perim > 0 else np.nan
        vel_stats = dict(
            v_mean=np.nan, v_std=np.nan, v_median=np.nan, v_min=np.nan, v_max=np.nan
        )
        if v is not None and n_pts > 0:
            v_sel = v[mask]
            vel_stats = dict(
                v_mean=float(np.mean(v_sel)),
                v_std=float(np.std(v_sel)),
                v_median=float(np.median(v_sel)),
                v_min=float(np.min(v_sel)),
                v_max=float(np.max(v_sel)),
            )
        pixel_count = np.nan
        if rasterize and sk_ok and img_shape is not None:
            h, w = img_shape[:2]
            rr, cc = sk_polygon(poly[:, 1], poly[:, 0], shape=(h, w))
            pixel_count = int(len(rr))
        stats_rows.append(
            {
                "label": lab,
                "n_points": n_pts,
                "area_px2": float(area),
                "perimeter_px": float(perim),
                "compactness": float(compactness),
                "centroid_x": float(centroid[0]),
                "centroid_y": float(centroid[1]),
                "pixel_count": pixel_count,
                "point_density_pts_per_px2": float(n_pts / area)
                if area > 0
                else np.nan,
                **vel_stats,
            }
        )
    columns = [
        "label",
        "n_points",
        "area_px2",
        "perimeter_px",
        "compactness",
        "centroid_x",
        "centroid_y",
        "pixel_count",
        "point_density_pts_per_px2",
        "v_mean",
        "v_std",
        "v_median",
        "v_min",
        "v_max",
    ]
    return (
        pd.DataFrame(stats_rows, columns=columns)
        .sort_values("label")
        .reset_index(drop=True)
    )


colors = {
    "A": "#b3140b",
    "B": "#ee9c21",
    "B1": "#ff4800",
    "C": "#f1ee30",
    "D": "#5fb61c",
    "D1": "#006837",
}


fig, ax = plt.subplots(figsize=(8, 8))

# Plot the main clusters - Exclude clusters B1 and D1 from filled plot
minor_clusters = ["B1", "D1"]
mask = np.isin(mk_label_str, minor_clusters)
x_major = x[~mask]
y_major = y[~mask]
mk_label_str_major = mk_label_str[~mask]
colors_plot = {k: v for k, v in colors.items() if k not in minor_clusters}
polygons_major = compute_mk_sector_polygons(
    x_major,
    y_major,
    mk_label_str_major,
    smooth_iters=5,
    prevent_overlap=True,
)

plot_mk_sectors_smooth_perimeters(
    x_major,
    y_major,
    mk_label_str_major,
    img=np.asarray(img),
    smooth_iters=4,
    colormap=colors_plot,
    fill=True,
    fill_alpha=0.1,
    edge_alpha=1,
    scatter_alpha=0.3,
    polygons=polygons_major,
    ax=ax,
)


# Add the excluded clusters as outlines only
mask = np.isin(mk_label_str, minor_clusters)
x_minor = x[mask]
y_minor = y[mask]
mk_label_str_minor = mk_label_str[mask]
colors_plot = {k: v for k, v in colors.items() if k in minor_clusters}
polygons_minor = compute_mk_sector_polygons(
    x_minor,
    y_minor,
    mk_label_str_minor,
    smooth_iters=4,
    prevent_overlap=False,
)
plot_mk_sectors_smooth_perimeters(
    x_minor,
    y_minor,
    mk_label_str_minor,
    smooth_iters=2,
    colormap=colors_plot,
    fill=True,
    fill_alpha=0.1,
    edge_alpha=1,
    scatter_alpha=0.3,
    polygons=polygons_minor,
    ax=ax,
)

plt.show()

In [None]:
# def compute_mk_sector_polygons(
#     x,
#     y,
#     mk_label_str,
#     *,
#     smooth_iters=2,
#     prevent_overlap=True,
#     containment_strategy="difference",
# ):
#     """Return smoothed MK-sector polygons with optional overlap and containment handling."""
#     import warnings
#     from collections import defaultdict

#     import numpy as np
#     from scipy.spatial import Delaunay

#     containment_strategy = str(containment_strategy).lower()
#     if containment_strategy not in {"difference", "keep"}:
#         raise ValueError("containment_strategy must be 'difference' or 'keep'.")

#     shapely_available = False
#     Polygon = MultiPolygon = unary_union = None
#     require_shapely = prevent_overlap or containment_strategy == "difference"
#     if require_shapely:
#         try:
#             from shapely.geometry import MultiPolygon, Polygon  # type: ignore
#             from shapely.ops import unary_union  # type: ignore

#             shapely_available = True
#         except ImportError:
#             warnings.warn(
#                 "Shapely is not installed. Falling back to simple polygons without "
#                 "overlap/containment handling.",
#                 RuntimeWarning,
#             )
#             prevent_overlap = False
#             containment_strategy = "keep"

#     def _chaikin(poly, n_iter=2):
#         if poly is None or len(poly) < 3 or n_iter <= 0:
#             return poly
#         P = np.asarray(poly, float)
#         for _ in range(n_iter):
#             if not np.allclose(P[0], P[-1]):
#                 P = np.vstack([P, P[0]])
#             Q = []
#             for i in range(len(P) - 1):
#                 p0, p1 = P[i], P[i + 1]
#                 Q.append(0.75 * p0 + 0.25 * p1)
#                 Q.append(0.25 * p0 + 0.75 * p1)
#             P = np.array(Q)
#         if np.allclose(P[0], P[-1]):
#             P = P[:-1]
#         return P

#     def _boundary_polygon(points):
#         pts = np.asarray(points, float)
#         if pts.shape[0] < 3:
#             return None
#         try:
#             tri = Delaunay(pts)
#         except Exception:
#             return None
#         edge_count = defaultdict(int)
#         for simplex in tri.simplices:
#             simplex = list(simplex)
#             edges = [
#                 (simplex[i], simplex[j]) for i in range(3) for j in range(i + 1, 3)
#             ]
#             for i, j in edges:
#                 if i > j:
#                     i, j = j, i
#                 edge_count[(i, j)] += 1
#         boundary_edges = [e for e, c in edge_count.items() if c == 1]
#         if not boundary_edges:
#             return None
#         adj = defaultdict(list)
#         for i, j in boundary_edges:
#             adj[i].append(j)
#             adj[j].append(i)
#         start = boundary_edges[0][0]
#         poly_idx = [start]
#         prev = None
#         current = start
#         while True:
#             neigh = adj[current]
#             nxt_candidates = [n for n in neigh if n != prev]
#             if not nxt_candidates:
#                 break
#             nxt = nxt_candidates[0]
#             if nxt == poly_idx[0] and len(poly_idx) >= 3:
#                 break
#             poly_idx.append(nxt)
#             prev, current = current, nxt
#             if len(poly_idx) > len(pts) * 2:
#                 break
#         if len(poly_idx) < 3:
#             return None
#         return pts[poly_idx]

#     labels = np.asarray(mk_label_str)
#     x = np.asarray(x, float)
#     y = np.asarray(y, float)
#     unique_labels = [lab for lab in np.unique(labels) if isinstance(lab, str) and lab]

#     polygons: dict[str, np.ndarray] = {}
#     geometries: dict[str, "Polygon"] = {}

#     for lab in unique_labels:
#         mask = labels == lab
#         pts = np.column_stack([x[mask], y[mask]])
#         poly = _boundary_polygon(pts)
#         if poly is None:
#             continue
#         poly_sm = _chaikin(poly, n_iter=smooth_iters)
#         if poly_sm is None or len(poly_sm) < 3:
#             continue

#         poly_shape = None
#         if shapely_available:
#             poly_shape = Polygon(poly_sm)
#             if not poly_shape.is_valid:
#                 poly_shape = poly_shape.buffer(0)

#         if shapely_available and poly_shape and not poly_shape.is_empty:
#             containing_labels = [
#                 prev_lab
#                 for prev_lab, prev_shape in geometries.items()
#                 if prev_shape.contains(poly_shape)
#             ]
#             contained_labels = [
#                 prev_lab
#                 for prev_lab, prev_shape in geometries.items()
#                 if poly_shape.contains(prev_shape)
#             ]

#             if containment_strategy == "difference":
#                 for prev_lab in containing_labels:
#                     updated_shape = geometries[prev_lab].difference(poly_shape)
#                     if updated_shape.is_empty:
#                         geometries.pop(prev_lab, None)
#                         polygons.pop(prev_lab, None)
#                         continue
#                     if isinstance(updated_shape, MultiPolygon):
#                         updated_shape = max(updated_shape.geoms, key=lambda g: g.area)
#                     coords_prev = np.asarray(updated_shape.exterior.coords)
#                     if coords_prev.shape[0] < 4:
#                         geometries.pop(prev_lab, None)
#                         polygons.pop(prev_lab, None)
#                     else:
#                         geometries[prev_lab] = updated_shape
#                         polygons[prev_lab] = coords_prev[:-1]

#                 if contained_labels:
#                     subtract_union = unary_union(
#                         [geometries[prev_lab] for prev_lab in contained_labels]
#                     )
#                     poly_shape = poly_shape.difference(subtract_union)
#                     if poly_shape.is_empty:
#                         continue
#                     if isinstance(poly_shape, MultiPolygon):
#                         poly_shape = max(poly_shape.geoms, key=lambda g: g.area)

#             if prevent_overlap:
#                 excluded_labels = set()
#                 if containment_strategy == "keep":
#                     excluded_labels.update(containing_labels)
#                     excluded_labels.update(contained_labels)
#                 overlap_targets = [
#                     prev_shape
#                     for prev_lab, prev_shape in geometries.items()
#                     if prev_lab not in excluded_labels
#                 ]
#                 if overlap_targets:
#                     overlap_union = unary_union(overlap_targets)
#                     poly_shape = poly_shape.difference(overlap_union)
#                     if poly_shape.is_empty:
#                         continue
#                     if isinstance(poly_shape, MultiPolygon):
#                         poly_shape = max(poly_shape.geoms, key=lambda g: g.area)

#             coords = np.asarray(poly_shape.exterior.coords)
#             if coords.shape[0] < 4:
#                 continue
#             polygons[lab] = coords[:-1]
#             geometries[lab] = poly_shape
#         else:
#             polygons[lab] = np.asarray(poly_sm, float)

#     return polygons


# fig, ax = plt.subplots(figsize=(8, 8))
# polygons = compute_mk_sector_polygons(
#     x,
#     y,
#     mk_label_str,
#     smooth_iters=5,
#     prevent_overlap=True,
#     containment_strategy="difference",
# )

# plot_mk_sectors_smooth_perimeters(
#     x,
#     y,
#     mk_label_str,
#     img=np.asarray(img),
#     smooth_iters=4,
#     colormap=colors_plot,
#     fill=True,
#     fill_alpha=0.1,
#     edge_alpha=1,
#     scatter_alpha=0.3,
#     polygons=polygons,
#     ax=ax,
# )


In [None]:
mk_stats = compute_mk_sector_stats(
    polygons_major,
    mk_label_str,
    x=x,
    y=y,
    v=v,
    img_shape=np.asarray(img).shape if img is not None else None,
    rasterize=True,
)

In [None]:
mk_stats

## Deprecated code


In [None]:
# # --- Morpho-kinematic assignment ---


# kinematics_cluster = point_labels_cleaned

# # Ensure arrays are numpy arrays
# x = np.asarray(x)
# y = np.asarray(y)
# kin_cluster = np.asarray(kinematics_cluster)

# # Remove non classified points (-1 label)
# valid_mask = kin_cluster >= 0
# x = x[valid_mask]
# y = y[valid_mask]
# kin_cluster = kin_cluster[valid_mask]

# # Compute per-cluster median y (to order clusters from bottom -> top)
# clusters = np.unique(kin_cluster)
# cluster_median_y = {int(c): float(np.median(y[kin_cluster == c])) for c in clusters}

# # Order clusters by median y descending (bottom = largest y first)
# ordered_clusters = sorted(
#     clusters, key=lambda c: cluster_median_y[int(c)], reverse=True
# )

# # Read discontinuities (clustered boundaries) from detection, if available
# clustered_disc = (
#     discontinuity_results.get("clustered", [])
#     if discontinuity_results is not None
#     else []
# )

# logger.info("Found %d morphological discontinuities", len(clustered_disc))


# # If a discontinuity falls inside a kinematic cluster, split that cluster into two parts.


# # replace the call to _split_clusters_on_discontinuities accordingly where used:
# # previously: kin_cluster, splits = _split_clusters_on_discontinuities(kin_cluster, y, clustered_disc, min_points_side=20)
# # now provide both x and y arrays:


# # Perform splitting using detected discontinuities
# if clustered_disc:
#     kin_cluster, splits = _split_clusters_on_discontinuities(
#         kin_cluster, x, y, clustered_disc, min_points_side=50
#     )

#     if splits:
#         # recompute clusters and medians after splits
#         clusters = np.unique(kin_cluster)
#         cluster_median_y = {
#             int(c): float(np.median(y[kin_cluster == c])) for c in clusters
#         }
#         ordered_clusters = sorted(
#             clusters, key=lambda c: cluster_median_y[int(c)], reverse=True
#         )
#         logger.info("Performed %d cluster split(s) due to discontinuities", len(splits))
#     else:
#         logger.info(
#             "No cluster splits performed (no discontinuity inside single cluster with enough support)"
#         )

# # Read first discontinuity (the one separating bottom A from the rest), if available
# first_discont_y = None
# if clustered_disc:
#     # clustered_boundaries were sorted reverse=True in detection, so first is bottommost discontinuity
#     first_discont_y = float(clustered_disc[0]["position"])

# # Prepare output arrays (string labels like "A1","B1","C1", and numeric ids)
# mk_label_str = np.full_like(kin_cluster, "", dtype=object)
# mk_label_id = -1 * np.ones_like(kin_cluster, dtype=int)

# # Helper to assign a label for a given cluster and mask
# next_mk_id = 0


# def _assign(cluster_val, mask, label_text):
#     global next_mk_id
#     mk_label_str[mask] = label_text
#     mk_label_id[mask] = next_mk_id
#     next_mk_id += 1


# # If we have a discontinuity, use it to split bottom vs above
# if first_discont_y is not None:
#     # Bottom region: y >= first_discont_y (image coords: large y -> bottom)
#     bottom_mask = y >= first_discont_y
#     clusters_in_bottom = (
#         np.unique(kin_cluster[bottom_mask]) if np.any(bottom_mask) else np.array([])
#     )

#     # Order them bottom->up (by median y) and assign A1, A2, ...
#     clusters_in_bottom_ordered = sorted(
#         clusters_in_bottom, key=lambda c: cluster_median_y[int(c)], reverse=True
#     )
#     for i, cl in enumerate(clusters_in_bottom_ordered, start=1):
#         mask = (kin_cluster == cl) & bottom_mask
#         if not np.any(mask):
#             continue
#         _assign(int(cl), mask, f"A{i}")

#     # Above region: y < first_discont_y
#     above_mask = y < first_discont_y
#     clusters_above = (
#         np.unique(kin_cluster[above_mask]) if np.any(above_mask) else np.array([])
#     )

#     if clusters_above.size > 0:
#         # Order by median y (closest to discontinuity first)
#         clusters_above_ordered = sorted(
#             clusters_above, key=lambda c: cluster_median_y[int(c)], reverse=True
#         )

#         # Assign B1 to the cluster closest to the discontinuity (highest median y among above)
#         b_cl = clusters_above_ordered[0]
#         b_mask = (kin_cluster == b_cl) & above_mask
#         _assign(int(b_cl), b_mask, "B1")

#         # Remaining above clusters -> assign to C1, C2, ... (upper part / slowest)
#         for j, cl in enumerate(clusters_above_ordered[1:], start=1):
#             mask = (kin_cluster == cl) & above_mask
#             if not np.any(mask):
#                 continue
#             _assign(int(cl), mask, f"C{j}")
# else:
#     # No discontinuity found: use kinematic ordering to define A / B / C
#     if len(ordered_clusters) == 1:
#         # single cluster -> all A1
#         _assign(int(ordered_clusters[0]), np.ones_like(kin_cluster, dtype=bool), "A1")
#     elif len(ordered_clusters) == 2:
#         # bottom -> A1, top -> B1
#         _assign(int(ordered_clusters[0]), kin_cluster == ordered_clusters[0], "A1")
#         _assign(int(ordered_clusters[1]), kin_cluster == ordered_clusters[1], "B1")
#     else:
#         # >=3 clusters: bottom->A1, next->B1, rest -> C1,C2...
#         _assign(int(ordered_clusters[0]), kin_cluster == ordered_clusters[0], "A1")
#         _assign(int(ordered_clusters[1]), kin_cluster == ordered_clusters[1], "B1")
#         for j, cl in enumerate(ordered_clusters[2:], start=1):
#             _assign(int(cl), kin_cluster == cl, f"C{j}")

# # Summary counts
# unique_mk, counts = np.unique(mk_label_str[mk_label_str != ""], return_counts=True)
# print("Morpho-kinematic assignment summary:")
# for u, c in zip(unique_mk, counts, strict=False):
#     print(f"  {u}: {c} points")

# # Plot the assignment for quick inspection
# fig, ax = plt.subplots(figsize=(8, 8))
# if img is not None:
#     ax.imshow(img, alpha=0.5, cmap="gray")

# # Create color map for labels
# labels_present = list(unique_mk)
# cmap = plt.get_cmap("tab10")
# colors = {lab: cmap(i % 10) for i, lab in enumerate(labels_present)}

# # Plot points colored by mk label
# for lab in labels_present:
#     mask = mk_label_str == lab
#     ax.scatter(x[mask], y[mask], color=colors[lab], label=lab, s=10, alpha=0.8)

# # Plot kinematic cluster boundaries (optional: scatter of kinematic clusters)
# # Also overlay discontinuities if present
# if first_discont_y is not None:
#     ax.axhline(
#         first_discont_y,
#         color="red",
#         linestyle="--",
#         linewidth=2,
#         label="first discontinuity",
#     )

# ax.legend(loc="upper right", fontsize=9)
# ax.set_title("Morpho-Kinematic Assignment")
# ax.set_aspect("equal")
# plt.show()

In [None]:
# # --- Morpho-kinematic assignment ---

# kinematics_cluster = point_labels_cleaned

# # Ensure arrays are numpy arrays
# x = np.asarray(x)
# y = np.asarray(y)
# kin_cluster = np.asarray(kinematics_cluster)

# # Remove non classified points (-1 label)
# valid_mask = kin_cluster >= 0
# x = x[valid_mask]
# y = y[valid_mask]
# kin_cluster = kin_cluster[valid_mask]

# # Compute per-cluster median y (to order clusters from bottom -> top)
# clusters = np.unique(kin_cluster)
# cluster_median_y = {int(c): float(np.median(y[kin_cluster == c])) for c in clusters}

# # Order clusters by median y descending (bottom = largest y first)
# ordered_clusters = sorted(
#     clusters, key=lambda c: cluster_median_y[int(c)], reverse=True
# )

# # Read first discontinuity (the one separating bottom A from the rest), if available
# first_discont_y = None
# clustered_disc = (
#     discontinuity_results.get("clustered", [])
#     if discontinuity_results is not None
#     else []
# )
# if clustered_disc:
#     # clustered_boundaries were sorted reverse=True in detection, so first is bottommost discontinuity
#     first_discont_y = float(clustered_disc[0]["position"])

# # Prepare output arrays (string labels like "A1","B1","C1", and numeric ids)
# mk_label_str = np.full_like(kin_cluster, "", dtype=object)
# mk_label_id = -1 * np.ones_like(kin_cluster, dtype=int)

# # Helper to assign a label for a given cluster and mask
# next_mk_id = 0


# def _assign(cluster_val, mask, label_text):
#     global next_mk_id
#     mk_label_str[mask] = label_text
#     mk_label_id[mask] = next_mk_id
#     next_mk_id += 1


# # If we have a discontinuity, use it to split bottom vs above
# if first_discont_y is not None:
#     # Bottom region: y >= first_discont_y (image coords: large y -> bottom)
#     bottom_mask = y >= first_discont_y
#     clusters_in_bottom = (
#         np.unique(kin_cluster[bottom_mask]) if np.any(bottom_mask) else np.array([])
#     )

#     # Order them bottom->up (by median y) and assign A1, A2, ...
#     clusters_in_bottom_ordered = sorted(
#         clusters_in_bottom, key=lambda c: cluster_median_y[int(c)], reverse=True
#     )
#     for i, cl in enumerate(clusters_in_bottom_ordered, start=1):
#         mask = (kin_cluster == cl) & bottom_mask
#         if not np.any(mask):
#             continue
#         _assign(int(cl), mask, f"A{i}")

#     # Above region: y < first_discont_y
#     above_mask = y < first_discont_y
#     clusters_above = (
#         np.unique(kin_cluster[above_mask]) if np.any(above_mask) else np.array([])
#     )

#     if clusters_above.size > 0:
#         # Order by median y (closest to discontinuity first)
#         clusters_above_ordered = sorted(
#             clusters_above, key=lambda c: cluster_median_y[int(c)], reverse=True
#         )

#         # Assign B1 to the cluster closest to the discontinuity (highest median y among above)
#         b_cl = clusters_above_ordered[0]
#         b_mask = (kin_cluster == b_cl) & above_mask
#         _assign(int(b_cl), b_mask, "B1")

#         # Remaining above clusters -> assign to C1, C2, ... (upper part / slowest)
#         for j, cl in enumerate(clusters_above_ordered[1:], start=1):
#             mask = (kin_cluster == cl) & above_mask
#             if not np.any(mask):
#                 continue
#             _assign(int(cl), mask, f"C{j}")
# else:
#     # No discontinuity found: use kinematic ordering to define A / B / C
#     if len(ordered_clusters) == 1:
#         # single cluster -> all A1
#         _assign(int(ordered_clusters[0]), np.ones_like(kin_cluster, dtype=bool), "A1")
#     elif len(ordered_clusters) == 2:
#         # bottom -> A1, top -> B1
#         _assign(int(ordered_clusters[0]), kin_cluster == ordered_clusters[0], "A1")
#         _assign(int(ordered_clusters[1]), kin_cluster == ordered_clusters[1], "B1")
#     else:
#         # >=3 clusters: bottom->A1, next->B1, rest -> C1,C2...
#         _assign(int(ordered_clusters[0]), kin_cluster == ordered_clusters[0], "A1")
#         _assign(int(ordered_clusters[1]), kin_cluster == ordered_clusters[1], "B1")
#         for j, cl in enumerate(ordered_clusters[2:], start=1):
#             _assign(int(cl), kin_cluster == cl, f"C{j}")

# # Summary counts
# unique_mk, counts = np.unique(mk_label_str[mk_label_str != ""], return_counts=True)
# print("Morpho-kinematic assignment summary:")
# for u, c in zip(unique_mk, counts, strict=False):
#     print(f"  {u}: {c} points")


# # Plot the assignment for quick inspection
# fig, ax = plt.subplots(figsize=(8, 8))
# if img is not None:
#     ax.imshow(img, alpha=0.5, cmap="gray")

# # Create color map for labels
# labels_present = list(unique_mk)
# cmap = plt.get_cmap("tab10")
# colors = {lab: cmap(i % 10) for i, lab in enumerate(labels_present)}

# # Plot points colored by mk label
# for lab in labels_present:
#     mask = mk_label_str == lab
#     ax.scatter(x[mask], y[mask], color=colors[lab], label=lab, s=10, alpha=0.8)

# # Plot kinematic cluster boundaries (optional: scatter of kinematic clusters)
# # Also overlay discontinuities if present
# if first_discont_y is not None:
#     ax.axhline(
#         first_discont_y,
#         color="red",
#         linestyle="--",
#         linewidth=2,
#         label="first discontinuity",
#     )

# ax.legend(loc="upper right", fontsize=9)
# ax.set_title("Morpho-Kinematic Assignment")
# ax.set_aspect("equal")
# plt.show()