Dependencies
-----------------

In [None]:
# supcon_audio.py  — updated

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchaudio
from sklearn.cluster import KMeans
from sklearn.metrics import (
    adjusted_rand_score,
    adjusted_mutual_info_score,
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
)
from scipy.optimize import linear_sum_assignment
from torchvision.models import resnet18
from typing import Optional, Dict, Any, Sequence, Tuple

# Safe UMAP import (avoid AttributeError: module 'umap' has no attribute 'UMAP')
import umap.umap_ as umap

Training
--------------

In [23]:
# ----------------------------
# Dataset with Dual Views
# ----------------------------
class SpectrogramDataset(Dataset):
    def __init__(self, audio_dir, label_path, sr=10000, n_mels=64, augment_fn=None):
        self.audio_paths = sorted(
            [os.path.join(audio_dir, f) for f in os.listdir(audio_dir)
             if f.endswith(".npy") and f != "labels.npy"],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        self.labels = np.load(label_path)
        assert len(self.audio_paths) == len(self.labels), "Mismatch: number of audio files and labels"
        self.sr = sr
        self.n_mels = n_mels
        self.augment_fn = augment_fn
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=1024, hop_length=512, n_mels=n_mels
        )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        y = np.load(self.audio_paths[idx]).astype(np.float32)
        y = y / (np.max(np.abs(y)) + 1e-8)
        y_tensor = torch.tensor(y).unsqueeze(0)
        mel = self.mel_transform(y_tensor)
        logmel = torch.log(mel + 1e-8)
        x1 = self.augment_fn(logmel.clone()) if self.augment_fn else logmel
        x2 = self.augment_fn(logmel.clone()) if self.augment_fn else logmel
        return x1, x2, int(self.labels[idx])

# ----------------------------
# Encoder with ResNet18
# ----------------------------
class ConvSupConEncoder(nn.Module):
    def __init__(self, output_dim=64):
        super().__init__()
        self.backbone = resnet18(pretrained=False)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone.fc = nn.Identity()
        self.projector = nn.Sequential(
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.projector(x)
        return F.normalize(x, dim=1)

# ----------------------------
# SupCon Loss
# ----------------------------
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.03):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        device = features.device
        batch_size = features.shape[0]
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        contrast = torch.div(torch.matmul(features, features.T), self.temperature)
        logits_max, _ = torch.max(contrast, dim=1, keepdim=True)
        logits = contrast - logits_max.detach()
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        mask *= logits_mask
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        return -mean_log_prob_pos.mean()

# ----------------------------
# Eval helpers
# ----------------------------
def hungarian_accuracy(y_true, y_pred):
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_pred)):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return sum(w[i, j] for i, j in zip(row_ind, col_ind)) / len(y_pred)

def _internal_indices(X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    """Silhouette / DB / CH with guards."""
    out = {"silhouette": float("nan"),
           "davies_bouldin": float("nan"),
           "calinski_harabasz": float("nan")}
    uniq = np.unique(labels)
    if len(uniq) <= 1 or len(uniq) >= len(labels):
        return out
    try:
        out["silhouette"] = float(silhouette_score(X, labels))
    except Exception:
        pass
    try:
        out["davies_bouldin"] = float(davies_bouldin_score(X, labels))
    except Exception:
        pass
    try:
        out["calinski_harabasz"] = float(calinski_harabasz_score(X, labels))
    except Exception:
        pass
    return out

def _kmeans_metrics(emb: np.ndarray, y_true: np.ndarray, kmeans_k: Optional[int] = None, random_state: int = 42) -> Dict[str, float]:
    n_clusters = int(kmeans_k) if kmeans_k is not None else len(np.unique(y_true))
    preds = KMeans(n_clusters=n_clusters, n_init=10, random_state=random_state).fit_predict(emb)
    ari = adjusted_rand_score(y_true, preds)
    ami = adjusted_mutual_info_score(y_true, preds)
    h_acc = hungarian_accuracy(y_true, preds)
    extra = _internal_indices(emb, preds)
    return {
        "ari": float(ari),
        "ami": float(ami),
        "hungarian_accuracy": float(h_acc),
        "silhouette": float(extra["silhouette"]),
        "davies_bouldin": float(extra["davies_bouldin"]),
        "calinski_harabasz": float(extra["calinski_harabasz"]),
        "kmeans_k": int(n_clusters),
    }

def evaluate_embeddings(encoder, loader, device, *, kmeans_k: Optional[int] = None, random_state: int = 42):
    """Returns KMeans metrics on embeddings from `loader`."""
    encoder.eval()
    all_embeddings, all_labels = [], []
    with torch.no_grad():
        for x, _, labels in loader:
            x = x.to(device)
            emb = encoder(x).cpu().numpy()
            all_embeddings.append(emb)
            all_labels.extend(labels.cpu().numpy())
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.array(all_labels)
    return _kmeans_metrics(all_embeddings, all_labels, kmeans_k=kmeans_k, random_state=random_state)

# ----------------------------
# SpecAugment
# ----------------------------
def augment_audio(x):
    if x.dim() == 3:
        # [1, F, T]
        mask_freq = x.clone()
        mask_time = x.clone()
        F_dim, T_dim = mask_freq.size(1), mask_freq.size(2)
        freq_mask = max(1, int(0.2 * F_dim))
        freq_start = torch.randint(0, max(1, F_dim - freq_mask), (1,))
        mask_freq[:, freq_start:freq_start+freq_mask, :] = 0
        time_mask = max(1, int(0.2 * T_dim))
        time_start = torch.randint(0, max(1, T_dim - time_mask), (1,))
        mask_time[:, :, time_start:time_start+time_mask] = 0
        return mask_freq + mask_time
    elif x.dim() == 4:
        # [B, 1, F, T]
        mask_freq = x.clone()
        mask_time = x.clone()
        F_dim, T_dim = mask_freq.size(2), mask_freq.size(3)
        freq_mask = max(1, int(0.2 * F_dim))
        freq_start = torch.randint(0, max(1, F_dim - freq_mask), (1,))
        mask_freq[:, :, freq_start:freq_start+freq_mask, :] = 0
        time_mask = max(1, int(0.2 * T_dim))
        time_start = torch.randint(0, max(1, T_dim - time_mask), (1,))
        mask_time[:, :, :, time_start:time_start+time_mask] = 0
        return mask_freq + mask_time
    else:
        raise ValueError(f"Unexpected input shape: {x.shape}")

# ======================================================
# Split helper — do the stratified 80/20 split OUTSIDE training
# ======================================================
def make_fixed_train_test_indices(dataset_path: str, *, test_size: float = 0.2, random_state: int = 42) -> Tuple[np.ndarray, np.ndarray]:
    """Return stratified indices (train_idx, test_idx) in the same order used by SpectrogramDataset."""
    audio_paths = sorted(
        [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)
         if f.endswith(".npy") and f != "labels.npy"],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
    )
    labels = np.load(os.path.join(dataset_path, "labels.npy"))
    assert len(audio_paths) == len(labels), "Mismatch between files and labels"
    from sklearn.model_selection import train_test_split
    indices = np.arange(len(audio_paths))
    train_idx, test_idx = train_test_split(
        indices, test_size=test_size, stratify=labels, random_state=random_state
    )
    return np.asarray(train_idx), np.asarray(test_idx)

