## CLASSIGY NEW DATA


In [1]:
import glob
import logging
from pathlib import Path

import arviz as az
import joblib
import numpy as np
import pandas as pd
import pymc as pm
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from sqlalchemy import create_engine

from src.clustering import preproc_features
from src.config import ConfigManager
from src.database import (
    get_dic_analysis_by_ids,
    get_dic_analysis_ids,
    get_dic_data,
    get_image,
    get_multi_dic_data,
)
from src.preprocessing import apply_dic_filters, spatial_subsample
from src.roi import PolygonROISelector
import numpy as np
from scipy.stats import norm

%matplotlib widget
az.style.use("arviz-darkgrid")

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

# Load configuration
config = ConfigManager()
db_engine = create_engine(config.db_url)

In [2]:
target_date = "2024-08-23"
camera_name = "PPCX_Tele"
output_dir = Path("output") / f"{camera_name}_PyMC"
base_name = f"{camera_name}_{target_date}_PyMC_GMM"

SUBSAMPLE_FACTOR = 2  # Take every n point
SUBSAMPLE_METHOD = "regular"  # or 'random', 'stratified'
PRIOR_STRENGTH = 0.4  # Between 0 and 1

# Load posterior and scaler
idata = az.from_netcdf(output_dir / f"{base_name}_pooled_posterior.idata.nc")
scaler = joblib.load(output_dir / f"{base_name}_scaler.joblib")

In [None]:
# Get DIC data
dic_ids = get_dic_analysis_ids(
    db_engine, camera_name=camera_name, reference_date=target_date
)
if len(dic_ids) == 0:
    raise ValueError("No DIC analyses found for the given criteria")
elif len(dic_ids) > 1:
    logging.warning(
        "Multiple DIC analyses found for the given criteria. Using the first one."
    )
dic_analyses = get_dic_analysis_by_ids(db_engine=db_engine, dic_ids=dic_ids)
master_image_id = dic_analyses["master_image_id"].iloc[0]
img = get_image(master_image_id, camera_name=camera_name)
df = get_dic_data(dic_ids[0])
df = apply_dic_filters(
    df,
    filter_outliers=config.get("dic.filter_outliers"),
    tails_percentile=config.get("dic.tails_percentile"),
)

# Apply ROI filter
selector = PolygonROISelector.from_file(config.get("data.roi_path"))
df = selector.filter_dataframe(df, x_col="x", y_col="y")
logging.info(f"Data shape after filtering: {df.shape}")

In [None]:
def assign_spatial_priors_dic(df, selectors, prior_strength=0.8):
    """Assign spatial prior probabilities based on polygon sectors."""
    ndata = len(df)
    k = len(selectors)
    prior_probs = np.ones((ndata, k)) / k  # default uniform

    for idx, selector in enumerate(selectors):
        mask = selector.contains_points(df["x"].values, df["y"].values)
        # Strong prior for points inside polygon
        prior_probs[mask] = (1 - prior_strength) / (
            k - 1
        )  # small prob for other clusters
        prior_probs[mask, idx] = prior_strength  # high prob for this cluster
        logging.info(f"Sector {idx}: {mask.sum()} points with strong prior")

    return prior_probs


# Prepare new data
variables_names = config.get("clustering.variables_names")
df_features = preproc_features(df)
X = df_features[variables_names].values
X_scaled = scaler.transform(X)


# Assign spatial priors
sector_files = sorted(glob.glob(config.get("data.sector_prior_pattern")))
sector_selectors = [PolygonROISelector.from_file(f) for f in sector_files]
k = len(sector_selectors)  # number of clusters = number of sectors
logging.info(f"Found {k} sector polygons for priors")

prior_probs = assign_spatial_priors_dic(
    df, sector_selectors, prior_strength=PRIOR_STRENGTH
)

In [None]:
# Extract posterior samples
mu_samples = idata.posterior["μ"].values  # (chains, draws, k, n_features)
sigma_samples = idata.posterior["σ"].values

