In [None]:
# cai_target_transform_hybrid.py
# Hybrid clustering target transformation for CAI_farm:
# Ward HAC (k=2) -> seeds KMeans (k=2) + silhouette + bootstrap stability + plots

from __future__ import annotations
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.utils import resample

from scipy.cluster.hierarchy import linkage, dendrogram
from typing import Tuple, Dict

# ------------------------------- Config --------------------------------
INPUT = "data/processed/normalised_with_cai.csv"  # must contain column 'CAI_farm'
CAI_COL = "CAI_farm"

RANDOM_SEED = 42
B_BOOT = 500                   # bootstrap iterations for stability
KM_MAX_ITER = 500
KM_TOL = 1e-4

# Visualization params
DENDRO_SUBSAMPLE = 1200        # using a subset to draw dendrogram (scales well)
CUT_HEIGHT = 1.0               # dashed line on dendrogram (illustrative)
HIST_BINS = 80

OUTDIR = Path("data/processed/")
OUTDIR.mkdir(exist_ok=True, parents=True)

# --------------------------- Helper functions --------------------------
def ensure_binary_high_is_one(labels: np.ndarray, X: np.ndarray) -> np.ndarray:
    """
    Map 2 cluster labels to {0,1} so that label 1 corresponds to the HIGH-CAI cluster.
    """
    labels = labels.astype(int)
    means = [X[labels == k].mean() if np.any(labels == k) else -np.inf for k in np.unique(labels)]
    # If cluster '1' is not the higher-mean cluster, flip labels
    uniq = np.unique(labels)
    if len(uniq) != 2:
        raise ValueError("Expected exactly 2 clusters.")
    # Identify which label has higher mean
    high_label = uniq[np.argmax([means[0] if uniq[0]==0 else means[1],
                                 means[1] if uniq[1]==1 else means[0]])]
    # Recode: high -> 1, other -> 0
    new = np.where(labels == high_label, 1, 0)
    return new