# ======================================================
# Runner 1: TRAINING (uses ONLY provided train indices)
# ======================================================
def run_training(
    dataset_path: str,
    *,
    train_indices: Sequence[int],          # REQUIRED: indices from make_fixed_train_test_indices (80%)
    val_split: float = 0.1,                # taken from the 80% train portion
    batch_size: int = 256,
    num_workers: int = 4,
    epochs: int = 30,
    lr: float = 2e-3,
    weight_decay: float = 1e-4,
    temperature: float = 0.03,
    eval_every: int = 10,
    patience: int = 50,
    checkpoint_in: Optional[str] = "best_encoder_pretrain.pth",
    checkpoint_out: str = "best_encoder_pretrain.pth",
    seed: int = 42,
    val_kmeans_k: Optional[int] = None,    # optional: fix k during periodic validation
) -> Dict[str, Any]:
    """
    Pretrains encoder with SupCon on dual-views spectrograms.
    Periodically evaluates with KMeans on a VALIDATION split carved from the provided TRAIN indices.
    The held-out TEST set is NOT touched here (created outside; pass to run_inference).
    """
    # Repro
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

    audio_dir = dataset_path
    label_path = os.path.join(audio_dir, "labels.npy")
    if not os.path.isdir(audio_dir):
        raise FileNotFoundError(f"Dataset directory not found: {audio_dir}")
    if not os.path.isfile(label_path):
        raise FileNotFoundError(f"Label file not found: {label_path}")

    # Base dataset; restrict to TRAIN indices only
    full_dataset = SpectrogramDataset(audio_dir, label_path, augment_fn=augment_audio)
    train_base = Subset(full_dataset, list(train_indices))

    # Carve a validation split from the train_base
    val_size = max(1, int(round(val_split * len(train_base))))
    train_size = len(train_base) - val_size
    if train_size <= 0:
        raise ValueError(f"val_split too large for train size {len(train_base)}")

    train_dataset, val_dataset = random_split(train_base, [train_size, val_size], generator=torch.Generator().manual_seed(seed))

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
        drop_last=True, pin_memory=True, persistent_workers=num_workers > 0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
        drop_last=False, pin_memory=True, persistent_workers=num_workers > 0
    )

    # Model/optim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = ConvSupConEncoder().to(device)
    criterion = SupConLoss(temperature=temperature)
    optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay)

    # Optional warm-start
    if checkpoint_in and os.path.isfile(checkpoint_in):
        try:
            encoder.load_state_dict(torch.load(checkpoint_in, map_location=device))
            print(f"Loaded checkpoint from {checkpoint_in}")
        except Exception as e:
            print(f"Warning: failed to load checkpoint '{checkpoint_in}': {e}")

    scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None

    best_hungarian, no_improve_epochs = 0.0, 0
    for epoch in range(1, epochs + 1):
        encoder.train()
        total_loss = 0.0
        for x1, x2, labels in train_loader:
            x1, x2, labels = x1.to(device, non_blocking=True), x2.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            if scaler:
                with torch.cuda.amp.autocast():
                    z1 = encoder(x1); z2 = encoder(x2)
                    features = torch.cat([z1, z2], dim=0)
                    labels_dual = torch.cat([labels, labels], dim=0)
                    loss = criterion(features, labels_dual)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                z1 = encoder(x1); z2 = encoder(x2)
                features = torch.cat([z1, z2], dim=0)
                labels_dual = torch.cat([labels, labels], dim=0)
                loss = criterion(features, labels_dual)
                loss.backward()
                optimizer.step()
            total_loss += loss.item()

        print(f"[Pretrain] Epoch {epoch}/{epochs}  Train Loss: {total_loss:.4f}")

        # Periodic eval on VAL (not test)
        if epoch % eval_every == 0:
            val_metrics = evaluate_embeddings(encoder, val_loader, device, kmeans_k=val_kmeans_k, random_state=seed)
            print(f"[Pretrain] Validation — "
                  f"ARI: {val_metrics['ari']:.4f}, AMI: {val_metrics['ami']:.4f}, "
                  f"Hungarian: {val_metrics['hungarian_accuracy']:.4f}")

            if val_metrics["hungarian_accuracy"] > best_hungarian:
                best_hungarian = val_metrics["hungarian_accuracy"]
                no_improve_epochs = 0
                torch.save(encoder.state_dict(), checkpoint_out)
                print(f"✅ Epoch {epoch} — New best saved to '{checkpoint_out}' (Hungarian: {best_hungarian:.4f})")
            else:
                no_improve_epochs += eval_every
                if no_improve_epochs >= patience:
                    print("⏹️ Early stopping triggered.")
                    break

    if not os.path.isfile(checkpoint_out):
        torch.save(encoder.state_dict(), checkpoint_out)
        print(f"Saved final model to '{checkpoint_out}'")

    return {
        "best_hungarian": float(best_hungarian),
        "checkpoint_out": checkpoint_out,
        "epochs_run": epoch,
        "train_samples": len(train_dataset),
        "val_samples": len(val_dataset),
        "device": str(device),
    }

# ======================================================
# Runner 2: INFERENCE / EVAL / UMAP  (uses ONLY provided test indices)
# ======================================================
def run_inference(
    dataset_path: str,
    *,
    test_indices: Sequence[int],           # REQUIRED: the same 20% held out
    checkpoint_path: str = "best_encoder_pretrain.pth",
    batch_size: int = 384,
    num_workers: int = 4,
    do_clustering: bool = True,
    kmeans_k: Optional[int] = None,        # choose k (else infer from y_test)
    do_umap: bool = True,
    umap_components: int = 3,
    umap_html_path: str = "umap_embeddings.html",
    seed: int = 42,
) -> Dict[str, Any]:
    """
    Loads encoder, extracts embeddings on the FIXED TEST subset (the provided 20%),
    computes KMeans metrics (ARI/AMI/H-Acc + Silhouette/DB/CH), and optionally saves UMAP HTML.
    """
    # Repro
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

    audio_dir = dataset_path
    label_path = os.path.join(audio_dir, "labels.npy")
    if not os.path.isdir(audio_dir):
        raise FileNotFoundError(f"Dataset directory not found: {audio_dir}")
    if not os.path.isfile(label_path):
        raise FileNotFoundError(f"Label file not found: {label_path}")
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    full_dataset = SpectrogramDataset(audio_dir, label_path, augment_fn=None)
    test_subset = Subset(full_dataset, list(test_indices))
    loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = ConvSupConEncoder().to(device)
    encoder.load_state_dict(torch.load(checkpoint_path, map_location=device))
    encoder.eval()

    # Extract embeddings for TEST subset
    all_embeddings, all_labels = [], []
    with torch.no_grad():
        for x1, _, labels in loader:
            x1 = x1.to(device, non_blocking=True)
            emb = encoder(x1).cpu().numpy()
            all_embeddings.append(emb)
            all_labels.append(labels.numpy())
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    results: Dict[str, Any] = {
        "embeddings_shape": tuple(all_embeddings.shape),
        "labels_shape": tuple(all_labels.shape),
        "checkpoint": checkpoint_path
    }

    # Clustering metrics on TEST embeddings
    if do_clustering:
        metrics = _kmeans_metrics(all_embeddings, all_labels, kmeans_k=kmeans_k, random_state=seed)
        print(f"[Test] KMeans (k={metrics['kmeans_k']}) — "
              f"ARI: {metrics['ari']:.4f}, AMI: {metrics['ami']:.4f}, "
              f"Hungarian: {metrics['hungarian_accuracy']:.4f}")
        print(f"[Test] Silhouette: {metrics['silhouette']:.4f} | "
              f"Davies–Bouldin: {metrics['davies_bouldin']:.4f} | "
              f"Calinski–Harabasz: {metrics['calinski_harabasz']:.2f}")
        results["metrics"] = metrics

    # Optional UMAP on TEST embeddings
    if do_umap:
        reducer = umap.UMAP(n_components=umap_components, random_state=seed)
        emb_umap = reducer.fit_transform(all_embeddings)
        if umap_components == 3:
            fig = plotly_3d(emb_umap, all_labels, title="3D UMAP of TEST Embeddings")
        elif umap_components == 2:
            fig = plotly_2d(emb_umap, all_labels, title="2D UMAP of TEST Embeddings")
        else:
            raise ValueError("umap_components must be 2 or 3")
        fig.write_html(umap_html_path)
        print(f"Saved UMAP visualization to '{umap_html_path}'")
        results["umap_html_path"] = umap_html_path

    return results