mu_flat = mu_samples.reshape(
    -1, mu_samples.shape[-2], mu_samples.shape[-1]
)  # (n_samples, k, n_features)
sigma_flat = sigma_samples.reshape(-1, sigma_samples.shape[-2], sigma_samples.shape[-1])

n_samples = mu_flat.shape[0]
n_new = X_scaled.shape[0]
k = mu_flat.shape[1]

In [None]:
# Compute assignment probabilities
log_resp = np.zeros((n_samples, n_new, k))
for s in range(n_samples):
    for kk in range(k):
        lp = norm.logpdf(X_scaled, loc=mu_flat[s, kk], scale=sigma_flat[s, kk])
        log_lik = lp.sum(axis=1)
        log_prior = np.log(prior_probs[:, kk] + 1e-12)
        log_resp[s, :, kk] = log_prior + log_lik

resp_samples = np.exp(log_resp - log_resp.max(axis=2, keepdims=True))
resp_samples /= resp_samples.sum(axis=2, keepdims=True)  # (n_samples, n_new, k)

# Average over posterior samples
posterior_probs = resp_samples.mean(axis=0)  # (n_new, k)
cluster_pred = posterior_probs.argmax(axis=1)

In [None]:
# For 1D velocity clustering
from matplotlib.colors import Normalize
from scipy.stats import norm as scipy_norm