def fit_hac_get_centers(X: np.ndarray, random_state: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Fit Ward hierarchical (k=2) on all data using sklearn's AgglomerativeClustering and
    return (labels_01, centers_2x1) where centers are ordered ascending by value.
    """
    hac = AgglomerativeClustering(n_clusters=2, linkage="ward", affinity="euclidean")
    hac_labels = hac.fit_predict(X)
    # Map to 0/1 with 1 being 'high'
    labels01 = ensure_binary_high_is_one(hac_labels, X.ravel())
    # Compute centers (means) by cluster
    c0 = X[labels01 == 0].mean()
    c1 = X[labels01 == 1].mean()
    centers = np.array([[min(c0, c1)], [max(c0, c1)]], dtype=float)
    return labels01, centers

def fit_kmeans_with_init(X: np.ndarray, init_centers: np.ndarray, random_state: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Fit KMeans(k=2) with provided initial centers (2x1), return (labels01, centers_sorted).
    """
    km = KMeans(
        n_clusters=2,
        init=init_centers,
        n_init=1,                 # use provided init exactly
        max_iter=KM_MAX_ITER,
        tol=KM_TOL,
        random_state=random_state
    )
    km.fit(X)
    labels = km.labels_.copy()
    # Order centers and ensure label '1' is the higher-mean cluster
    centers = km.cluster_centers_.reshape(-1)
    order = np.argsort(centers)
    centers_sorted = centers[order]
    # Remap labels to match sorted centers (low->0, high->1)
    labels_sorted = np.zeros_like(labels)
    labels_sorted[labels == order[1]] = 1
    labels_sorted[labels == order[0]] = 0
    return labels_sorted, centers_sorted.reshape(-1, 1)

def silhouettes(X: np.ndarray, labels01: np.ndarray) -> Tuple[float, float]:
    """
    Compute mean and median silhouette scores for 2 clusters on 1D data.
    """
    if len(np.unique(labels01)) < 2:
        return float("nan"), float("nan")
    sil = silhouette_samples(X, labels01, metric="euclidean")
    return float(np.nanmean(sil)), float(np.nanmedian(sil))

def bootstrap_stability(
    X: np.ndarray,
    labels_ref: np.ndarray,
    random_state: int,
    B: int = 500
) -> np.ndarray:
    """
    Bootstrap stability:
      - For each bootstrap sample, fit HAC (k=2) -> get centers -> fit KMeans with those centers
      - Predict ALL X using those centroids (via km.predict)
      - Compare to reference labels (from the full-data KMeans) and average matches
    Returns: stability score per observation (n,)
    """
    rng = np.random.RandomState(random_state)
    n = X.shape[0]
    matches = np.zeros(n, dtype=float)

    for b in range(B):
        idx = rng.randint(0, n, size=n)  # bootstrap indices
        Xb = X[idx]
        # HAC on bootstrap to seed KMeans
        _, centers_b = fit_hac_get_centers(Xb, random_state + b + 1)
        km_b = KMeans(n_clusters=2, init=centers_b, n_init=1, max_iter=KM_MAX_ITER,
                      tol=KM_TOL, random_state=random_state + b + 1)
        km_b.fit(Xb)
        # Predict all original points using bootstrap-trained kmeans
        pred_b = km_b.predict(X)
        # Align such that 1==HIGH cluster
        centers_b_full = km_b.cluster_centers_.reshape(-1)
        order = np.argsort(centers_b_full)
        pred_b_aligned = np.zeros_like(pred_b)
        pred_b_aligned[pred_b == order[1]] = 1
        pred_b_aligned[pred_b == order[0]] = 0
        matches += (pred_b_aligned == labels_ref).astype(float)

    return matches / B

# --------------------------------- Main --------------------------------
if __name__ == "__main__":
    # 1) Load data with CAI_farm
    df = pd.read_csv(INPUT)
    if CAI_COL not in df.columns:
        raise KeyError(f"Column '{CAI_COL}' not found in {INPUT}")

    X = df[CAI_COL].astype(float).values.reshape(-1, 1)
    n = X.shape[0]
    print(f"Loaded {n:,} rows with CAI column: {CAI_COL}")

    # 2) Ward HAC (k=2) on full data -> labels + centers
    hac_labels01, hac_centers = fit_hac_get_centers(X, RANDOM_SEED)
    sil_hac_mean, sil_hac_median = silhouettes(X, hac_labels01)
    print(f"Ward HAC (k=2) Silhouette: mean={sil_hac_mean:.3f}, median={sil_hac_median:.3f}")

    # 3) KMeans (k=2) seeded by HAC centers
    km_labels01, km_centers = fit_kmeans_with_init(X, hac_centers, RANDOM_SEED)
    sil_km_mean, sil_km_median = silhouettes(X, km_labels01)
    print(f"KMeans (k=2, HAC-seeded) Silhouette: mean={sil_km_mean:.3f}, median={sil_km_median:.3f}")

    # 4) Bootstrap stability
    print(f"Bootstrapping stability with B={B_BOOT} ...")
    stability = bootstrap_stability(X, km_labels01, RANDOM_SEED, B=B_BOOT)
    print(f"Stability: median={np.median(stability):.3f}, mean={np.mean(stability):.3f}")

    # 5) Silhouette table
    sil_table = pd.DataFrame([
        {"Method": "K-means (k=2)", "Silhouette (mean)": round(sil_km_mean, 3),
         "Silhouette (median)": round(sil_km_median, 3), "n": n},
        {"Method": "Hierarchical/Ward (k=2)", "Silhouette (mean)": round(sil_hac_mean, 3),
         "Silhouette (median)": round(sil_hac_median, 3), "n": n},
    ])
    sil_table.to_csv(OUTDIR / "silhouette_summary.csv", index=False)

    # 6) Attach cluster outputs to dataframe
    # Ensure 0 = Low, 1 = High (already enforced), add names & centroids
    centroid_low, centroid_high = float(km_centers[0, 0]), float(km_centers[1, 0])
    df["binary_chai"] = km_labels01
    df["cluster_name"] = np.where(df["binary_cai"] == 1, "High", "Low")
    df["stability_score"] = stability
    df["centroid_low"] = centroid_low
    df["centroid_high"] = centroid_high

    df.to_csv(OUTDIR / "normalised_with_binary_cai.csv", index=False)

    # 7) Dendrogram plot (on a representative subsample for readability)
    n_sub = min(DENDRO_SUBSAMPLE, n)
    rng = np.random.RandomState(RANDOM_SEED)
    idx_sub = rng.choice(n, size=n_sub, replace=False)
    X_sub = X[idx_sub, :]
    Z = linkage(X_sub, method="ward", metric="euclidean")
    plt.figure(figsize=(7, 5))
    dendrogram(Z, no_labels=True, color_threshold=CUT_HEIGHT, truncate_mode="lastp", p=30)
    plt.axhline(CUT_HEIGHT, ls="--", color="tab:blue", alpha=0.7)
    plt.title("HAC (Ward) on Composite CAI — sample (cross-sectional)")
    plt.xlabel("Merged clusters (truncated)")
    plt.ylabel("Distance")
    plt.tight_layout()
    plt.savefig(OUTDIR / "hac_dendrogram_sample.png", dpi=200)
    plt.close()

    # 8) Histogram plot of K-means clusters with centroid markers
    plt.figure(figsize=(7.6, 4.6))
    mask_low = (km_labels01 == 0)
    mask_high = (km_labels01 == 1)
    plt.hist(X[mask_low].ravel(), bins=HIST_BINS, alpha=0.7, label="Cluster: Low")
    plt.hist(X[mask_high].ravel(), bins=HIST_BINS, alpha=0.7, label="Cluster: High")
    # centroid markers
    plt.axvline(centroid_low, ls="--", alpha=0.9)
    plt.axvline(centroid_high, ls="--", alpha=0.9)
    plt.title("Composite CAI — clusters via HAC→KMeans (cross-sectional)")
    plt.xlabel("CAI")
    plt.ylabel("Count")
    plt.legend()
    plt.tight_layout()
    plt.savefig(OUTDIR / "kmeans_histogram.png", dpi=200)
    plt.close()

    # 9) Console summary
    print("\nSilhouette summary:")
    print(sil_table.to_string(index=False))
    print(f"\nCentroids (KMeans): low={centroid_low:.4f}, high={centroid_high:.4f}")
    print(f"Artifacts saved in: {OUTDIR.resolve()}")