# Small helpers for plotting
def plotly_3d(emb, labels, title="3D UMAP"):
    import plotly.express as px
    fig = px.scatter_3d(
        x=emb[:,0], y=emb[:,1], z=emb[:,2],
        color=labels.astype(str),
        title=title,
        labels={"x":"UMAP-1","y":"UMAP-2","z":"UMAP-3"},
        opacity=0.7
    )
    fig.update_traces(marker=dict(size=3))
    fig.update_layout(legend_title_text='True Label', width=800, height=700)
    return fig

def plotly_2d(emb, labels, title="2D UMAP"):
    import plotly.express as px
    fig = px.scatter(
        x=emb[:,0], y=emb[:,1],
        color=labels.astype(str),
        title=title,
        labels={"x":"UMAP-1","y":"UMAP-2"},
        opacity=0.7
    )
    fig.update_traces(marker=dict(size=4))
    fig.update_layout(legend_title_text='True Label', width=800, height=650)
    return fig

Train/test runner with metrics
-------------------------

In [25]:
# # 1) Create the fixed split ONCE (outside training)
train_idx, test_idx = make_fixed_train_test_indices("/notebooks/dataset_preprocessed", test_size=0.2, random_state=42)

# # 2) Train ONLY on the 80% (we carve a small val from that 80%)
# train_res = run_training(
#     dataset_path="/notebooks/dataset_preprocessed",
#     train_indices=train_idx,
#     epochs=30,
#     eval_every=10,
#     checkpoint_out="CNNbest_encoder_pretrain.pth",
#     val_kmeans_k=60,     # optional during val
# )

# 3) Evaluate ONLY on the held-out 20% (same indices)
test_res = run_inference(
    dataset_path="/notebooks/dataset_preprocessed",
    test_indices=test_idx,       # <- the exact same 20%
    checkpoint_path="best_encoder_pretrain.pth",
    do_clustering=True,
    kmeans_k=6,                 # choose k; else inferred from y_test
    do_umap=True,
    umap_components=3,
    umap_html_path="SupCon_k6.html",
)

print(test_res.get("metrics"))


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.



[Test] KMeans (k=6) — ARI: 0.3715, AMI: 0.3956, Hungarian: 0.6456
[Test] Silhouette: 0.4459 | Davies–Bouldin: 1.0093 | Calinski–Harabasz: 70264.94



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Saved UMAP visualization to 'SupCon_k6.html'
{'ari': 0.3715453842629094, 'ami': 0.39558826152268695, 'hungarian_accuracy': 0.6455628818583268, 'silhouette': 0.44588062167167664, 'davies_bouldin': 1.0093408090839737, 'calinski_harabasz': 70264.93860863186, 'kmeans_k': 6}


Extended Reports
-------------------

In [None]:
# ===== SUPCON REPORT (SimCLR-style) =====
# Add this block to the end of supcon_audio.py (after your existing code)

import os
import math
import time
import base64
from io import BytesIO
from collections import Counter
from typing import List, Optional, Dict

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
import torch
from torch.utils.data import DataLoader, Subset

# safe UMAP import (works if you already did `import umap.umap_ as umap` earlier)
try:
    import umap.umap_ as umap
except Exception:
    import umap  # type: ignore

from scipy.stats import entropy
from scipy.signal import spectrogram, get_window
from sklearn.cluster import KMeans
from sklearn.metrics import (
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
)
from sklearn.metrics.pairwise import cosine_similarity


# ---------- File utilities ----------
def _list_audio_npy_files(folder: str) -> List[str]:
    files = []
    for f in os.listdir(folder):
        if f.endswith(".npy") and f != "labels.npy":
            stem = os.path.splitext(f)[0]
            try:
                key = int(stem)
            except ValueError:
                key = stem
            files.append((key, os.path.join(folder, f)))
    files.sort(key=lambda t: t[0])
    return [p for _, p in files]


def _load_class_labels_if_any(folder: str, count: int) -> np.ndarray:
    lbl_path = os.path.join(folder, "labels.npy")
    if os.path.isfile(lbl_path):
        try:
            arr = np.load(lbl_path, allow_pickle=True)
            if len(arr) < count:
                pad = np.array(["Unknown"] * (count - len(arr)), dtype=object)
                arr = np.concatenate([arr, pad], axis=0)
            elif len(arr) > count:
                arr = arr[:count]
            return arr
        except Exception:
            pass
    return np.array(["Unknown"] * count, dtype=object)


# ---------- Spectrogram thumbnail helpers ----------
def _ali_spec(x: np.ndarray, fs: int = 10000):
    Lframe2 = 1000
    po = 80
    lov = int(np.ceil((po / 100) * Lframe2))
    taper = get_window("hann", Lframe2)
    Nfft = 2 ** (int(np.floor(np.log2(Lframe2))) + 2)
    f, t, s = spectrogram(
        x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode="complex"
    )
    as_ = np.abs(s)
    as_max = np.max(as_) if np.isfinite(np.max(as_)) and np.max(as_) > 0 else 1.0
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx : max_inx + 1, :], f[min_inx : max_inx + 1], t


def _spec_img_base64(
    audio_or_spec: np.ndarray, title: str, fs: int = 10000
) -> str:
    if audio_or_spec.ndim == 1:
        spec, f_axis, t_axis = _ali_spec(audio_or_spec.astype(np.float32), fs)
        fig, ax = plt.subplots(figsize=(8, 3))
        ax.imshow(
            spec,
            aspect="auto",
            origin="lower",
            extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]],
            cmap="hsv",
        )
        ax.set_title(title)
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Hz")
    else:
        S = audio_or_spec.squeeze()
        fig, ax = plt.subplots(figsize=(8, 3))
        ax.imshow(S, aspect="auto", origin="lower", cmap="viridis")
        ax.set_title(title)
        ax.set_xlabel("Frames")
        ax.set_ylabel("Mel bins")

    fig.tight_layout()
    buf = BytesIO()
    fig.savefig(buf, format="png", dpi=120)
    plt.close(fig)
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return (
        f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" '
        f'alt="{title}">'
    )


def _carousel_html(cluster_id: int, scope: str, imgs: List[str]) -> str:
    cid = f"carousel_{scope}_{cluster_id}"
    indicators = "".join(
        f'<button type="button" data-bs-target="#{cid}" data-bs-slide-to="{i}" '
        f'{"class=active" if i==0 else ""} aria-current="true" '
        f'aria-label="Slide {i+1}"></button>'
        for i in range(len(imgs))
    )
    items = "".join(
        f'<div class="carousel-item {"active" if i==0 else ""}">'
        f'<div class="d-flex justify-content-center">{img}</div>'
        f"</div>"
        for i, img in enumerate(imgs)
    )
    return f"""
    <div id="{cid}" class="carousel slide" data-bs-interval="false" data-bs-touch="false">
      <div class="carousel-indicators">{indicators}</div>
      <div class="carousel-inner">{items}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{cid}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{cid}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    """