def plot_1d_velocity_clustering(
    df_features,
    img,
    *,
    idata,
    scaler=None,
    cluster_pred=None,
):
    """Specialized plot for 1D velocity-only clustering."""

    # Use provided cluster assignments when available
    if cluster_pred is None:
        # Try to extract discrete z samples from idata (if present)
        if "z" in idata.posterior:
            z_posterior = idata.posterior["z"]
            z_samples = z_posterior.values.reshape(-1, z_posterior.shape[-1])
            cluster_pred = np.zeros(z_posterior.shape[-1], dtype=int)
            for i in range(z_posterior.shape[-1]):
                values, counts = np.unique(z_samples[:, i], return_counts=True)
                cluster_pred[i] = values[np.argmax(counts)]
        else:
            raise ValueError(
                "cluster_pred not provided and 'z' not found in idata.posterior"
            )

    # Get model parameters
    mu_posterior = idata.posterior["μ"].mean(dim=["chain", "draw"]).values.flatten()
    sigma_posterior = idata.posterior["σ"].mean(dim=["chain", "draw"]).values.flatten()

    # Distinct colors
    unique_labels = np.unique(cluster_pred)
    colors = ["#E31A1C", "#1F78B4", "#33A02C", "#FF7F00", "#6A3D9A"][
        : len(unique_labels)
    ]
    color_map = {label: colors[i] for i, label in enumerate(unique_labels)}

    # Create figure with custom layout
    fig, axes = plt.subplots(2, 2, figsize=(8, 12))

    # Plot 1: Spatial clusters
    ax1 = axes[0, 0]
    ax1.set_title("Velocity-Based Spatial Clustering", fontsize=14, pad=10)

    if img is not None:
        ax1.imshow(img, alpha=0.3, cmap="gray")

    for label in unique_labels:
        mask = cluster_pred == label
        if np.any(mask):
            ax1.scatter(
                df_features.loc[mask, "x"],
                df_features.loc[mask, "y"],
                c=color_map[label],
                s=8,
                alpha=0.8,
                label=f"Velocity Cluster {label}",
                edgecolors="none",
            )
    ax1.legend(loc="upper right", framealpha=0.9, fontsize=10)
    ax1.set_aspect("equal")
    ax1.set_xticks([])
    ax1.set_yticks([])

    # Plot 2: Velocity distribution
    ax3 = axes[0, 1]
    ax3.set_title("Velocity Distribution by Cluster", fontsize=14, pad=10)

    # Plot histograms for each cluster
    velocity = df_features["V"].values
    for label in unique_labels:
        mask = cluster_pred == label
        if np.any(mask):
            ax3.hist(
                velocity[mask],
                bins=35,
                alpha=0.7,
                density=True,
                color=color_map[label],
                label=f"Cluster {label}",
                edgecolor="white",
                linewidth=0.5,
            )

    # Overlay model distributions
    v_range = np.linspace(velocity.min(), velocity.max(), 200)
    for label in unique_labels:
        if scaler is not None:
            mu_orig = scaler.inverse_transform([[mu_posterior[label]]])[0, 0]
            sigma_orig = sigma_posterior[label] * scaler.scale_[0]
        else:
            mu_orig = mu_posterior[label]
            sigma_orig = sigma_posterior[label]

        model_dist = scipy_norm.pdf(v_range, mu_orig, sigma_orig)
        ax3.plot(
            v_range,
            model_dist,
            "--",
            color=color_map[label],
            linewidth=2.5,
            alpha=0.9,
            label=f"Model {label}",
        )
    ax3.set_xlabel("Velocity Magnitude", fontsize=12)
    ax3.set_ylabel("Density", fontsize=12)
    ax3.grid(True, alpha=0.3)
    ax3.legend(fontsize=10, framealpha=0.9)

    # Plot 3: Velocity field with quiver plot
    ax2 = axes[1, 0]
    ax2.set_title("Velocity Vector Field", fontsize=14, pad=10)

    if img is not None:
        ax2.imshow(img, alpha=0.7, cmap="gray")

    # Create quiver plot
    magnitudes = df_features["V"].to_numpy()
    vmin = 0.0
    vmax = np.max(magnitudes)
    norm = Normalize(vmin=vmin, vmax=vmax)
    q = ax2.quiver(
        df_features["x"].to_numpy(),
        df_features["y"].to_numpy(),
        df_features["u"].to_numpy(),
        df_features["v"].to_numpy(),
        magnitudes,
        scale=None,
        scale_units="xy",
        angles="xy",
        cmap="viridis",
        norm=norm,
        width=0.003,
        headwidth=2.5,
        alpha=1.0,
    )

    # Add colorbar
    cbar = fig.colorbar(q, ax=ax2, shrink=0.8, aspect=20, pad=0.02)
    cbar.set_label("Velocity Magnitude", rotation=270, labelpad=15)
    ax2.set_aspect("equal")
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.grid(False)

    # Plot 4: Statistics
    ax4 = axes[1, 1]
    ax4.axis("off")
    stats_text = "VELOCITY CLUSTERING STATISTICS\n" + "=" * 30 + "\n"
    for label in unique_labels:
        mask = cluster_pred == label
        count = mask.sum()

        if count == 0:
            continue

        v_mean = velocity[mask].mean()
        v_std = velocity[mask].std()
        v_median = np.median(velocity[mask])
        nmad = np.median(np.abs(velocity[mask] - v_median)) * 1.4826

        # Model parameters (in original scale)
        if scaler is not None:
            model_mu = scaler.inverse_transform([[mu_posterior[label]]])[0, 0]
            model_sigma = sigma_posterior[label] * scaler.scale_[0]
        else:
            model_mu = mu_posterior[label]
            model_sigma = sigma_posterior[label]

        stats_text += f"VELOCITY CLUSTER {label} (pts: {count})\n"
        stats_text += f"├─ Velocity: {v_mean:.4f} ± {v_std:.4f}\n"
        stats_text += f"├─ Median/NMAD: {v_median:.4f}/{nmad:.4f}\n"
        stats_text += f"├─ Model μ/σ: {model_mu:.4f}/{model_sigma:.4f}\n\n"

    ax4.text(
        0.05,
        0.95,
        stats_text,
        transform=ax4.transAxes,
        fontsize=8,
        verticalalignment="top",
        fontfamily="monospace",
        bbox=dict(
            boxstyle="round,pad=0.4", facecolor="lightblue", alpha=0.8, edgecolor="navy"
        ),
    )

    return cluster_pred, fig


cluster_pred_1d, fig_1d = plot_1d_velocity_clustering(
    df_features, img, idata=idata, scaler=scaler, cluster_pred=cluster_pred
)
# fig_1d.savefig(
#     output_dir / f"{base_name}_velocity_only_results.png", dpi=300, bbox_inches="tight"
# )