# ---------- Per-cluster metrics (same as SimCLR report) ----------
def _evaluate_cluster_metrics(
    embeddings: np.ndarray,
    idxs: np.ndarray,
    location_labels: np.ndarray,
    location_entropy_base: Optional[int] = None,
) -> Dict[str, float]:
    X = embeddings[idxs]
    if X.shape[0] == 0:
        return {
            "variance": 0.0,
            "mean_sim": 1.0,
            "entropy": 0.0,
            "quality": 0.0,
            "novelty": 0.0,
        }
    center = np.mean(X, axis=0, keepdims=True)
    variance = float(np.mean(np.sum((X - center) ** 2, axis=1)))
    if len(X) > 1:
        cos_sim = cosine_similarity(X)
        iu = np.triu_indices_from(cos_sim, k=1)
        mean_sim = float(np.mean(cos_sim[iu])) if iu[0].size > 0 else 1.0
    else:
        mean_sim = 1.0
    loc_counts = Counter(location_labels[idxs])
    loc_probs = np.array(list(loc_counts.values()), dtype=np.float32)
    loc_probs /= max(loc_probs.sum(), 1e-8)
    base = int(location_entropy_base or len(set(location_labels)))
    loc_entropy = float(entropy(loc_probs, base=base)) if base > 1 else 0.0
    max_ent = np.log2(base) if base > 1 else 1.0
    entropy_score = 1.0 - (loc_entropy / max_ent) if base > 1 else 1.0
    quality = float((mean_sim / (variance + 1e-8)) * entropy_score)
    novelty = float((loc_entropy / max_ent) * variance) if base > 1 else 0.0
    return {
        "variance": variance,
        "mean_sim": mean_sim,
        "entropy": loc_entropy,
        "quality": quality,
        "novelty": novelty,
    }


# ---------- Embed a subset with your SupCon encoder ----------
def _load_supcon_encoder(checkpoint_path: str, device: torch.device):
    enc = ConvSupConEncoder().to(device)  # uses your class defined earlier
    sd = torch.load(checkpoint_path, map_location=device)
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    enc.load_state_dict(sd, strict=False)
    enc.eval()
    print(f"[INFO] SupCon encoder loaded from {checkpoint_path} on {device}")
    return enc


def _extract_embeds_from_dir(
    dataset_path: str,
    location_tag: str,
    checkpoint_path: str,
    *,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    batch_size: int = 256,
    num_workers: int = 0,
):
    all_files = _list_audio_npy_files(dataset_path)
    if len(all_files) == 0:
        raise RuntimeError(f"No .npy files found in {dataset_path}")

    n_sub = max(1, int(math.ceil(len(all_files) * subset_fraction)))
    rng = np.random.RandomState(subset_seed)
    chosen_idx = np.sort(rng.choice(len(all_files), size=n_sub, replace=False))
    chosen_files = [all_files[i] for i in chosen_idx]

    # dataset & subset aligned to SpectrogramDataset's ordering
    ds = SpectrogramDataset(
        dataset_path, os.path.join(dataset_path, "labels.npy"), augment_fn=None
    )
    subset = Subset(ds, chosen_idx.tolist())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    enc = _load_supcon_encoder(checkpoint_path, device)

    loader = DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    feats, cls_labels = [], []
    t0 = time.perf_counter()
    seen = 0
    with torch.no_grad():
        for step, (x1, _, y) in enumerate(loader):
            x1 = x1.to(device, non_blocking=True)
            h = enc(x1).cpu().numpy().astype("float32")
            feats.append(h)
            cls_labels.append(y.numpy())
            seen += x1.size(0)
            if (step + 1) % 10 == 0 or step == 0:
                dt = time.perf_counter() - t0
                print(f"[Embed] {seen}/{n_sub} • {seen / max(dt, 1e-9):.1f} it/s")

    H = np.vstack(feats).astype("float32")
    cls = np.concatenate(cls_labels, axis=0)
    loc = np.array([location_tag] * H.shape[0], dtype=object)
    print(f"[INFO] {dataset_path} → {H.shape} embeds")
    return H, cls, loc, np.array(chosen_files, dtype=object)


# ---------- Main analysis to HTML ----------
def analyze_supcon_to_html(
    checkpoint_path: str,
    dataset_paths: List[str],
    location_tags: List[str],
    *,
    cluster_method: str = "kmeans",  # 'kmeans' | 'gmm' | 'agglomerative'
    n_clusters: int = 60,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    samples_per_cluster: int = 4,
    top_n_clusters: Optional[int] = None,  # rank by size; None = show all
) -> str:
    assert len(dataset_paths) == len(location_tags), (
        "dataset_paths and location_tags must match"
    )

    # Collect embeddings from each location
    embeds_all, class_all, loc_all, files_all = [], [], [], []
    for path, tag in zip(dataset_paths, location_tags):
        H, cls, loc, files = _extract_embeds_from_dir(
            path, tag, checkpoint_path,
            subset_fraction=subset_fraction,
            subset_seed=subset_seed,
        )
        embeds_all.append(H)
        class_all.append(cls)
        loc_all.append(loc)
        files_all.append(files)

    embeddings = np.vstack(embeds_all).astype("float32")
    class_labels = np.concatenate(class_all, axis=0).astype(object)
    location_labels = np.concatenate(loc_all, axis=0).astype(object)
    original_paths = np.concatenate(files_all, axis=0).astype(object)

    # UMAP (cosine) on embeddings
    print(f"[INFO] UMAP on {embeddings.shape[0]}×{embeddings.shape[1]}")
    reducer = umap.UMAP(
        n_components=3, n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42
    )
    proj_3d = reducer.fit_transform(embeddings)

    # Clustering
    def _choose_clusterer(algorithm: str, X: np.ndarray, k: int):
        if algorithm == "kmeans":
            model = KMeans(n_clusters=k, n_init="auto", random_state=42).fit(X)
            return model.labels_, None
        elif algorithm == "agglomerative":
            from sklearn.cluster import AgglomerativeClustering

            model = AgglomerativeClustering(n_clusters=k).fit(X)
            return model.labels_, None
        elif algorithm == "gmm":
            from sklearn.mixture import GaussianMixture

            model = GaussianMixture(
                n_components=k, covariance_type="full", random_state=42
            ).fit(X)
            return model.predict(X), model.predict_proba(X)
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")

    print(f"[INFO] Clustering with {cluster_method} (k={n_clusters})")
    cluster_labels, _ = _choose_clusterer(cluster_method, embeddings, n_clusters)

    # Global metrics (same as SimCLR report) on embeddings, not UMAP
    unique_clusters = np.unique(cluster_labels)
    valid = (len(unique_clusters) > 1) and (embeddings.shape[0] > len(unique_clusters))

    def _safe_metric(fn, X, y):
        try:
            return float(fn(X, y)) if valid else float("nan")
        except Exception:
            return float("nan")

    sil = _safe_metric(silhouette_score, embeddings, cluster_labels)
    dbi = _safe_metric(davies_bouldin_score, embeddings, cluster_labels)
    ch = _safe_metric(calinski_harabasz_score, embeddings, cluster_labels)

    title_txt = f"{cluster_method.capitalize()} (Sil={sil:.3f} | DBI={dbi:.3f} | CH={ch:.1f})"
    fig = px.scatter_3d(
        x=proj_3d[:, 0],
        y=proj_3d[:, 1],
        z=proj_3d[:, 2],
        color=[str(c) for c in cluster_labels],
        symbol=location_labels,
        hover_data={"Cluster": cluster_labels, "Class": class_labels},
        title=title_txt,
        opacity=0.85,
        height=800,
    )
    umap_html = pio.to_html(fig, include_plotlyjs="cdn", full_html=False)

    # Per-cluster sections
    base_for_entropy = len(set(location_labels))
    cluster_ids = sorted(
        set(cluster_labels),
        key=lambda c: np.sum(cluster_labels == c),
        reverse=True,
    )
    if top_n_clusters is not None:
        cluster_ids = cluster_ids[: int(top_n_clusters)]

    blocks = []
    for cid in cluster_ids:
        idxs = np.where(cluster_labels == cid)[0]
        if idxs.size == 0:
            continue

        sz = len(idxs)
        loc_counts = Counter(location_labels[idxs])
        cls_counts = Counter(class_labels[idxs])
        metrics = _evaluate_cluster_metrics(
            embeddings, idxs, location_labels, location_entropy_base=base_for_entropy
        )

        meta_html = (
            "<p><strong>Location Distribution:</strong></p><ul>"
            + "".join(
                f"<li><b>{loc}</b>: {count} ({count/sz:.1%})</li>"
                for loc, count in loc_counts.items()
            )
            + "</ul>"
        )
        meta_html += (
            "<p><strong>Class Distribution:</strong></p><ul>"
            + "".join(f"<li>{cls}: {count}</li>" for cls, count in cls_counts.items())
            + "</ul>"
        )
        meta_html += f"""
        <p><strong>Cluster Metrics:</strong></p>
        <ul>
          <li>Size: {sz}</li>
          <li>Intra-Cluster Variance: {metrics['variance']:.4f}</li>
          <li>Mean Cosine Similarity: {metrics['mean_sim']:.4f}</li>
          <li>Location Entropy: {metrics['entropy']:.3f}</li>
          <li>Composite Quality Score: {metrics['quality']:.4f}</li>
          <li><strong>Novelty Score:</strong> {metrics['novelty']:.4f}</li>
        </ul>
        """

        # nearest to center exemplars
        center = np.mean(embeddings[idxs], axis=0, keepdims=True)
        dists = np.linalg.norm(embeddings[idxs] - center, axis=1)
        order = np.argsort(dists)
        pick = idxs[order[: min(samples_per_cluster, len(order))]]

        thumbs = []
        for i, p in enumerate(pick):
            try:
                arr = np.load(original_paths[p], mmap_mode="r")
                title = f"#{i+1} | {location_labels[p]} | Class {class_labels[p]}"
                thumbs.append(_spec_img_base64(arr, title))
            except Exception as e:
                thumbs.append(f"<p class='text-danger'>Error: {e}</p>")

        blocks.append(
            f"<div class='col-md-6 mb-4'><h4>Cluster {cid}</h4>"
            f"{meta_html}{_carousel_html(cid, 'supcon', thumbs)}</div>"
        )

    cluster_html = ""
    for i in range(0, len(blocks), 2):
        cluster_html += "<div class='row'>" + "".join(blocks[i : i + 2]) + "</div>"

    return f"""
    <div class='section'>
      <h2>{cluster_method.capitalize()} Clustering Analysis (SupCon • subset {int(subset_fraction*100)}%)</h2>
      <div class="container">
        <div class="row justify-content-center mb-4">
          <div class="col-md-12 d-flex justify-content-center">{umap_html}</div>
        </div>
      </div>
      {cluster_html}
    </div>
    """


# ---------- Full HTML wrapper ----------
def generate_supcon_full_report(
    checkpoint_path: str,
    dataset_paths: List[str],
    location_tags: List[str],
    *,
    out_html: str = "supcon_unsupervised_report.html",
    cluster_method: str = "kmeans",
    n_clusters: int = 60,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    samples_per_cluster: int = 4,
    top_n_clusters: Optional[int] = None,
) -> str:
    section = analyze_supcon_to_html(
        checkpoint_path=checkpoint_path,
        dataset_paths=list(dataset_paths),
        location_tags=list(location_tags),
        cluster_method=cluster_method,
        n_clusters=int(n_clusters),
        subset_fraction=subset_fraction,
        subset_seed=subset_seed,
        samples_per_cluster=samples_per_cluster,
        top_n_clusters=top_n_clusters,
    )
    html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Unsupervised Clustering Report (SupCon)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body {{ font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }}
            h1 {{ color: #2c3e50; }}
            h2, h4 {{ color: #34495e; }}
            hr {{ border-top: 2px solid #bbb; margin-top: 40px; margin-bottom: 40px; }}
            .section {{ margin-bottom: 60px; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
        </style>
    </head>
    <body>
      <h1 class='mb-4'>Unsupervised Latent Clustering Report (SupCon)</h1>
      <p><strong>Locations:</strong> {', '.join(location_tags)}</p>
      <p><strong>Subset:</strong> {int(subset_fraction*100)}% • Seed: {subset_seed}</p>
      <hr>
      {section}
    </body></html>
    """
    with open(out_html, "w", encoding="utf-8") as f:
        f.write(html)
    print(f"✅ Report saved to: {out_html}")
    return out_html

In [22]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Unsupervised Latent Clustering Report for SupCon (and compatible encoders)

Generates an HTML report including:
- UMAP visualization of embeddings (color by cluster, symbol by location)
- Global clustering metrics (Silhouette, DBI, CH)
- Per-cluster metrics (variance, mean cosine, entropy, quality, novelty)
- Spectrogram-space consistency metrics:
    * Mean spectrogram cosine similarity (↑ better)
    * Mean DTW distance on time courses (↓ better; optional via fastdtw)

Usage (example):
    python supcon_full_report.py \
        --checkpoint /path/to/supcon_encoder.ckpt \
        --dataset /data/siteA /data/siteB \
        --tags SiteA SiteB \
        --out supcon_unsupervised_report.html \
        --cluster gmm --k 60 --subset 0.02 --seed 42 \
        --samples-per-cluster 4 --top-n 40
"""

import os
import math
import time
import base64
from io import BytesIO
from collections import Counter, defaultdict
from typing import List, Optional, Dict, Tuple

# ----- Core libs -----
import numpy as np

# Plotting
import matplotlib
matplotlib.use("Agg")  # headless
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio

# ML / metrics
import umap
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import entropy
from scipy.signal import spectrogram, get_window

# ----- Project-specific imports (edit paths/names if needed) -----
# Expect these to be available in your codebase:
# - ConvSupConEncoder: your SupCon-compatible encoder
# - SpectrogramDataset: returns spectrograms / waveforms + labels


# ============================================================
# File helpers
# ============================================================

def _list_audio_npy_files(folder: str) -> List[str]:
    files = []
    for f in os.listdir(folder):
        if f.endswith(".npy") and f != "labels.npy":
            stem = os.path.splitext(f)[0]
            try:
                key = int(stem)
            except ValueError:
                key = stem
            files.append((key, os.path.join(folder, f)))
    files.sort(key=lambda t: t[0])
    return [p for _, p in files]


def _load_class_labels_if_any(folder: str, count: int) -> np.ndarray:
    lbl_path = os.path.join(folder, "labels.npy")
    if os.path.isfile(lbl_path):
        try:
            arr = np.load(lbl_path, allow_pickle=True)
            if len(arr) < count:
                pad = np.array(["Unknown"] * (count - len(arr)), dtype=object)
                arr = np.concatenate([arr, pad], axis=0)
            elif len(arr) > count:
                arr = arr[:count]
            return arr
        except Exception:
            pass
    return np.array(["Unknown"] * count, dtype=object)


# ============================================================
# Spectrogram helpers
# ============================================================

def _ali_spec(x: np.ndarray, fs: int = 10000) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Make a Hann-windowed spectrogram (complex), then log-like scale, 0–800Hz band."""
    Lframe2 = 1000
    po = 80
    lov = int(math.ceil((po / 100) * Lframe2))
    taper = get_window('hann', Lframe2)
    Nfft = 2 ** (int(math.floor(math.log2(Lframe2))) + 2)
    f, t, s = spectrogram(x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode='complex')
    as_ = np.abs(s)
    as_max = np.max(as_) if np.isfinite(np.max(as_)) and np.max(as_) > 0 else 1.0
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx:max_inx+1, :], f[min_inx:max_inx+1], t


def _spec_img_base64(audio_or_spec: np.ndarray, title: str, fs: int = 10000) -> str:
    if audio_or_spec.ndim == 1:
        spec, f_axis, t_axis = _ali_spec(audio_or_spec.astype(np.float32), fs)
        fig, ax = plt.subplots(figsize=(8, 3))
        im = ax.imshow(spec, aspect='auto', origin='lower',
                       extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]])
        ax.set_title(title); ax.set_xlabel("Time (s)"); ax.set_ylabel("Hz")
    else:
        S = audio_or_spec.squeeze()
        fig, ax = plt.subplots(figsize=(8, 3))
        im = ax.imshow(S, aspect='auto', origin='lower')
        ax.set_title(title); ax.set_xlabel("Frames"); ax.set_ylabel("Bins")
    fig.tight_layout()
    buf = BytesIO(); fig.savefig(buf, format='png', dpi=120); plt.close(fig); buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" alt="{title}">'


def _standardize_spec(S: np.ndarray, target_shape=(128, 256)) -> np.ndarray:
    """
    Pad/crop 2D (F,T) spectrogram to target_shape. Per-spec z-norm for cosine comparability.
    """
    S = np.asarray(S, dtype=np.float32)
    S = (S - np.mean(S)) / (np.std(S) + 1e-8)
    F, T = S.shape
    Ft, Tt = target_shape

    # F dimension
    if F < Ft:
        pad_top = (Ft - F) // 2
        pad_bot = Ft - F - pad_top
        S = np.pad(S, ((pad_top, pad_bot), (0, 0)), mode="constant")
    elif F > Ft:
        start = (F - Ft) // 2
        S = S[start:start+Ft, :]

    # T dimension
    F, T = S.shape
    if T < Tt:
        pad_left = (Tt - T) // 2
        pad_right = Tt - T - pad_left
        S = np.pad(S, ((0, 0), (pad_left, pad_right)), mode="constant")
    elif T > Tt:
        start = (T - Tt) // 2
        S = S[:, start:start+Tt]
    return S


def _load_spec_for_index(original_paths: np.ndarray, idx: int, fs: int = 10000,
                         target_shape=(128, 256)) -> Optional[np.ndarray]:
    """
    Load .npy (1D waveform or 2D spectrogram) and return standardized 2D spectrogram.
    """
    try:
        arr = np.load(original_paths[idx], mmap_mode="r")
        if arr.ndim == 1:
            S, _, _ = _ali_spec(arr.astype(np.float32), fs=fs)
        elif arr.ndim == 2:
            S = arr.astype(np.float32)
        else:
            return None
        return _standardize_spec(S, target_shape=target_shape)
    except Exception:
        return None


def _avg_intra_cluster_spec_cosine(original_paths: np.ndarray, idxs: np.ndarray,
                                   *, fs: int = 10000, max_samples: int = 50,
                                   target_shape=(128, 256)) -> float:
    """
    Average pairwise cosine similarity of standardized spectrograms within a cluster.
    Subsamples up to max_samples items for O(n^2) stability.
    """
    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)

    specs = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            specs.append(S.reshape(-1))
    if len(specs) < 2:
        return float("nan")
    X = np.stack(specs, axis=0)
    sim = cosine_similarity(X)
    iu = np.triu_indices_from(sim, k=1)
    return float(np.mean(sim[iu]))


def _avg_intra_cluster_spec_dtw(original_paths: np.ndarray, idxs: np.ndarray,
                                *, fs: int = 10000, max_samples: int = 25,
                                target_shape=(128, 256)) -> float:
    """
    Average pairwise DTW distance between standardized spectrogram time-traces.
    Lower = more similar. Optional: requires fastdtw.
    """
    try:
        from fastdtw import fastdtw
        from scipy.spatial.distance import euclidean
    except Exception:
        return float("nan")

    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)

    series = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            series.append(np.mean(S, axis=0))  # collapse freq
    if len(series) < 2:
        return float("nan")

    n = len(series)
    total, pairs = 0.0, 0
    for a in range(n):
        for b in range(a+1, n):
            d, _ = fastdtw(series[a], series[b], dist=euclidean)
            total += d
            pairs += 1
    return float(total / max(pairs, 1))


# ============================================================
# Embedding & metrics
# ============================================================

def _evaluate_cluster_metrics(embeddings: np.ndarray, idxs: np.ndarray,
                              location_labels: np.ndarray,
                              location_entropy_base: Optional[int] = None) -> Dict[str, float]:
    X = embeddings[idxs]
    if X.shape[0] == 0:
        return {'variance': 0.0, 'mean_sim': 1.0, 'entropy': 0.0, 'quality': 0.0, 'novelty': 0.0}
    center = np.mean(X, axis=0, keepdims=True)
    variance = float(np.mean(np.sum((X - center) ** 2, axis=1)))
    if len(X) > 1:
        cos_sim = cosine_similarity(X)
        iu = np.triu_indices_from(cos_sim, k=1)
        mean_sim = float(np.mean(cos_sim[iu])) if iu[0].size > 0 else 1.0
    else:
        mean_sim = 1.0
    loc_counts = Counter(location_labels[idxs])
    loc_probs = np.array(list(loc_counts.values()), dtype=np.float32)
    loc_probs /= max(loc_probs.sum(), 1e-8)
    base = int(location_entropy_base or len(set(location_labels)))
    loc_entropy = float(entropy(loc_probs, base=base)) if base > 1 else 0.0
    max_ent = np.log2(base) if base > 1 else 1.0
    entropy_score = 1.0 - (loc_entropy / max_ent) if base > 1 else 1.0
    quality = float((mean_sim / (variance + 1e-8)) * entropy_score)
    novelty = float((loc_entropy / max_ent) * variance) if base > 1 else 0.0
    return {'variance': variance, 'mean_sim': mean_sim,
            'entropy': loc_entropy, 'quality': quality, 'novelty': novelty}


def _load_supcon_encoder(checkpoint_path: str, device: torch.device) -> ConvSupConEncoder:
    enc = ConvSupConEncoder().to(device)
    sd = torch.load(checkpoint_path, map_location=device)
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    enc.load_state_dict(sd, strict=False)
    enc.eval()
    print(f"[INFO] SupCon encoder loaded from {checkpoint_path} on {device}")
    return enc


def _extract_embeds_from_dir(
    dataset_path: str,
    location_tag: str,
    checkpoint_path: str,
    *,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    batch_size: int = 256,
    num_workers: int = 0
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    all_files = _list_audio_npy_files(dataset_path)
    if len(all_files) == 0:
        raise RuntimeError(f"No .npy files found in {dataset_path}")
    n_sub = max(1, int(math.ceil(len(all_files) * subset_fraction)))
    rng = np.random.RandomState(subset_seed)
    chosen_idx = np.sort(rng.choice(len(all_files), size=n_sub, replace=False))
    chosen_files = [all_files[i] for i in chosen_idx]

    ds = SpectrogramDataset(dataset_path, os.path.join(dataset_path, "labels.npy"), augment_fn=None)
    subset = Subset(ds, chosen_idx.tolist())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    enc = _load_supcon_encoder(checkpoint_path, device)

    loader = DataLoader(subset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, drop_last=False)

    feats, cls_labels = [], []
    t0 = time.perf_counter(); seen = 0
    with torch.no_grad():
        for step, (x1, _, y) in enumerate(loader):
            x1 = x1.to(device, non_blocking=True)
            h = enc(x1).cpu().numpy().astype("float32")
            feats.append(h); cls_labels.append(y.numpy())
            seen += x1.size(0)
            if (step+1) % 10 == 0 or step == 0:
                dt = time.perf_counter() - t0
                print(f"[Embed] {seen}/{n_sub} • {seen/max(dt,1e-9):.1f} it/s")

    H = np.vstack(feats).astype("float32")
    cls = np.concatenate(cls_labels, axis=0)
    loc = np.array([location_tag] * H.shape[0], dtype=object)
    print(f"[INFO] {dataset_path} → {H.shape} embeds")
    return H, cls, loc, np.array(chosen_files, dtype=object)


# ============================================================
# Report generator
# ============================================================

def _carousel_html(cluster_id: int, scope: str, imgs: List[str]) -> str:
    cid = f"carousel_{scope}_{cluster_id}"
    indicators = "".join(
        f'<button type="button" data-bs-target="#{cid}" data-bs-slide-to="{i}" {"class=active" if i==0 else ""} aria-current="true" aria-label="Slide {i+1}"></button>'
        for i in range(len(imgs))
    )
    items = "".join(
        f'<div class="carousel-item {"active" if i==0 else ""}"><div class="d-flex justify-content-center">{img}</div></div>'
        for i, img in enumerate(imgs)
    )
    return f"""
    <div id="{cid}" class="carousel slide" data-bs-interval="false" data-bs-touch="false">
      <div class="carousel-indicators">{indicators}</div>
      <div class="carousel-inner">{items}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{cid}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{cid}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    """


def analyze_supcon_to_html(
    checkpoint_path: str,
    dataset_paths: List[str],
    location_tags: List[str],
    *,
    cluster_method: str = "kmeans",   # 'kmeans' | 'gmm' | 'agglomerative'
    n_clusters: int = 60,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    samples_per_cluster: int = 4,
    top_n_clusters: Optional[int] = None   # rank by size; None = show all
) -> str:
    assert len(dataset_paths) == len(location_tags), "dataset_paths and location_tags must match"

    # Collect embeddings from each location
    embeds_all, class_all, loc_all, files_all = [], [], [], []
    for path, tag in zip(dataset_paths, location_tags):
        H, cls, loc, files = _extract_embeds_from_dir(
            path, tag, checkpoint_path,
            subset_fraction=subset_fraction, subset_seed=subset_seed
        )
        embeds_all.append(H); class_all.append(cls); loc_all.append(loc); files_all.append(files)

    embeddings = np.vstack(embeds_all).astype("float32")
    class_labels = np.concatenate(class_all, axis=0).astype(object)
    location_labels = np.concatenate(loc_all, axis=0).astype(object)
    original_paths = np.concatenate(files_all, axis=0).astype(object)

    # UMAP (cosine) on embeddings
    print(f"[INFO] UMAP on {embeddings.shape[0]}×{embeddings.shape[1]}")
    reducer = umap.UMAP(n_components=3, n_neighbors=15, min_dist=0.1,
                        metric="cosine", random_state=42)
    proj_3d = reducer.fit_transform(embeddings)

    # Clustering in embedding space
    def _choose_clusterer(algorithm: str, X: np.ndarray, k: int):
        if algorithm == "kmeans":
            model = KMeans(n_clusters=k, n_init='auto', random_state=42).fit(X)
            return model.labels_, None
        elif algorithm == "agglomerative":
            from sklearn.cluster import AgglomerativeClustering
            model = AgglomerativeClustering(n_clusters=k).fit(X)
            return model.labels_, None
        elif algorithm == "gmm":
            from sklearn.mixture import GaussianMixture
            model = GaussianMixture(n_components=k, covariance_type='full', random_state=42).fit(X)
            return model.predict(X), model.predict_proba(X)
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")

    print(f"[INFO] Clustering with {cluster_method} (k={n_clusters})")
    cluster_labels, _ = _choose_clusterer(cluster_method, embeddings, n_clusters)

    # Global metrics (embedding space, not UMAP)
    unique_clusters = np.unique(cluster_labels)
    valid = (len(unique_clusters) > 1) and (embeddings.shape[0] > len(unique_clusters))

    def _safe_metric(fn, X, y):
        try:
            return float(fn(X, y)) if valid else float("nan")
        except Exception:
            return float("nan")

    sil = _safe_metric(silhouette_score, embeddings, cluster_labels)
    dbi = _safe_metric(davies_bouldin_score, embeddings, cluster_labels)
    ch  = _safe_metric(calinski_harabasz_score, embeddings, cluster_labels)

    title_txt = f"{cluster_method.capitalize()} (Sil={sil:.3f} | DBI={dbi:.3f} | CH={ch:.1f})"
    fig = px.scatter_3d(
        x=proj_3d[:,0], y=proj_3d[:,1], z=proj_3d[:,2],
        color=[str(c) for c in cluster_labels],
        symbol=location_labels,
        hover_data={"Cluster": cluster_labels, "Class": class_labels},
        title=title_txt, opacity=0.85, height=800
    )
    umap_html = pio.to_html(fig, include_plotlyjs="cdn", full_html=False)

    # Per-cluster sections
    base_for_entropy = len(set(location_labels))
    cluster_ids = sorted(set(cluster_labels), key=lambda c: np.sum(cluster_labels==c), reverse=True)
    if top_n_clusters is not None:
        cluster_ids = cluster_ids[:int(top_n_clusters)]

    # Accumulators for global spectrogram consistency summary
    spec_cos_scores, spec_dtw_scores = [], []

    blocks = []
    for cid in cluster_ids:
        idxs = np.where(cluster_labels == cid)[0]
        if idxs.size == 0:
            continue

        sz = len(idxs)
        loc_counts = Counter(location_labels[idxs])
        cls_counts = Counter(class_labels[idxs])
        metrics = _evaluate_cluster_metrics(embeddings, idxs, location_labels, location_entropy_base=base_for_entropy)

        # New: spectrogram-space intra-cluster similarity
        spec_cos = _avg_intra_cluster_spec_cosine(original_paths, idxs, fs=10000, max_samples=50)
        spec_dtw = _avg_intra_cluster_spec_dtw(original_paths, idxs, fs=10000, max_samples=25)

        if np.isfinite(spec_cos):
            spec_cos_scores.append(spec_cos)
        if np.isfinite(spec_dtw):
            spec_dtw_scores.append(spec_dtw)

        meta_html = "<p><strong>Location Distribution:</strong></p><ul>" + "".join(
            f"<li><b>{loc}</b>: {count} ({count/sz:.1%})</li>" for loc, count in loc_counts.items()
        ) + "</ul>"
        meta_html += "<p><strong>Class Distribution:</strong></p><ul>" + "".join(
            f"<li>{cls}: {count}</li>" for cls, count in cls_counts.items()
        ) + "</ul>"
        meta_html += f"""
        <p><strong>Cluster Metrics (Embedding Space):</strong></p>
        <ul>
          <li>Size: {sz}</li>
          <li>Intra-Cluster Variance: {metrics['variance']:.4f}</li>
          <li>Mean Cosine Similarity (Embeddings): {metrics['mean_sim']:.4f}</li>
          <li>Location Entropy: {metrics['entropy']:.3f}</li>
          <li>Composite Quality Score: {metrics['quality']:.4f}</li>
          <li><strong>Novelty Score:</strong> {metrics['novelty']:.4f}</li>
        </ul>
        <p><strong>Spectrogram Consistency (Signal Space):</strong></p>
        <ul>
          <li>Mean Spectrogram Cosine Similarity (↑ better): {spec_cos if np.isfinite(spec_cos) else float('nan'):.4f}</li>
          <li>Mean Spectrogram DTW Distance (↓ better): {spec_dtw if np.isfinite(spec_dtw) else float('nan'):.2f}</li>
        </ul>
        """

        # nearest-to-center spectrogram exemplars
        center = np.mean(embeddings[idxs], axis=0, keepdims=True)
        dists = np.linalg.norm(embeddings[idxs] - center, axis=1)
        order = np.argsort(dists)
        pick = idxs[order[:min(samples_per_cluster, len(order))]]

        thumbs = []
        for i, p in enumerate(pick):
            try:
                arr = np.load(original_paths[p], mmap_mode="r")
                title = f"#{i+1} | {location_labels[p]} | Class {class_labels[p]}"
                thumbs.append(_spec_img_base64(arr, title))
            except Exception as e:
                thumbs.append(f"<p class='text-danger'>Error: {e}</p>")

        blocks.append(f"<div class='col-md-6 mb-4'><h4>Cluster {cid}</h4>{meta_html}{_carousel_html(cid, 'supcon', thumbs)}</div>")

    cluster_html = ""
    for i in range(0, len(blocks), 2):
        cluster_html += "<div class='row'>" + "".join(blocks[i:i+2]) + "</div>"

    # Global spectrogram-consistency summary
    global_spec_cos = float(np.mean(spec_cos_scores)) if len(spec_cos_scores) else float("nan")
    global_spec_dtw = float(np.mean(spec_dtw_scores)) if len(spec_dtw_scores) else float("nan")

    summary_card = f"""
    <div class='row mb-4'>
      <div class='col-md-12'>
        <div class='alert alert-info' role='alert'>
          <h5 class='mb-2'>Spectrogram Consistency (Cluster Averages)</h5>
          <ul class='mb-0'>
            <li><strong>Mean Spectrogram Cosine Similarity</strong>: {global_spec_cos if np.isfinite(global_spec_cos) else float('nan'):.4f} (↑ better)</li>
            <li><strong>Mean Spectrogram DTW Distance</strong>: {global_spec_dtw if np.isfinite(global_spec_dtw) else float('nan'):.2f} (↓ better)</li>
          </ul>
          <small>Note: Cosine computed on standardized spectrograms (z-norm, padded/cropped to a common shape). DTW is optional and subsampled for speed.</small>
        </div>
      </div>
    </div>
    """

    return f"""
    <div class='section'>
      <h2>{cluster_method.capitalize()} Clustering Analysis (SupCon • subset {int(subset_fraction*100)}%)</h2>
      <div class="container">
        <div class="row justify-content-center mb-4">
          <div class="col-md-12 d-flex justify-content-center">{umap_html}</div>
        </div>
        {summary_card}
      </div>
      {cluster_html}
    </div>
    """


def generate_supcon_full_report(
    checkpoint_path: str,
    dataset_paths: List[str],
    location_tags: List[str],
    *,
    out_html: str = "supcon_unsupervised_report.html",
    cluster_method: str = "kmeans",
    n_clusters: int = 60,
    subset_fraction: float = 0.02,
    subset_seed: int = 42,
    samples_per_cluster: int = 4,
    top_n_clusters: Optional[int] = None,
) -> str:
    section = analyze_supcon_to_html(
        checkpoint_path=checkpoint_path,
        dataset_paths=list(dataset_paths),
        location_tags=list(location_tags),
        cluster_method=cluster_method,
        n_clusters=int(n_clusters),
        subset_fraction=subset_fraction,
        subset_seed=subset_seed,
        samples_per_cluster=samples_per_cluster,
        top_n_clusters=top_n_clusters,
    )
    html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Unsupervised Clustering Report (SupCon)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body {{ font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }}
            h1 {{ color: #2c3e50; }}
            h2, h4 {{ color: #34495e; }}
            hr {{ border-top: 2px solid #bbb; margin-top: 40px; margin-bottom: 40px; }}
            .section {{ margin-bottom: 60px; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
        </style>
    </head>
    <body>
      <h1 class='mb-4'>Unsupervised Latent Clustering Report (SupCon)</h1>
      <p><strong>Locations:</strong> {', '.join(location_tags)}</p>
      <p><strong>Subset:</strong> {int(subset_fraction*100)}% • Seed: {subset_seed}</p>
      <p><strong>Clusterer:</strong> {cluster_method} • <strong>k</strong> = {n_clusters}</p>
      <hr>
      {section}
    </body></html>
    """
    with open(out_html, "w", encoding="utf-8") as f:
        f.write(html)
    print(f"✅ Report saved to: {out_html}")
    return out_html


# ============================================================
# CLI
# ============================================================

out_path = generate_supcon_full_report(
    checkpoint_path="best_encoder_pretrain.pth",
    dataset_paths=["/notebooks/dataset_preprocessed"],  # list
    location_tags=["PR_U1137"],                         # same length as dataset_paths
    cluster_method="kmeans",                            # or "gmm", "agglomerative"
    n_clusters=60,
    out_html="SupCon_unsupervised_clustering_report.html",
    subset_fraction=0.2,
    subset_seed=42,
    samples_per_cluster=4,   # optional (default 4)
    top_n_clusters=40,       # optional: show only top-N clusters; remove to show all
)

[INFO] SupCon encoder loaded from best_encoder_pretrain.pth on cuda
[Embed] 256/82655 • 202.0 it/s
[Embed] 2560/82655 • 140.9 it/s
[Embed] 5120/82655 • 117.4 it/s
[Embed] 7680/82655 • 95.7 it/s
[Embed] 10240/82655 • 86.5 it/s
[Embed] 12800/82655 • 80.1 it/s
[Embed] 15360/82655 • 81.8 it/s
[Embed] 17920/82655 • 79.9 it/s
[Embed] 20480/82655 • 72.5 it/s
[Embed] 23040/82655 • 72.7 it/s
[Embed] 25600/82655 • 73.8 it/s
[Embed] 28160/82655 • 75.0 it/s
[Embed] 30720/82655 • 73.5 it/s
[Embed] 33280/82655 • 74.4 it/s
[Embed] 35840/82655 • 75.1 it/s
[Embed] 38400/82655 • 72.4 it/s
[Embed] 40960/82655 • 72.1 it/s
[Embed] 43520/82655 • 72.1 it/s
[Embed] 46080/82655 • 72.1 it/s
[Embed] 48640/82655 • 73.0 it/s
[Embed] 51200/82655 • 73.4 it/s
[Embed] 53760/82655 • 74.5 it/s
[Embed] 56320/82655 • 75.3 it/s
[Embed] 58880/82655 • 75.6 it/s
[Embed] 61440/82655 • 74.9 it/s
[Embed] 64000/82655 • 75.4 it/s
[Embed] 66560/82655 • 75.0 it/s
[Embed] 69120/82655 • 74.8 it/s
[Embed] 71680/82655 • 75.0 it/s
[Embed


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



[INFO] Clustering with kmeans (k=60)
✅ Report saved to: SupCon_unsupervised_clustering_report.html
