Dependencies
---------------

In [1]:
!pip install librosa
!pip install umap-learn
!pip install plotly
!pip install hdbscan
!pip install --upgrade pip
!pip install -U soundfile datasets[audio] librosa
# optional but helpful in many audio setups:
!pip install -U torchaudio audioread

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.62.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-1.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.1 kB)
Collecting llvmlite<0.46,>=0.45.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.45.1-cp311-cp311-manylinux2014_x86_64

SimCLR trainer 
------------------

In [2]:
# === SimCLR (reef acoustics) with Feature Bank + SimSiam + VICReg ============================
# NaN-proof + faster:
#  - Contrastive logits never use -inf; use large negatives and clamp
#  - Explicit sanitization (nan/inf -> 0) on all embeddings
#  - Loss computed in fp32 (AMP only for forwards)
#  - VICReg made numerically safer
#  - SimSiam & VICReg warmup ramp (first 5 epochs)
#  - bank_size=8192 (per paper), bank_sample default 2048/4096 for speed
#  - Cached MelSpectrograms, channels_last, cuDNN autotune, optional torch.compile
# ==============================================================================================

import os, math, random, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchaudio as ta
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from typing import Optional, Tuple, Dict

# ----------------------------
# Speed/Determinism toggles
# ----------------------------
torch.backends.cudnn.benchmark = True
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)

# ----------------------------
# Backbone: ResNet18 (1-channel)
# ----------------------------
try:
    from torchvision.models import resnet18
    def _resnet18_1ch():
        m = resnet18(weights=None)
        m.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        m.fc = nn.Identity()
        return m
except Exception:
    from torchvision.models import resnet18 as _resnet18_old
    def _resnet18_1ch():
        m = _resnet18_old(pretrained=False)
        m.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        m.fc = nn.Identity()
        return m

# ----------------------------
# Encoder + Projector + Predictor (for SimSiam)
# ----------------------------
class EncoderProj(nn.Module):
    def __init__(self, proj_dim=128, use_predictor=True):
        super().__init__()
        self.backbone = _resnet18_1ch()
        self.projector = nn.Sequential(
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True),
            nn.Linear(256, proj_dim)
        )
        self.predictor = None
        if use_predictor:
            self.predictor = nn.Sequential(
                nn.Linear(proj_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True),
                nn.Linear(256, proj_dim)
            )

    def forward(self, x, return_backbone=False, for_predictor=False):
        h = self.backbone(x)              # (B, 512)
        z = F.normalize(self.projector(h), dim=1)
        if for_predictor and (self.predictor is not None):
            p = self.predictor(z)
            p = F.normalize(p, dim=1)
            if return_backbone:
                return F.normalize(h, dim=1), z, p
            return z, p
        if return_backbone:
            return F.normalize(h, dim=1), z
        return z

# ----------------------------
# EMA Teacher
# ----------------------------
class EMATeacher(nn.Module):
    def __init__(self, online: EncoderProj, m=0.99, proj_dim=128):
        super().__init__()
        self.teacher = EncoderProj(proj_dim=proj_dim, use_predictor=False)
        self.m = m
        self._init_from(online)
        for p in self.teacher.parameters():
            p.requires_grad_(False)
        self.teacher.eval()

    @torch.no_grad()
    def _init_from(self, online):
        state = online.state_dict()
        state = {k: v for k, v in state.items() if not k.startswith("predictor.")}
        self.teacher.load_state_dict(state, strict=False)

    @torch.no_grad()
    def update(self, online, m=None):
        m = float(self.m if m is None else m)
        for p_t, p_o in zip(self.teacher.parameters(), online.parameters()):
            p_t.data.mul_(m).add_(p_o.data, alpha=(1.0 - m))
        for b_t, b_o in zip(self.teacher.buffers(), online.buffers()):
            if b_t.dtype.is_floating_point:
                b_t.data.mul_(m).add_(b_o.data, alpha=(1.0 - m))

    @torch.no_grad()
    def forward(self, x, return_backbone=False):
        return self.teacher(x, return_backbone=return_backbone)

# ----------------------------
# Utilities
# ----------------------------
def _sanitize_(*tensors):
    """Replace NaN/Inf with 0.0 on every tensor; returns sanitized copies."""
    out = []
    for t in tensors:
        if t is None:
            out.append(None)
        else:
            out.append(torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0))
    return out

# ----------------------------
# FIFO Feature Bank (teacher embeddings, detached)
# ----------------------------
class FeatureBank:
    def __init__(self, dim, size=8192, device=None):
        device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bank = torch.zeros(size, dim, device=device)
        self.ptr = 0
        self.size = size
        self._filled = 0

    @torch.no_grad()
    def enqueue(self, feats):  # feats: (N, dim), detached
        if feats is None or feats.numel() == 0:
            return
        feats = F.normalize(feats, dim=1)
        feats = torch.nan_to_num(feats, nan=0.0, posinf=0.0, neginf=0.0)
        n = feats.size(0)
        end = self.ptr + n
        if end <= self.size:
            self.bank[self.ptr:end] = feats
        else:
            first = self.size - self.ptr
            self.bank[self.ptr:] = feats[:first]
            self.bank[:end % self.size] = feats[first:]
        self.ptr = (self.ptr + n) % self.size
        self._filled = min(self.size, self._filled + n)

    @torch.no_grad()
    def get(self):
        if self._filled == 0:
            return self.bank[:0]
        return self.bank[:self._filled].detach()

# ----------------------------
# Multi-Positive NT-Xent (weighted by teacher) + Memory Bank (NaN-proof)
# ----------------------------
class MultiPosNTXentSoft(nn.Module):
    def __init__(self, temperature=0.1, pos_topk=5, pos_thr=0.7):
        super().__init__()
        self.tau = float(temperature)
        self.pos_topk = int(pos_topk)
        self.pos_thr = float(pos_thr)

    @torch.no_grad()
    def _weights_from_teacher(self, t1, t2):
        B = t1.size(0)
        T = F.normalize(torch.cat([t1, t2], dim=0), dim=1)
        T = torch.nan_to_num(T, nan=0.0, posinf=0.0, neginf=0.0)
        S = T @ T.T
        eye = torch.eye(2 * B, device=T.device, dtype=torch.bool)
        S = S.masked_fill(eye, -1e9)  # large negative, not -inf

        W = torch.zeros_like(S)
        idx = torch.arange(B, device=T.device)
        W[idx, idx + B] = 1.0
        W[idx + B, idx] = 1.0

        S_thr = torch.where(S >= self.pos_thr, S, torch.full_like(S, -1e9))
        k = min(self.pos_topk, 2 * B - 2)
        if k > 0:
            vals, inds = torch.topk(S_thr, k=k, dim=1)
            rows = torch.arange(2 * B, device=T.device).unsqueeze(1).expand_as(inds)
            keep = torch.isfinite(vals)
            weights = torch.clamp((vals + 1.0) / 2.0, min=0.0, max=1.0)
            W[rows[keep], inds[keep]] = weights[keep]

        W = torch.where(eye, torch.zeros_like(W), W)
        row_sum = W.sum(dim=1, keepdim=True).clamp_min(1e-6)
        W = W / row_sum
        return W

    def forward_with_bank(self, z1, z2, t1, t2, bank_feats):
        B = z1.size(0)
        Q = F.normalize(torch.cat([z1, z2], dim=0), dim=1).float()
        Q = torch.nan_to_num(Q, nan=0.0, posinf=0.0, neginf=0.0)

        cur_keys = F.normalize(torch.cat([z1.detach(), z2.detach()], dim=0), dim=1)
        cur_keys = torch.nan_to_num(cur_keys, nan=0.0, posinf=0.0, neginf=0.0)

        if bank_feats is not None and bank_feats.numel() > 0:
            bank_feats = torch.nan_to_num(bank_feats, nan=0.0, posinf=0.0, neginf=0.0)
            K = torch.cat([cur_keys, bank_feats], dim=0)
        else:
            K = cur_keys

        logits = (Q @ K.T) / self.tau
        eye = torch.eye(2 * B, device=Q.device, dtype=torch.bool)
        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask[:, :2 * B] = eye
        logits = logits.masked_fill(mask, -1e9)    # never -inf
        logits = logits.clamp(min=-30.0, max=30.0) # keep exp stable

        log_denom = torch.logsumexp(logits, dim=1, keepdim=True)
        log_probs = logits - log_denom

        with torch.no_grad():
            W = self._weights_from_teacher(t1, t2)
        if K.size(0) > 2 * B:
            pad_zeros = torch.zeros(Q.size(0), K.size(0) - 2 * B, device=Q.device, dtype=W.dtype)
            W_full = torch.cat([W, pad_zeros], dim=1)
        else:
            W_full = W

        loss = -(W_full * log_probs).sum(dim=1).mean()
        if not torch.isfinite(loss):
            loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
        return loss

# ----------------------------
# SimSiam & VICReg (safer)
# ----------------------------
def simsiam_loss(p, z_target):
    p = torch.nan_to_num(F.normalize(p, dim=1), nan=0.0, posinf=0.0, neginf=0.0)
    zt = torch.nan_to_num(F.normalize(z_target, dim=1), nan=0.0, posinf=0.0, neginf=0.0)
    out = - (p * zt).sum(dim=1).mean()
    return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

def vicreg_loss(z_a, z_b, sim_coeff=25.0, var_coeff=25.0, cov_coeff=1.0, eps=1e-4, gamma=1.0):
    z_a = torch.nan_to_num(z_a, nan=0.0, posinf=0.0, neginf=0.0)
    z_b = torch.nan_to_num(z_b, nan=0.0, posinf=0.0, neginf=0.0)

    inv = F.mse_loss(z_a, z_b)

    def _var(z):
        std = torch.sqrt(z.var(dim=0, unbiased=False).clamp_min(0.0) + eps)
        return torch.mean(F.relu(gamma - std))

    def _cov(z):
        N, D = z.size()
        zc = z - z.mean(dim=0, keepdim=True)
        cov = (zc.T @ zc) / max(1, N - 1)
        off_diag = cov - torch.diag(torch.diag(cov))
        return (off_diag ** 2).sum() / max(1, D)

    z_a = z_a - z_a.mean(dim=0, keepdim=True)
    z_b = z_b - z_b.mean(dim=0, keepdim=True)

    var = _var(z_a) + _var(z_b)
    cov = _cov(z_a) + _cov(z_b)
    out = sim_coeff * inv + var_coeff * var + cov_coeff * cov
    return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

# ----------------------------
# Augmentations (spectral notch + others)
# ----------------------------
class SpectroAug(nn.Module):
    def __init__(
        self,
        time_mask_param=32,
        freq_mask_param=12,
        p_time_mask=0.8,
        p_freq_mask=0.8,
        p_notch=0.3, notch_width=6,
        p_shift=0.8,
        p_noise=0.5,
        max_shift_frac=0.12,
        noise_std=0.01,
        p_time_crop=0.5,
        crop_frac_range=(0.7, 1.0),
        keep_length=True,
    ):
        super().__init__()
        self.tm = ta.transforms.TimeMasking(time_mask_param=time_mask_param)
        self.fm = ta.transforms.FrequencyMasking(freq_mask_param=freq_mask_param)
        self.p_time_mask = p_time_mask
        self.p_freq_mask = p_freq_mask
        self.p_notch = p_notch
        self.notch_width = notch_width
        self.p_shift = p_shift
        self.p_noise = p_noise
        self.max_shift_frac = max_shift_frac
        self.noise_std = noise_std
        self.p_time_crop = p_time_crop
        self.crop_lo, self.crop_hi = crop_frac_range
        self.keep_length = keep_length

    def forward(self, x):  # x: (1, F, T0)
        T0 = x.size(2)

        if torch.rand(1) < self.p_time_crop and T0 > 16:
            frac = float(torch.empty(1).uniform_(self.crop_lo, self.crop_hi))
            new_T = max(16, int(T0 * frac))
            if new_T < T0:
                start = int(torch.randint(0, T0 - new_T + 1, (1,)))
                x = x[:, :, start:start + new_T]

        if torch.rand(1) < self.p_time_mask: x = self.tm(x)
        if torch.rand(1) < self.p_freq_mask: x = self.fm(x)

        if torch.rand(1) < self.p_notch:
            Fdim = x.size(1)
            width = min(self.notch_width, Fdim)
            start = int(torch.randint(0, max(1, Fdim - width + 1), (1,)))
            x[:, start:start+width, :] = 0.0

        if torch.rand(1) < self.p_shift:
            T_dim2 = x.size(2)
            max_shift = max(1, int(self.max_shift_frac * T_dim2))
            shift = int(torch.randint(-max_shift, max_shift + 1, (1,)))
            x = torch.roll(x, shifts=shift, dims=2)

        if torch.rand(1) < self.p_noise:
            x = x + self.noise_std * torch.randn_like(x)

        if self.keep_length and x.size(2) != T0:
            pad = T0 - x.size(2)
            if pad > 0: x = F.pad(x, (0, pad), mode='constant', value=0.0)
            else:       x = x[:, :, :T0]
        return x.contiguous()

# ----------------------------
# Multi-Resolution + Local/Global Multi-Crop Dataset (with cached mels)
# ----------------------------
class MultiResCropDataset(Dataset):
    """
    Returns list of views per sample: [G1, G2, L1, L2, ...], each (1, F, T).
    Caches MelSpectrogram ops per (n_mels, n_fft, hop) for speed.
    """
    def __init__(
        self,
        audio_dir: str,
        sr: int = 10000,
        mel_choices: Tuple[Tuple[int,int,int], ...] = ((64,1024,512),(128,1024,256),(128,2048,512)),
        n_global: int = 2, global_T: int = 256,
        n_local: int = 2,  local_T: int = 96,
        event_jitter_T: int = 24,
        per_sample_norm: bool = True,
        augment_fn: Optional[nn.Module] = None,
        resize_to: Optional[Tuple[int,int]] = (128, 256),
    ):
        self.audio_paths = sorted(
            [os.path.join(audio_dir, f) for f in os.listdir(audio_dir)
             if f.endswith(".npy") and f != "labels.npy"],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        self.sr = sr
        self.mel_choices = list(mel_choices)
        self.n_global, self.global_T = n_global, global_T
        self.n_local,  self.local_T  = n_local,  local_T
        self.event_jitter_T = event_jitter_T
        self.per_sample_norm = per_sample_norm
        self.augment_fn = augment_fn
        self.resize_to = resize_to

        # cache Mel ops
        self._mel_cache: Dict[Tuple[int,int,int], ta.transforms.MelSpectrogram] = {}
        for n_mels, n_fft, hop in set(mel_choices + ((64,1024,512),)):
            self._mel_cache[(n_mels, n_fft, hop)] = ta.transforms.MelSpectrogram(
                sample_rate=self.sr, n_fft=n_fft, hop_length=hop, n_mels=n_mels
            )

    def __len__(self): return len(self.audio_paths)

    def _wave_to_logmel(self, y, n_mels, n_fft, hop):
        y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        mx = float(np.max(np.abs(y)));  mx = 1.0 if (not np.isfinite(mx) or mx < 1e-8) else mx
        y = (y / mx).astype(np.float32)
        y_t = torch.from_numpy(y).unsqueeze(0)  # (1,T)
        mel = self._mel_cache[(n_mels, n_fft, hop)](y_t)
        mel = torch.nan_to_num(mel, nan=0.0, posinf=0.0, neginf=0.0)
        logmel = torch.log(torch.clamp(mel, min=1e-5))
        logmel = torch.nan_to_num(logmel, nan=0.0, posinf=0.0, neginf=0.0)
        if self.per_sample_norm:
            m, s = logmel.mean(), logmel.std()
            if (not torch.isfinite(s)) or (s < 1e-6): s = torch.tensor(1e-6, device=logmel.device)
            logmel = (logmel - m) / s
        return logmel  # (1, F, T)

    def _energy_center(self, logmel):
        energy = logmel.mean(dim=1).squeeze(0)  # (T,)
        return int(torch.argmax(energy).item())

    def _crop_T(self, x: torch.Tensor, T_win: int, center: Optional[int], jitter_T: int) -> torch.Tensor:
        T = x.size(2)
        if T <= T_win:
            return x
        if center is None:
            start = int(torch.randint(0, T - T_win + 1, (1,)))
        else:
            start = max(0, min(center - T_win // 2, T - T_win))
            j = int(torch.randint(-jitter_T, jitter_T + 1, (1,)))
            start = max(0, min(start + j, T - T_win))
        return x[:, :, start:start + T_win]

    def _maybe_resize(self, x: torch.Tensor) -> torch.Tensor:
        if self.resize_to is None:
            return x
        Ft, Tt = self.resize_to
        x4 = x.unsqueeze(0)  # (1,1,F,T)
        x4 = F.interpolate(x4, size=(Ft, Tt), mode='bilinear', align_corners=False)
        return x4.squeeze(0)

    def __getitem__(self, idx):
        y = np.load(self.audio_paths[idx]).astype(np.float32)

        base = self._wave_to_logmel(y, n_mels=64, n_fft=1024, hop=512)
        center = self._energy_center(base)

        views = []
        for _ in range(self.n_global):
            n_mels, n_fft, hop = random.choice(self.mel_choices)
            lm = self._wave_to_logmel(y, n_mels, n_fft, hop)
            crop = self._crop_T(lm, self.global_T, center=center, jitter_T=self.event_jitter_T)
            if self.augment_fn is not None: crop = self.augment_fn(crop)
            crop = self._maybe_resize(crop)
            views.append(crop)

        for _ in range(self.n_local):
            n_mels, n_fft, hop = random.choice(self.mel_choices)
            lm = self._wave_to_logmel(y, n_mels, n_fft, hop)
            crop = self._crop_T(lm, self.local_T, center=center, jitter_T=self.event_jitter_T)
            if self.augment_fn is not None: crop = self.augment_fn(crop)
            crop = self._maybe_resize(crop)
            views.append(crop)

        return views  # list[(1,Ft,Tt), ...]

# ----------------------------
# Collate: multi-crop -> list of (B,1,F,T)
# ----------------------------
def multicrop_collate(batch):
    n_views = len(batch[0])
    out = []
    for v in range(n_views):
        V = torch.stack([sample[v] for sample in batch], dim=0)
        out.append(V)
    return out

# ----------------------------
# kNN-like sanity (multi-crop loaders)
# ----------------------------
@torch.no_grad()
def knn_sim_eval_multicrop(encoder, loader, device, use_backbone=True, K=20):
    encoder.eval()
    feats = []
    for views in loader:
        x = views[0].to(device, non_blocking=True)
        h, z = encoder(x, return_backbone=True)
        feats.append(h if use_backbone else z)
    Fmat = F.normalize(torch.cat(feats, dim=0), dim=1)
    sim = Fmat @ Fmat.T
    sim.fill_diagonal_(-1)
    knn_vals, _ = torch.topk(sim, k=min(K, sim.size(1)-1), dim=1)
    return float(knn_vals.mean().item())

# ----------------------------
# Schedulers
# ----------------------------
def warmup_cosine_lambda(step, warmup_steps, total_steps):
    if step < warmup_steps: return (step + 1) / max(1, warmup_steps)
    t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * t))

def ema_momentum_at(step, total_steps, m_base=0.99, m_final=0.9995):
    cos_t = 0.5 * (1 + math.cos(math.pi * step / max(1, total_steps)))
    return m_final - (m_final - m_base) * cos_t

# ----------------------------
# Training
# ----------------------------
def train_simclr_multipos(
    audio_dir,
    epochs=200,
    batch_size=256,
    lr=3e-4,
    weight_decay=1e-4,
    temperature=0.1,              # hotter by default
    accum_steps=2,
    ema_m=0.99,
    pos_topk=5,
    pos_thr=0.7,
    save_path="simclr_multipos_latest.pth",
    warmup_epochs=10,
    num_workers=4,
    compile_model=False,
    clip_grad_norm=1.0,
    seed=42,
    resume_path=None,
    # dataset knobs
    mel_choices=((64,1024,512),(128,1024,256),(128,2048,512)),
    n_global=2, global_T=256,
    n_local=2,  local_T=96,
    resize_to=(128,256),
    # bank knobs
    bank_size=8192,
    bank_sample=2048,             # 2048–4096 is good; lower = faster
    # loss weights
    alpha_contrast=1.0,
    beta_siam=0.1,
    gamma_vicreg=0.1,
    vic_sim=25.0, vic_var=25.0, vic_cov=1.0,
):
    """
    SimCLR with:
      - EMA teacher + soft multi-positive NT-Xent (two global views),
      - FIFO Feature Bank (teacher embeddings) with per-step sampling,
      - SimSiam invariance,
      - VICReg regularization.
    """
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ------- data -------
    aug = SpectroAug(keep_length=True)

    full_train_ds = MultiResCropDataset(
        audio_dir,
        mel_choices=mel_choices,
        n_global=n_global, global_T=global_T,
        n_local=n_local,   local_T=local_T,
        per_sample_norm=True,
        augment_fn=aug,
        resize_to=resize_to,
    )

    full_val_ds = MultiResCropDataset(
        audio_dir,
        mel_choices=mel_choices,
        n_global=1, global_T=global_T,
        n_local=0,  local_T=local_T,
        per_sample_norm=True,
        augment_fn=aug,  # set None for absolutely clean val
        resize_to=resize_to,
    )

    n_total = len(full_train_ds)
    train_size = int(0.9 * n_total)
    val_size   = n_total - train_size
    idx_all = list(range(n_total))
    g = torch.Generator().manual_seed(seed)
    train_subset, val_subset = random_split(idx_all, [train_size, val_size], generator=g)
    train_ds = Subset(full_train_ds, train_subset.indices)
    val_ds   = Subset(full_val_ds,   val_subset.indices)

    pin = (device.type == "cuda")
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, drop_last=True, pin_memory=pin,
        collate_fn=multicrop_collate, persistent_workers=False
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=max(1, num_workers // 2), drop_last=False, pin_memory=pin,
        collate_fn=multicrop_collate, persistent_workers=False
    )

    # ------- models -------
    online = EncoderProj(proj_dim=128, use_predictor=True).to(device).to(memory_format=torch.channels_last)
    teacher = EMATeacher(online, m=ema_m, proj_dim=128).to(device).to(memory_format=torch.channels_last)
    if compile_model and hasattr(torch, "compile"):
        online = torch.compile(online)
        teacher.teacher = torch.compile(teacher.teacher)

    # ------- objective & optimizer -------
    ntxent = MultiPosNTXentSoft(temperature=temperature, pos_topk=pos_topk, pos_thr=pos_thr)
    optimizer = AdamW(online.parameters(), lr=lr, weight_decay=weight_decay)

    steps_per_epoch = max(1, math.ceil(len(train_loader) / max(1, accum_steps)))
    total_steps = max(1, epochs * steps_per_epoch)
    warmup_steps = max(1, warmup_epochs * steps_per_epoch)
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: warmup_cosine_lambda(step, warmup_steps, total_steps))

    # AMP scaler: we keep AMP on for forwards; loss is computed in fp32
    use_cuda = (device.type == "cuda")
    amp_dtype = torch.bfloat16 if (use_cuda and torch.cuda.is_bf16_supported()) else torch.float16
    scaler = GradScaler(enabled=(use_cuda and amp_dtype == torch.float16))

    # ------- feature bank (teacher embeddings) -------
    bank = FeatureBank(dim=128, size=bank_size, device=device)

    best_sim = -1e9
    global_step = 0
    start_epoch = 1

    # ------- resume -------
    if resume_path is None:
        resume_path = save_path.replace(".pth", "_ckpt.pth")
    if resume_path and os.path.isfile(resume_path):
        ckpt = torch.load(resume_path, map_location=device)
        if isinstance(ckpt, dict) and "online" in ckpt and "optimizer" in ckpt:
            try:
                online.load_state_dict(ckpt["online"], strict=True)
                teacher.teacher.load_state_dict(ckpt["teacher"], strict=False)
            except Exception as e:
                print(f"[Resume] strict load failed: {e} — retrying with strict=False")
                online.load_state_dict(ckpt["online"], strict=False)
                teacher.teacher.load_state_dict(ckpt["teacher"], strict=False)
            try: optimizer.load_state_dict(ckpt["optimizer"])
            except Exception as e: print(f"[Resume] optimizer not loaded: {e}")
            try: scheduler.load_state_dict(ckpt["scheduler"])
            except Exception as e: print(f"[Resume] scheduler not loaded: {e}")
            if scaler.is_enabled() and ckpt.get("scaler") is not None:
                try: scaler.load_state_dict(ckpt["scaler"])
                except Exception as e: print(f"[Resume] scaler not loaded: {e}")
            global_step = int(ckpt.get("global_step", 0))
            start_epoch = int(ckpt.get("epoch", 0)) + 1
            best_sim = float(ckpt.get("best_sim", best_sim))
            print(f"[Resume] FULL checkpoint loaded (epoch={start_epoch-1}, best_sim={best_sim:.4f})")
        else:
            print(f"[Resume] Weights-only file detected; loading encoder only.")
            try:
                online.load_state_dict(ckpt, strict=True)
            except Exception as e:
                print(f"[Resume] strict load failed: {e} — using strict=False")
                online.load_state_dict(ckpt, strict=False)
            with torch.no_grad():
                teacher._init_from(online)

    # ------- train loop -------
    for epoch in range(start_epoch, epochs + 1):
        online.train()
        running = 0.0
        optimizer.zero_grad(set_to_none=True)

        # Warmup ramp for auxiliary losses (prevents early explosions)
        aux_ramp = min(1.0, epoch / 5.0)
        eff_beta_siam    = beta_siam    * aux_ramp
        eff_gamma_vicreg = gamma_vicreg * aux_ramp

        for it, views in enumerate(train_loader):
            assert len(views) >= 2, "Need at least two global views"
            v_g1 = views[0].to(device, non_blocking=True).to(memory_format=torch.channels_last)
            v_g2 = views[1].to(device, non_blocking=True).to(memory_format=torch.channels_last)
            v_l  = views[2].to(device, non_blocking=True).to(memory_format=torch.channels_last) if len(views) > 2 else None

            # Online forward under AMP
            with autocast(enabled=use_cuda, dtype=amp_dtype):
                z_g1 = online(v_g1)                  # (B,d)
                z_g2 = online(v_g2)                  # (B,d)
                _, z_g1_p = online(v_g1, for_predictor=True)
                _, z_g2_p = online(v_g2, for_predictor=True)
                if v_l is not None:
                    _, z_l_p = online(v_l, for_predictor=True)
                    z_l = online(v_l)
                else:
                    z_l_p, z_l = None, None

            # Teacher in fp32
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=False):
                    t_g1 = teacher(v_g1)
                    t_g2 = teacher(v_g2)
                    t_l  = teacher(v_l) if v_l is not None else None

            # Sanitize all embeddings to kill NaNs/Infs
            z_g1, z_g2, z_g1_p, z_g2_p, z_l, z_l_p = _sanitize_(z_g1, z_g2, z_g1_p, z_g2_p, z_l, z_l_p)
            t_g1, t_g2, t_l = _sanitize_(t_g1, t_g2, t_l)

            # ----- Losses computed in fp32 for stability -----
            with torch.cuda.amp.autocast(enabled=False):
                # sample a subset of negatives for speed
                bank_feats = bank.get()
                if bank_feats is not None and bank_feats.numel() > 0 and bank_feats.size(0) > bank_sample:
                    idx = torch.randperm(bank_feats.size(0), device=bank_feats.device)[:bank_sample]
                    bank_feats = bank_feats[idx]

                loss_ctr = ntxent.forward_with_bank(
                    z_g1.float(), z_g2.float(), t_g1.float(), t_g2.float(),
                    bank_feats.float() if bank_feats is not None else None
                )
                loss_total = alpha_contrast * loss_ctr

                if eff_beta_siam > 0.0:
                    if v_l is not None:
                        loss_siam = 0.5 * (simsiam_loss(z_g1_p.float(), z_l.detach().float()) +
                                           simsiam_loss(z_l_p.float(),  z_g1.detach().float()))
                    else:
                        loss_siam = 0.5 * (simsiam_loss(z_g1_p.float(), z_g2.detach().float()) +
                                           simsiam_loss(z_g2_p.float(), z_g1.detach().float()))
                    loss_total = loss_total + eff_beta_siam * loss_siam

                if eff_gamma_vicreg > 0.0:
                    if v_l is not None:
                        loss_vcr = vicreg_loss(z_g1.float(), z_l.float(), vic_sim, vic_var, vic_cov)
                    else:
                        loss_vcr = vicreg_loss(z_g1.float(), z_g2.float(), vic_sim, vic_var, vic_cov)
                    loss_total = loss_total + eff_gamma_vicreg * loss_vcr

                loss_total = loss_total / max(1, accum_steps)
                if not torch.isfinite(loss_total):
                    loss_total = torch.nan_to_num(loss_total, nan=0.0, posinf=0.0, neginf=0.0)

            scaler.scale(loss_total).backward()
            running += float(loss_total.item()) * max(1, accum_steps)

            if (it + 1) % max(1, accum_steps) == 0:
                if scaler.is_enabled(): scaler.unscale_(optimizer)
                if clip_grad_norm and clip_grad_norm > 0:
                    nn.utils.clip_grad_norm_(online.parameters(), max_norm=clip_grad_norm)
                scaler.step(optimizer); scaler.update()
                optimizer.zero_grad(set_to_none=True)

                m_now = ema_momentum_at(global_step, total_steps, m_base=ema_m, m_final=0.9995)
                with torch.no_grad():
                    teacher.update(online, m=m_now)

                with torch.no_grad():
                    to_enqueue = [t_g1, t_g2] + ([t_l] if (t_l is not None) else [])
                    cat = torch.cat([t for t in to_enqueue if t is not None], dim=0)
                    bank.enqueue(cat)

                scheduler.step()
                global_step += 1

        with torch.no_grad():
            knn_sim = knn_sim_eval_multicrop(online, val_loader, device, use_backbone=True, K=20)

        cur_lr = optimizer.param_groups[0]['lr']
        print(f"[Epoch {epoch:03d}] loss={running/len(train_loader):.4f} | kNN-sim={knn_sim:.4f} | lr={cur_lr:.2e}")

        if knn_sim > best_sim:
            best_sim = knn_sim
            torch.save(online.state_dict(), save_path)
            print(f"  ✔ Saved BEST encoder weights to {save_path} (kNN-sim={best_sim:.4f})")

        ckpt_path = save_path.replace(".pth", "_ckpt.pth")
        torch.save({
            "online": online.state_dict(),
            "teacher": teacher.teacher.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "scaler": scaler.state_dict() if scaler.is_enabled() else None,
            "epoch": epoch,
            "global_step": global_step,
            "best_sim": best_sim,
            "config": {
                "audio_dir": audio_dir, "epochs": epochs, "batch_size": batch_size, "lr": lr,
                "weight_decay": weight_decay, "temperature": temperature, "accum_steps": accum_steps,
                "ema_m": ema_m, "pos_topk": pos_topk, "pos_thr": pos_thr, "warmup_epochs": warmup_epochs,
                "num_workers": num_workers, "compile_model": compile_model, "clip_grad_norm": clip_grad_norm,
                "seed": seed, "mel_choices": mel_choices, "n_global": n_global, "global_T": global_T,
                "n_local": n_local, "local_T": local_T, "resize_to": resize_to,
                "bank_size": bank_size, "bank_sample": bank_sample,
                "alpha_contrast": alpha_contrast, "beta_siam": beta_siam, "gamma_vicreg": gamma_vicreg,
                "vic_sim": vic_sim, "vic_var": vic_var, "vic_cov": vic_cov
            }
        }, ckpt_path)

    last_path = save_path.replace(".pth", "_last.pth")
    torch.save(online.state_dict(), last_path)
    print(f"Final encoder weights saved to {last_path}")
    return {"best_encoder": save_path, "last_encoder": last_path, "best_knn_sim": best_sim}

# ----------------------------
# Optional: embeddings extraction (one global view)
# ----------------------------
@torch.no_grad()
def extract_embeddings(encoder_ckpt_path, audio_dir, batch_size=128, n_workers=4, device_str=None,
                       resize_to=(128,256), mel_choice=(128,1024,256), T_win=256):
    device = torch.device(device_str or ("cuda" if torch.cuda.is_available() else "cpu"))
    enc = EncoderProj(proj_dim=128, use_predictor=True).to(device).to(memory_format=torch.channels_last)
    enc.load_state_dict(torch.load(encoder_ckpt_path, map_location=device), strict=True)
    enc.eval()

    class _EvalDS(Dataset):
        def __init__(self, audio_dir):
            self.paths = sorted(
                [os.path.join(audio_dir, f) for f in os.listdir(audio_dir)
                 if f.endswith(".npy") and f != "labels.npy"],
                key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
            )
            self.sr=10000; self.resize_to=resize_to
            self.mel_choice = mel_choice; self.T_win=T_win
        def __len__(self): return len(self.paths)
        def _lm(self, y, n_mels,n_fft,hop):
            y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
            mx = float(np.max(np.abs(y))); mx = 1.0 if (not np.isfinite(mx) or mx<1e-8) else mx
            y = (y/mx).astype(np.float32); y_t=torch.from_numpy(y).unsqueeze(0)
            mel = ta.transforms.MelSpectrogram(sample_rate=self.sr,n_fft=n_fft,hop_length=hop,n_mels=n_mels)(y_t)
            mel = torch.nan_to_num(mel, nan=0.0, posinf=0.0, neginf=0.0)
            lm = torch.log(torch.clamp(mel, min=1e-5))
            lm = torch.nan_to_num(lm, nan=0.0, posinf=0.0, neginf=0.0)
            m, s = lm.mean(), lm.std(); s = s if (torch.isfinite(s) and s>=1e-6) else torch.tensor(1e-6, device=lm.device)
            return (lm - m)/s
        def _center(self, lm):
            e = lm.mean(dim=1).squeeze(0); c=int(torch.argmax(e).item())
            T=lm.size(2); s=max(0, min(c-self.T_win//2, T-self.T_win))
            return lm[:,:,s:s+self.T_win] if T>self.T_win else lm
        def __getitem__(self, idx):
            y = np.load(self.paths[idx]).astype(np.float32)
            _ = self._center(self._lm(y,64,1024,512))
            n_mels,n_fft,hop = self.mel_choice
            lm = self._lm(y,n_mels,n_fft,hop)
            crop = self._center(lm)
            if self.resize_to is not None:
                Ft,Tt=self.resize_to
                crop = F.interpolate(crop.unsqueeze(0), size=(Ft,Tt), mode='bilinear', align_corners=False).squeeze(0)
            return crop

    ds = _EvalDS(audio_dir)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=(device.type=="cuda"))

    feats = []
    for x in loader:
        x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        h, _ = enc(x, return_backbone=True)
        feats.append(h.cpu())
    H = torch.cat(feats, dim=0).numpy().astype(np.float32)
    return H

# ----------------------------
# Example
# ----------------------------
# if __name__ == "__main__":
#     # Tip for notebooks: set num_workers=0 to avoid teardown spam; in scripts you can use >0.
#     train_simclr_multipos(
#         audio_dir="/notebooks/dataset_preprocessed",
#         epochs=200,
#         batch_size=256,
#         lr=3e-4,
#         weight_decay=1e-4,
#         temperature=0.1,
#         accum_steps=2,
#         ema_m=0.99,
#         pos_topk=5,
#         pos_thr=0.7,
#         save_path="simclr_ne.pth",
#         warmup_epochs=10,
#         num_workers=4,              # set 0 in notebooks if you see cleanup warnings
#         compile_model=False,        # True on PyTorch 2.x if desired (first iter slower)
#         clip_grad_norm=1.0,
#         seed=42,
#         resume_path=None,
#         mel_choices=((64,1024,512),(128,1024,256),(128,2048,512)),
#         n_global=2, global_T=256,
#         n_local=2,  local_T=96,
#         resize_to=(128,256),
#         bank_size=8192,
#         bank_sample=2048,           # try 2048 first; raise to 4096 if still fast
#         alpha_contrast=1.0,
#         beta_siam=0.1,
#         gamma_vicreg=0.1,
#         vic_sim=25.0, vic_var=25.0, vic_cov=1.0,
#     )


Report
------------

In [3]:
# ===== Unsupervised Report (2% subset, safe mode) =====
# Works with your improved SimCLR EncoderProj(proj_dim=256).
# Deterministic 2% subset, safe single-process DataLoader to avoid hangs,
# progress logs, and dataset sanity checks.
#
# Adds:
# - Spectrogram-space intra-cluster consistency:
#     * Mean spectrogram cosine similarity (↑ better)
#     * Optional DTW distance on time traces (↓ better; requires fastdtw)
# - Global summary card aggregating these per-cluster metrics.

import os, math, contextlib, base64, time
import numpy as np
from collections import Counter, OrderedDict
from io import BytesIO
from typing import Optional, Dict  # NEW

from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score
from scipy.optimize import linear_sum_assignment

# plotting / viz
from umap import UMAP
import plotly.express as px
import plotly.io as pio
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import entropy
from scipy.signal import spectrogram, get_window
import matplotlib.pyplot as plt
from sklearn.metrics import (
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
)

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


def hungarian_accuracy(y_true, y_pred):
    """
    Maximum bipartite matching accuracy between predicted cluster IDs and true labels.
    Works with string labels too.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    true_ids = {t: i for i, t in enumerate(np.unique(y_true))}
    pred_ids = {c: i for i, c in enumerate(np.unique(y_pred))}
    y_true_i = np.vectorize(true_ids.get)(y_true)
    y_pred_i = np.vectorize(pred_ids.get)(y_pred)

    W = np.zeros((len(pred_ids), len(true_ids)), dtype=np.int64)
    for i in range(y_pred_i.size):
        W[y_pred_i[i], y_true_i[i]] += 1

    row_ind, col_ind = linear_sum_assignment(W.max() - W)  # maximize matches
    return float(W[row_ind, col_ind].sum() / y_pred_i.size)


# ------------------------------------------------------------------------------------
# Import your model + (optional) pad_collate from training code
# ------------------------------------------------------------------------------------
# !!! IMPORTANT: change 'your_training_file' to your actual module name !!!
try:
    from your_training_file import EncoderProj, pad_collate  # noqa
    _HAS_PAD_COL = True
    print("[INFO] Imported EncoderProj and pad_collate from your_training_file.")
except Exception as e:
    print(f"[WARN] Could not import from your_training_file: {e}")
    # You MUST still provide EncoderProj somewhere on PYTHONPATH for this script to work.
    # If pad_collate isn't available, we'll fall back to default collate.
    _HAS_PAD_COL = False
    try:
        # If EncoderProj is defined elsewhere / already imported, this won't raise.
        EncoderProj  # type: ignore  # noqa
    except NameError:
        raise RuntimeError(
            "EncoderProj is not imported. Please do:\n"
            "from <your_module> import EncoderProj  # and optionally pad_collate"
        )

    if 'pad_collate' not in globals():
        pad_collate = None  # type: ignore


# ------------------------------------------------------------------------------------
# Dataset for precomputed spectrograms or raw waveforms (.npy files)
# ------------------------------------------------------------------------------------
def _wave_to_spec_1xFxT(y_np: np.ndarray, sr: int = 10000, n_mels: int = 64,
                        n_fft: int = 1024, hop_length: int = 512) -> torch.Tensor:
    y_np = np.nan_to_num(y_np.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
    mx = np.max(np.abs(y_np)); mx = 1.0 if (not np.isfinite(mx) or mx < 1e-8) else mx
    y_t = torch.from_numpy(y_np / mx).unsqueeze(0)  # (1, T)

    try:
        import torchaudio as ta
        mel = ta.transforms.MelSpectrogram(sample_rate=sr, n_fft=n_fft,
                                           hop_length=hop_length, n_mels=n_mels)
        S = mel(y_t)                                # (1, F, T)
        S = torch.clamp(S, min=1e-10).log()
        return S.contiguous().float()
    except Exception:
        from scipy.signal import spectrogram, get_window
        win = get_window('hann', n_fft, fftbins=True)
        f, t, Z = spectrogram(y_np, fs=sr, window=win, nperseg=n_fft,
                              noverlap=n_fft - hop_length, mode='magnitude')
        S = np.log(np.maximum(Z, 1e-10)).astype(np.float32)  # (F_lin, T)
        if S.shape[0] != n_mels:
            idx = np.linspace(0, S.shape[0]-1, num=n_mels).astype(np.int32)
            S = S[idx]
        S = torch.from_numpy(S).unsqueeze(0)
        return S.contiguous().float()


class TwoViewPrecomputed(Dataset):
    def __init__(self, spec_dir, T=256, per_sample_norm=True,
                 sr=10000, n_mels=64, n_fft=1024, hop=512,
                 wave_policy="convert",   # "convert" or "skip"
                 file_paths=None):
        if file_paths is not None:
            self.paths = list(file_paths)
        else:
            paths = []
            for f in os.listdir(spec_dir):
                if f.endswith(".npy") and f != "labels.npy":
                    stem = os.path.splitext(f)[0]
                    try: key = int(stem)
                    except ValueError: key = stem
                    paths.append((key, os.path.join(spec_dir, f)))
            paths.sort(key=lambda t: t[0])
            self.paths = [p for _, p in paths]

        self.T = T
        self.per_sample_norm = per_sample_norm
        self.sr, self.n_mels, self.n_fft, self.hop = sr, n_mels, n_fft, hop
        self.wave_policy = wave_policy

        if self.wave_policy == "skip":
            kept = []
            for p in self.paths:
                arr = np.load(p, mmap_mode="r")
                if arr.ndim == 1:
                    continue
                kept.append(p)
            self.paths = kept

    def _to_1xFxT(self, arr: np.ndarray) -> torch.Tensor:
        arr = np.array(arr, copy=True)
        if arr.ndim == 1:
            if self.wave_policy == "skip":
                raise RuntimeError("Waveform encountered but wave_policy='skip'.")
            return _wave_to_spec_1xFxT(arr, sr=self.sr, n_mels=self.n_mels,
                                       n_fft=self.n_fft, hop_length=self.hop)
        x = torch.from_numpy(arr)
        if x.ndim == 2:
            x = x.unsqueeze(0)
        elif x.ndim == 3:
            if x.shape[0] != 1:
                x = x.reshape(1, x.shape[0]*x.shape[1], x.shape[2])
        else:
            raise RuntimeError(f"Unexpected array shape {tuple(x.shape)}")
        return x.contiguous().float()

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        feat_np = np.load(self.paths[idx], mmap_mode="r")
        feat = self._to_1xFxT(feat_np)
        if self.per_sample_norm:
            m, s = feat.mean(), feat.std()
            s = s if (torch.isfinite(s) and s >= 1e-6) else torch.tensor(1e-6, device=feat.device)
            feat = (feat - m) / s
        x1 = feat
        x2 = feat.clone()
        return x1, x2


# ------------------------------------------------------------------------------------
# File utilities and viz helpers
# ------------------------------------------------------------------------------------
def list_audio_npy_files(folder):
    files = []
    for f in os.listdir(folder):
        if f.endswith(".npy") and f != "labels.npy":
            stem = os.path.splitext(f)[0]
            try: key = int(stem)
            except ValueError: key = stem
            files.append((key, os.path.join(folder, f)))
    files.sort(key=lambda t: t[0])
    return [p for _, p in files]

def load_class_labels_if_any(folder, count):
    lbl_path = os.path.join(folder, "labels.npy")
    if os.path.isfile(lbl_path):
        try:
            arr = np.load(lbl_path, allow_pickle=True)
            if len(arr) < count:
                pad = np.array(["Unknown"] * (count - len(arr)), dtype=object)
                arr = np.concatenate([arr, pad], axis=0)
            elif len(arr) > count:
                arr = arr[:count]
            return arr
        except Exception:
            pass
    return np.array(["Unknown"] * count, dtype=object)

def ali_spec(x: np.ndarray, fs: int):
    Lframe2 = 1000
    po = 80
    lov = int(np.ceil((po / 100) * Lframe2))
    taper = get_window('hann', Lframe2)
    Nfft = 2 ** (int(np.floor(np.log2(Lframe2))) + 2)
    f, t, s = spectrogram(x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode='complex')
    as_ = np.abs(s)
    as_max = np.max(as_) if np.isfinite(np.max(as_)) and np.max(as_) > 0 else 1.0
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx:max_inx+1, :], f[min_inx:max_inx+1], t

def generate_spectrogram_base64(audio, fs=10000, title="Spectrogram"):
    spec, f_axis, t_axis = ali_spec(audio, fs)
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.imshow(spec, aspect='auto', origin='lower',
              extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]], cmap='hsv')
    ax.set_title(title); ax.set_xlabel("Time (s)"); ax.set_ylabel("Frequency (Hz)")
    fig.tight_layout()
    buf = BytesIO(); fig.savefig(buf, format='png'); plt.close(fig); buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" alt="{title}">'

def make_carousel(cluster_id, scope, imgs):
    cid = f"carousel_{scope}_{cluster_id}"
    indicators = "".join(
        f'<button type="button" data-bs-target="#{cid}" data-bs-slide-to="{i}" {"class=active" if i==0 else ""} aria-current="true" aria-label="Slide {i+1}"></button>'
        for i in range(len(imgs))
    )
    items = "".join(
        f'<div class="carousel-item {"active" if i==0 else ""}"><div class="d-flex justify-content-center">{img}</div></div>'
        for i, img in enumerate(imgs)
    )
    return f"""
    <div id="{cid}" class="carousel slide" data-bs-interval="false" data-bs-touch="false">
      <div class="carousel-indicators">{indicators}</div>
      <div class="carousel-inner">{items}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{cid}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{cid}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    """


# ------------------------------------------------------------------------------------
# NEW: spectrogram standardization & intra-cluster consistency metrics
# ------------------------------------------------------------------------------------
def _standardize_spec(S: np.ndarray, target_shape=(128, 256)) -> np.ndarray:
    """
    Pad/crop 2D (F,T) spectrogram to target_shape. Per-spec z-norm for cosine comparability.
    """
    S = np.asarray(S, dtype=np.float32)
    S = (S - np.mean(S)) / (np.std(S) + 1e-8)
    F, T = S.shape
    Ft, Tt = target_shape

    # F dimension
    if F < Ft:
        pad_top = (Ft - F) // 2
        pad_bot = Ft - F - pad_top
        S = np.pad(S, ((pad_top, pad_bot), (0, 0)), mode="constant")
    elif F > Ft:
        start = (F - Ft) // 2
        S = S[start:start+Ft, :]

    # T dimension
    F, T = S.shape
    if T < Tt:
        pad_left = (Tt - T) // 2
        pad_right = Tt - T - pad_left
        S = np.pad(S, ((0, 0), (pad_left, pad_right)), mode="constant")
    elif T > Tt:
        start = (T - Tt) // 2
        S = S[:, start:start+Tt]
    return S

def _load_spec_for_index(original_paths: np.ndarray, idx: int, fs: int = 10000,
                         target_shape=(128, 256)) -> Optional[np.ndarray]:
    """
    Load .npy (1D waveform or 2D spectrogram) and return standardized 2D spectrogram.
    """
    try:
        arr = np.load(original_paths[idx], mmap_mode="r")
        if arr.ndim == 1:
            S, _, _ = ali_spec(arr.astype(np.float32), fs=fs)
        elif arr.ndim == 2:
            S = arr.astype(np.float32)
        else:
            return None
        return _standardize_spec(S, target_shape=target_shape)
    except Exception:
        return None

def _avg_intra_cluster_spec_cosine(original_paths: np.ndarray, idxs: np.ndarray,
                                   *, fs: int = 10000, max_samples: int = 50,
                                   target_shape=(128, 256)) -> float:
    """
    Average pairwise cosine similarity of standardized spectrograms within a cluster.
    Subsamples up to max_samples items for O(n^2) stability.
    """
    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)

    specs = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            specs.append(S.reshape(-1))
    if len(specs) < 2:
        return float("nan")
    X = np.stack(specs, axis=0)
    sim = cosine_similarity(X)
    iu = np.triu_indices_from(sim, k=1)
    return float(np.mean(sim[iu]))

def _avg_intra_cluster_spec_dtw(original_paths: np.ndarray, idxs: np.ndarray,
                                *, fs: int = 10000, max_samples: int = 25,
                                target_shape=(128, 256)) -> float:
    """
    Average pairwise DTW distance between standardized spectrogram time-traces.
    Lower = more similar. Optional: requires fastdtw.
    """
    try:
        from fastdtw import fastdtw
        from scipy.spatial.distance import euclidean
    except Exception:
        return float("nan")

    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)

    series = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            series.append(np.mean(S, axis=0))  # collapse frequency to get a 1D time trace
    if len(series) < 2:
        return float("nan")

    n = len(series)
    total, pairs = 0.0, 0
    for a in range(n):
        for b in range(a+1, n):
            d, _ = fastdtw(series[a], series[b], dist=euclidean)
            total += d
            pairs += 1
    return float(total / max(pairs, 1))


# ------------------------------------------------------------------------------------
# Clustering and metrics
# ------------------------------------------------------------------------------------
def choose_clusterer(algorithm: str, embeddings: np.ndarray, n_clusters: int):
    if algorithm == 'kmeans':
        clusterer = KMeans(n_clusters=n_clusters, n_init='auto').fit(embeddings)
        return clusterer.labels_, None
    elif algorithm == 'agglomerative':
        clusterer = AgglomerativeClustering(n_clusters=n_clusters).fit(embeddings)
        return clusterer.labels_, None
    elif algorithm == 'gmm':
        clusterer = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42).fit(embeddings)
        return clusterer.predict(embeddings), clusterer.predict_proba(embeddings)
    else:
        raise ValueError(f"Unsupported clustering algorithm: {algorithm}")

def evaluate_cluster_metrics(embeddings, idxs, location_labels, location_entropy_base=None):
    X = embeddings[idxs]
    if X.shape[0] == 0:
        return {'variance': 0.0, 'mean_sim': 1.0, 'entropy': 0.0, 'quality': 0.0, 'novelty': 0.0}
    center = np.mean(X, axis=0, keepdims=True)
    variance = float(np.mean(np.sum((X - center) ** 2, axis=1)))
    if len(X) > 1:
        cos_sim = cosine_similarity(X)
        iu = np.triu_indices_from(cos_sim, k=1)
        mean_sim = float(np.mean(cos_sim[iu])) if iu[0].size > 0 else 1.0
    else:
        mean_sim = 1.0
    loc_counts = Counter(location_labels[idxs])
    loc_probs = np.array(list(loc_counts.values()), dtype=np.float32)
    loc_probs /= max(loc_probs.sum(), 1e-8)
    base = int(location_entropy_base or len(set(location_labels)))
    loc_entropy = float(entropy(loc_probs, base=base)) if base > 1 else 0.0
    max_ent = np.log2(base) if base > 1 else 1.0
    entropy_score = 1.0 - (loc_entropy / max_ent) if base > 1 else 1.0
    quality = float((mean_sim / (variance + 1e-8)) * entropy_score)
    novelty = float((loc_entropy / max_ent) * variance) if base > 1 else 0.0
    return {'variance': variance, 'mean_sim': mean_sim, 'entropy': loc_entropy, 'quality': quality, 'novelty': novelty}


# ------------------------------------------------------------------------------------
# Embedding extraction
# ------------------------------------------------------------------------------------
def _load_encoder(encoder_ckpt_path, device):
    enc = EncoderProj(proj_dim=256).to(device).to(memory_format=torch.channels_last)
    ckpt = torch.load(encoder_ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and "online" in ckpt:
        state_dict = ckpt["online"]
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    else:
        state_dict = ckpt
    new_state = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace("encoder.", "").replace("module.", "")
        new_state[new_key] = v
    enc.load_state_dict(new_state, strict=False)
    enc.eval()
    print(f"[INFO] Encoder loaded on {device} • CUDA available: {torch.cuda.is_available()}")
    return enc

def extract_embeddings_from_specdir(
    encoder_ckpt_path,
    spec_dir,
    batch_size=128,
    n_workers=0,          # SAFE: 0 workers to avoid hangs for first run
    device_str=None,
    subset_fraction=0.02, # 2% subset
    subset_seed=42,
    subset_strategy="random",  # "random" | "tail" | "head"
    wave_policy="convert"      # "convert" works for 1-D waveforms; use "skip" if you *know* files are (1,F,T)
):
    """
    Returns:
        H: (N_subset, D) embeddings as float32
        chosen_files: list[str] of file paths used, length N_subset
    """
    print(f"[INFO] Scanning: {spec_dir}")
    all_files = list_audio_npy_files(spec_dir)
    print(f"[INFO] Found {len(all_files)} candidate .npy files")
    if len(all_files) == 0:
        raise RuntimeError(f"No .npy files found in {spec_dir}")

    n_sub = max(1, int(math.ceil(len(all_files) * subset_fraction)))
    if subset_strategy == "random":
        rng = np.random.RandomState(subset_seed)
        chosen_idx = np.sort(rng.choice(len(all_files), size=n_sub, replace=False))
    elif subset_strategy == "tail":
        chosen_idx = np.arange(len(all_files) - n_sub, len(all_files))
    elif subset_strategy == "head":
        chosen_idx = np.arange(0, n_sub)
    else:
        raise ValueError(f"Unknown subset_strategy: {subset_strategy}")

    chosen_files = [all_files[i] for i in chosen_idx]
    print(f"[INFO] Subset size: {len(chosen_files)} (fraction={subset_fraction})")

    device = torch.device(device_str or ("cuda" if torch.cuda.is_available() else "cpu"))
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = False

    enc = _load_encoder(encoder_ckpt_path, device)

    # ---- SAFE DataLoader (no multiprocessing, no persistence) ----
    ds = TwoViewPrecomputed(
        spec_dir,
        T=256,
        per_sample_norm=True,
        sr=10000, n_mels=64, n_fft=1024, hop=512,
        wave_policy=wave_policy,
        file_paths=chosen_files
    )
    if len(ds) == 0:
        raise RuntimeError(
            "No samples in dataset. Likely cause: wave_policy='skip' but files are 1-D waveforms. "
            "Use wave_policy='convert' or precompute spectrograms."
        )

    loader_kwargs = dict(
        batch_size=batch_size,
        shuffle=False,
        num_workers=n_workers,
        pin_memory=(device.type == "cuda" and n_workers == 0),  # conservative
        persistent_workers=False
    )
    if _HAS_PAD_COL and pad_collate is not None:
        loader_kwargs["collate_fn"] = pad_collate

    loader = DataLoader(ds, **loader_kwargs)
    print(f"[INFO] DataLoader ready: {len(ds)} samples • batch_size={batch_size} • workers={n_workers}")

    feats = []
    use_cuda = (device.type == "cuda")
    autocast_ctx = torch.cuda.amp.autocast if use_cuda else contextlib.nullcontext

    t0 = time.perf_counter()
    n_seen = 0
    print(f"[INFO] Starting embedding pass on {device} (AMP={use_cuda})")
    with torch.inference_mode():
        with autocast_ctx():
            for step, (x1, _) in enumerate(loader):
                n_seen += x1.size(0)
                x = x1.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                h, _ = enc(x, return_backbone=True)
                feats.append(h.detach().to("cpu", non_blocking=True).float())
                if use_cuda:
                    torch.cuda.synchronize()
                if (step + 1) % 10 == 0 or (step == 0):
                    dt = time.perf_counter() - t0
                    print(f"[INFO] Batches: {step+1} • Seen: {n_seen} • {n_seen/max(dt,1e-9):.1f} items/s", flush=True)
                del x, h
        if use_cuda:
            torch.cuda.empty_cache()
    emb_time = time.perf_counter() - t0
    print(f"[INFO] Embedding done: {n_seen} items in {emb_time:.2f}s")

    H = torch.cat(feats, dim=0).numpy().astype("float32")
    return H, chosen_files


# ------------------------------------------------------------------------------------
# Main analysis (2% subset) — now with spectrogram consistency
# ------------------------------------------------------------------------------------
def analyze_all_unsupervised_to_html(
    encoder_ckpt_path,
    dataset_paths,
    labels_list,
    cluster_method='kmeans',
    n_clusters=20,
    subset_fraction=0.02,
    subset_seed=42,
    subset_strategy="random",
    wave_policy="convert",
    # ---- NEW knobs for signal-space consistency ----
    signal_consistency_mode: str = "global",   # "off" | "global" | "per_cluster"
    global_consistency_cluster_sample: int = 20,
    spec_cos_max_samples: int = 30,
    spec_dtw_max_samples: int = 12,
    compute_dtw: bool = False,
):
    """
    Builds an HTML section with:
      - UMAP 3D scatter
      - Global internal/external clustering metrics
      - (Optional) Global spectrogram-consistency summary across clusters
      - Per-cluster cards (NO signal-space metrics unless signal_consistency_mode == 'per_cluster')

    The heavy spectrogram pairwise work is skipped for per-cluster cards unless
    signal_consistency_mode='per_cluster'. If 'global', a lightweight pass samples clusters
    (and items per cluster) to compute an overall mean spec-cos (and optional DTW).
    """
    # ---- Gather embeddings from each dataset (2% subset by default) ----
    embeddings_all, loc_labels_all, class_labels_all, file_paths_all = [], [], [], []

    for path, loc_label in zip(dataset_paths, labels_list):
        H, chosen_files = extract_embeddings_from_specdir(
            encoder_ckpt_path, path,
            batch_size=128, n_workers=0,
            subset_fraction=subset_fraction,
            subset_seed=subset_seed,
            subset_strategy=subset_strategy,
            wave_policy=wave_policy
        )
        embeddings_all.append(H)
        file_paths_all.extend(chosen_files)
        loc_labels_all.extend([loc_label] * H.shape[0])

        full_files_sorted = list_audio_npy_files(path)
        cls_full = load_class_labels_if_any(path, count=len(full_files_sorted))
        name_to_label = {os.path.basename(p): cls_full[i] for i, p in enumerate(full_files_sorted)}
        class_labels_all.extend([name_to_label[os.path.basename(p)] for p in chosen_files])

    embeddings = np.vstack(embeddings_all).astype(np.float32)
    location_labels = np.array(loc_labels_all, dtype=object)
    class_labels = np.array(class_labels_all, dtype=object)
    original_paths = np.array(file_paths_all, dtype=object)

    # ---- UMAP ----
    print(f"[INFO] Starting UMAP on {embeddings.shape[0]} items, dim={embeddings.shape[1]}")
    reducer = UMAP(n_components=3, n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42)
    t_umap0 = time.perf_counter()
    proj_3d = reducer.fit_transform(embeddings)
    t_umap1 = time.perf_counter()
    print(f"[INFO] UMAP done in {t_umap1 - t_umap0:.2f}s")

    # ---- Clustering ----
    print(f"[INFO] Clustering with {cluster_method}, k={n_clusters}")
    t_cl0 = time.perf_counter()
    cluster_labels, cluster_probs = choose_clusterer(cluster_method, embeddings, n_clusters)
    t_cl1 = time.perf_counter()
    print(f"[INFO] Clustering done in {t_cl1 - t_cl0:.2f}s")

    # ---- Global clustering quality metrics (internal) ----
    unique_clusters = np.unique(cluster_labels)
    valid_for_metrics = (len(unique_clusters) > 1) and (embeddings.shape[0] > len(unique_clusters))

    def _safe_metric(fn, X, y):
        try:
            return float(fn(X, y)) if valid_for_metrics else float("nan")
        except Exception:
            return float("nan")

    from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
    sil = _safe_metric(silhouette_score, embeddings, cluster_labels)         # ↑ better
    dbi = _safe_metric(davies_bouldin_score, embeddings, cluster_labels)     # ↓ better
    ch  = _safe_metric(calinski_harabasz_score, embeddings, cluster_labels)  # ↑ better

    # ---- External metrics vs labels.npy (if available) ----
    def _safe_external(fn, y_true, y_pred):
        try:
            return float(fn(y_true, y_pred)) if valid_for_metrics else float("nan")
        except Exception:
            return float("nan")

    from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score
    ari  = _safe_external(adjusted_rand_score, class_labels, cluster_labels)         # ↑ better
    ami  = _safe_external(adjusted_mutual_info_score, class_labels, cluster_labels)  # ↑ better

    # Hungarian accuracy (already defined above)
    hacc = _safe_external(hungarian_accuracy, class_labels, cluster_labels)          # ↑ better

    print(f"[INFO] Silhouette (↑): {sil if sil==sil else 'nan'}")
    print(f"[INFO] Davies–Bouldin (↓): {dbi if dbi==dbi else 'nan'}")
    print(f"[INFO] Calinski–Harabasz (↑): {ch if ch==ch else 'nan'}")
    print(f"[INFO] ARI (↑): {ari if ari==ari else 'nan'}")
    print(f"[INFO] AMI (↑): {ami if ami==ami else 'nan'}")
    print(f"[INFO] Hungarian Accuracy (↑): {hacc if hacc==hacc else 'nan'}")

    title_txt = (
        f"Clustering with {cluster_method} "
        f"(Sil={sil:.3f} | DBI={dbi:.3f} | CH={ch:.1f} | "
        f"ARI={ari:.3f} | AMI={ami:.3f} | H-Acc={hacc:.3f})"
    )

    umap_fig = px.scatter_3d(
        x=proj_3d[:, 0], y=proj_3d[:, 1], z=proj_3d[:, 2],
        color=[str(c) for c in cluster_labels],
        symbol=location_labels,
        hover_data={"Cluster": cluster_labels, "Class": class_labels},
        title=title_txt,
        opacity=0.85, height=800
    )
    umap_html = pio.to_html(umap_fig, include_plotlyjs="cdn", full_html=False)

    # ---- Per-cluster cards (embedding-space metrics only) ----
    cluster_html = ""
    cluster_blocks = []
    base_for_entropy = len(set(location_labels))

    # accumulators ONLY when showing per-cluster consistency
    spec_cos_scores, spec_dtw_scores = [], []

    for c in sorted(set(cluster_labels)):
        idxs = np.where(cluster_labels == c)[0]
        if idxs.size == 0:
            continue

        cluster_counts = Counter(location_labels[idxs])
        class_counts = Counter(class_labels[idxs])

        metrics = evaluate_cluster_metrics(
            embeddings, idxs, location_labels, location_entropy_base=base_for_entropy
        )

        # ---- OPTIONAL per-cluster signal-space metrics ----
        if signal_consistency_mode == "per_cluster":
            spec_cos = _avg_intra_cluster_spec_cosine(
                original_paths, idxs, fs=10000, max_samples=spec_cos_max_samples
            )
            spec_dtw = (
                _avg_intra_cluster_spec_dtw(original_paths, idxs, fs=10000, max_samples=spec_dtw_max_samples)
                if compute_dtw else float("nan")
            )
            if np.isfinite(spec_cos): spec_cos_scores.append(spec_cos)
            if np.isfinite(spec_dtw): spec_dtw_scores.append(spec_dtw)
        else:
            spec_cos, spec_dtw = float("nan"), float("nan")

        # ---- Build the HTML for this cluster ----
        meta_html = "<p><strong>Location Distribution:</strong></p><ul>" + "".join(
            f"<li><b>{loc}</b>: {count} ({count/len(idxs):.1%})</li>" for loc, count in cluster_counts.items()
        ) + "</ul>"

        meta_html += "<p><strong>Class Distribution:</strong></p><ul>" + "".join(
            f"<li>{cls}: {count}</li>" for cls, count in class_counts.items()
        ) + "</ul>"

        meta_html += f"""
        <p><strong>Cluster Metrics (Embedding Space):</strong></p>
        <ul>
            <li>Size: {len(idxs)}</li>
            <li>Intra-Cluster Variance: {metrics['variance']:.4f}</li>
            <li>Mean Cosine Similarity (Embeddings): {metrics['mean_sim']:.4f}</li>
            <li>Location Entropy: {metrics['entropy']:.3f}</li>
            <li>Composite Quality Score: {metrics['quality']:.4f}</li>
            <li><strong>Novelty Score:</strong> {metrics['novelty']:.4f}</li>
        </ul>
        """

        if signal_consistency_mode == "per_cluster":
            meta_html += (
                "<p><strong>Spectrogram Consistency (Signal Space):</strong></p>"
                "<ul>"
                f"<li>Mean Spectrogram Cosine Similarity (↑ better): {spec_cos if np.isfinite(spec_cos) else float('nan'):.4f}</li>"
                f"<li>Mean Spectrogram DTW Distance (↓ better): {spec_dtw if np.isfinite(spec_dtw) else float('nan'):.2f}</li>"
                "</ul>"
            )

        # Nearest-to-center exemplars (images)
        center = np.mean(embeddings[idxs], axis=0, keepdims=True)
        distances = np.linalg.norm(embeddings[idxs] - center, axis=1)
        sorted_indices = np.argsort(distances)
        sampled_idxs = idxs[sorted_indices[:min(5, len(sorted_indices))]]

        imgs = []
        for i, chosen_idx in enumerate(sampled_idxs):
            try:
                x = np.load(original_paths[chosen_idx], mmap_mode="r")
                title = f"#{i+1} | {location_labels[chosen_idx]} | Class {class_labels[chosen_idx]}"
                if x.ndim == 1:
                    imgs.append(generate_spectrogram_base64(x.astype(np.float32), title=title))
                else:
                    S = x.squeeze(0)
                    fig, ax = plt.subplots(figsize=(8, 3))
                    ax.imshow(S, aspect='auto', origin='lower', cmap='viridis')
                    ax.set_title(title)
                    ax.set_xlabel("Frames"); ax.set_ylabel("Mel bins")
                    fig.tight_layout()
                    buf = BytesIO(); fig.savefig(buf, format='png'); plt.close(fig); buf.seek(0)
                    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
                    imgs.append(f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" alt="Spec">')
            except Exception as e:
                imgs.append(f"<p class='text-danger'>Error loading sample {i+1}: {e}</p>")

        carousel_html = make_carousel(c, "all", imgs)
        block = f"<div class='col-md-6 mb-4'><h4>Cluster {c}</h4>{meta_html}{carousel_html}</div>"
        cluster_blocks.append(block)

    for i in range(0, len(cluster_blocks), 2):
        cluster_html += "<div class='row'>" + "".join(cluster_blocks[i:i+2]) + "</div>"

    # ---- Global spectrogram-consistency summary across clusters ----
    def _consistency_for_idxs(idxs):
        cos = _avg_intra_cluster_spec_cosine(
            original_paths, idxs, fs=10000, max_samples=spec_cos_max_samples
        )
        dtw = (
            _avg_intra_cluster_spec_dtw(original_paths, idxs, fs=10000, max_samples=spec_dtw_max_samples)
            if compute_dtw else float("nan")
        )
        return cos, dtw

    if signal_consistency_mode == "per_cluster":
        global_spec_cos = float(np.nanmean(spec_cos_scores)) if len(spec_cos_scores) else float("nan")
        global_spec_dtw = float(np.nanmean(spec_dtw_scores)) if len(spec_dtw_scores) else float("nan")
    elif signal_consistency_mode == "global":
        rng = np.random.RandomState(42)
        all_clusters = sorted(set(cluster_labels))
        if isinstance(global_consistency_cluster_sample, int) and global_consistency_cluster_sample > 0:
            sampled_clusters = rng.choice(
                all_clusters,
                size=min(global_consistency_cluster_sample, len(all_clusters)),
                replace=False
            )
        else:
            sampled_clusters = all_clusters

        cos_list, dtw_list = [], []
        for c in sampled_clusters:
            idxs = np.where(cluster_labels == c)[0]
            if len(idxs) < 2:
                continue
            cos, dtw = _consistency_for_idxs(idxs)
            if np.isfinite(cos): cos_list.append(cos)
            if np.isfinite(dtw): dtw_list.append(dtw)
        global_spec_cos = float(np.nanmean(cos_list)) if len(cos_list) else float("nan")
        global_spec_dtw = float(np.nanmean(dtw_list)) if len(dtw_list) else float("nan")
    else:
        global_spec_cos, global_spec_dtw = float("nan"), float("nan")

    summary_card = f"""
    <div class='row mb-4'>
      <div class='col-md-12'>
        <div class='alert alert-info' role='alert'>
          <h5 class='mb-2'>Spectrogram Consistency (Cluster Averages)</h5>
          <ul class='mb-0'>
            <li><strong>Mean Spectrogram Cosine Similarity</strong>: {global_spec_cos if np.isfinite(global_spec_cos) else float('nan'):.4f} (↑ better)</li>
            <li><strong>Mean Spectrogram DTW Distance</strong>: {global_spec_dtw if np.isfinite(global_spec_dtw) else float('nan'):.2f} (↓ better)</li>
          </ul>
          <small>Note: Cosine on standardized spectrograms (z-norm, padded/cropped to common shape). DTW optional and subsampled. Per-cluster consistency is omitted unless mode='per_cluster'.</small>
        </div>
      </div>
    </div>
    """

    # ---- Final section HTML ----
    return f"""
    <div class='section'>
      <h2>{cluster_method.capitalize()} Clustering Analysis ({int(subset_fraction*100)}% subset)</h2>
      <div class="container">
        <div class="row justify-content-center mb-4">
          <div class="col-md-12 d-flex justify-content-center">{umap_html}</div>
        </div>
        {summary_card if signal_consistency_mode in ("global", "per_cluster") else ""}
      </div>
      {cluster_html}
    </div>
    """

# ------------------------------------------------------------------------------------
# Report generator
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Report generator (UPDATED)
# ------------------------------------------------------------------------------------
def generate_full_report(
    encoder_ckpt_path,
    dataset_paths,
    labels_list,
    cluster_method='kmeans',
    n_clusters=20,
    out_prefix="unsup_report",
    subset_fraction=0.02,
    subset_seed=42,
    subset_strategy="random",
    wave_policy="convert",
    # --- NEW passthrough args for signal-space consistency ---
    signal_consistency_mode: str = "global",       # "off" | "global" | "per_cluster"
    global_consistency_cluster_sample: int | None = 20,  # None = use ALL clusters
    spec_cos_max_samples: int = 30,                # items per cluster for cosine
    spec_dtw_max_samples: int = 12,                # items per cluster for DTW
    compute_dtw: bool = False,                     # set True to include DTW
):
    section_html = analyze_all_unsupervised_to_html(
        encoder_ckpt_path=encoder_ckpt_path,
        dataset_paths=list(dataset_paths),
        labels_list=list(labels_list),
        cluster_method=cluster_method,
        n_clusters=int(n_clusters),
        subset_fraction=subset_fraction,
        subset_seed=subset_seed,
        subset_strategy=subset_strategy,
        wave_policy=wave_policy,
        # --- forward the new args ---
        signal_consistency_mode=signal_consistency_mode,
        global_consistency_cluster_sample=global_consistency_cluster_sample,
        spec_cos_max_samples=spec_cos_max_samples,
        spec_dtw_max_samples=spec_dtw_max_samples,
        compute_dtw=compute_dtw,
    )
    html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Unsupervised Clustering Report ({int(subset_fraction*100)}% subset)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body {{ font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }}
            h1 {{ color: #2c3e50; }}
            h2, h4 {{ color: #34495e; }}
            hr {{ border-top: 2px solid #bbb; margin-top: 40px; margin-bottom: 40px; }}
            .section {{ margin-bottom: 60px; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
        </style>
    </head>
    <body>
      <h1 class='mb-4'>Unsupervised Latent Clustering Report ({int(subset_fraction*100)}% subset)</h1>
      <p><strong>Compared Locations:</strong> {', '.join(labels_list)}</p>
      <p><strong>Subset:</strong> {int(subset_fraction*100)}% • Strategy: {subset_strategy} • Seed: {subset_seed}</p>
      <hr>
      {section_html}
    </body></html>
    """
    output_path = f"{out_prefix}_{cluster_method}_subset{int(subset_fraction*100)}.html"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html)
    return output_path


# ------------------------------------------------------------------------------------
# Example usage
# ------------------------------------------------------------------------------------
if __name__ == "__main__":
    # !!! IMPORTANT: Ensure EncoderProj is importable (see the try/except at top) !!!
    encoder_ckpt_path = "reef_ssl_precomp_ckpt.pth"

    dataset_paths = [
        #"/notebooks/2017_mona_elbow_U1276"
        "/notebooks/MX_PA_2022_preprocessed"
    ]
    labels_list = ["MX_PA_2022_preprocessed"] #, "mona_U1274"]

    report_path = generate_full_report(
        encoder_ckpt_path=encoder_ckpt_path,
        dataset_paths=dataset_paths,
        labels_list=labels_list,
        cluster_method='kmeans',
        n_clusters=60,
        out_prefix="PR_k60",
        subset_fraction=1,
        subset_seed=45,
        subset_strategy="random",
        wave_policy="convert",

        # --- keep only the global summary, computed like before ---
        signal_consistency_mode="global",       # no per-cluster section, keep global card
        global_consistency_cluster_sample=None, # use ALL clusters (no sampling)
        spec_cos_max_samples=50,                # match old per-cluster cap for cosine
        spec_dtw_max_samples=25,                # match old per-cluster cap for DTW
        compute_dtw=True                        # DTW was on before; set False to speed up
    )

    print("✅ Report saved to:", report_path)

2025-09-17 22:07:56.607681: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-17 22:07:56.607742: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-17 22:07:56.608941: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-09-17 22:07:56.615586: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[WARN] Could not import from your_training_file: No module named 'your_training_file'
[INFO] Scanning: /notebooks/MX_PA_2022_preprocessed
[INFO] Found 211770 candidate .npy files
[INFO] Subset size: 211770 (fraction=1)
[INFO] Encoder loaded on cuda • CUDA available: True
[INFO] DataLoader ready: 211770 samples • batch_size=128 • workers=0
[INFO] Starting embedding pass on cuda (AMP=True)
[INFO] Batches: 1 • Seen: 128 • 71.0 items/s
[INFO] Batches: 10 • Seen: 1280 • 76.5 items/s
[INFO] Batches: 20 • Seen: 2560 • 87.7 items/s
[INFO] Batches: 30 • Seen: 3840 • 72.8 items/s
[INFO] Batches: 40 • Seen: 5120 • 65.9 items/s
[INFO] Batches: 50 • Seen: 6400 • 69.5 items/s
[INFO] Batches: 60 • Seen: 7680 • 74.9 items/s
[INFO] Batches: 70 • Seen: 8960 • 77.7 items/s
[INFO] Batches: 80 • Seen: 10240 • 78.3 items/s
[INFO] Batches: 90 • Seen: 11520 • 70.6 items/s
[INFO] Batches: 100 • Seen: 12800 • 64.3 items/s
[INFO] Batches: 110 • Seen: 14080 • 61.3 items/s
[INFO] Batches: 120 • Seen: 15360 • 63.2 

  warn(


[INFO] UMAP done in 206.89s
[INFO] Clustering with kmeans, k=60
[INFO] Clustering done in 42.97s
[INFO] Silhouette (↑): 0.11465297639369965
[INFO] Davies–Bouldin (↓): 1.6620144583289869
[INFO] Calinski–Harabasz (↑): 25601.323936378536
[INFO] ARI (↑): 0.0
[INFO] AMI (↑): 0.0
[INFO] Hungarian Accuracy (↑): 0.07170515181564906
✅ Report saved to: PR_k60_kmeans_subset100.html


Extensive Report
----------

In [4]:
# ===== Unsupervised Report (2% subset, safe mode, enhanced) =====
# Adds: L2-normalized clustering, UMAP-by-cluster & by-GT,
# k-scan (Sil/DBI/CH), stability vs seed, retrieval@k,
# per-cluster silhouette & centroid-distance violins,
# and your spectrogram-consistency metrics + exemplar carousels.
#
# Usage:
#   output_path = generate_full_report(
#       encoder_ckpt_path="...pth",
#       dataset_paths=["/path/to/spec_or_wave_npy_dir"],
#       labels_list=["SITE1"],    # label per dataset_paths entry
#       cluster_method="kmeans",  # 'kmeans'|'gmm'|'agglomerative'
#       n_clusters=60,
#       out_prefix="unsup_report",
#       subset_fraction=0.02, subset_seed=42, subset_strategy="random",
#       wave_policy="convert",    # 'convert' if files are waveforms; 'skip' if already (F,T)
#   )

import os, math, contextlib, base64, time, io
import numpy as np
from collections import Counter, OrderedDict
from io import BytesIO
from typing import Optional, Dict, Sequence

from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, fowlkes_mallows_score, silhouette_samples
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from scipy.optimize import linear_sum_assignment

# plotting / viz
from umap import UMAP
import plotly.express as px
import plotly.io as pio
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import entropy
from scipy.signal import spectrogram, get_window
import matplotlib.pyplot as plt
from sklearn.metrics import (
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
)

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ------------------------- small helpers -------------------------
def _fig_to_img_html(fig) -> str:
    """Convert a Matplotlib figure to an inline <img> (base64 PNG)."""
    buf = io.BytesIO()
    fig.tight_layout()
    fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    b64 = base64.b64encode(buf.read()).decode("utf-8")
    return f'<img class="img-fluid" src="data:image/png;base64,{b64}" alt="plot">'


def hungarian_accuracy(y_true, y_pred):
    """Max bipartite matching accuracy between predicted cluster IDs and true labels."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    true_ids = {t: i for i, t in enumerate(np.unique(y_true))}
    pred_ids = {c: i for i, c in enumerate(np.unique(y_pred))}
    y_true_i = np.vectorize(true_ids.get)(y_true)
    y_pred_i = np.vectorize(pred_ids.get)(y_pred)
    W = np.zeros((len(pred_ids), len(true_ids)), dtype=np.int64)
    for i in range(y_pred_i.size):
        W[y_pred_i[i], y_true_i[i]] += 1
    row_ind, col_ind = linear_sum_assignment(W.max() - W)
    return float(W[row_ind, col_ind].sum() / y_pred_i.size)


# ------------------------------------------------------------------------------------
# Import your model + (optional) pad_collate from training code
# ------------------------------------------------------------------------------------
# !!! IMPORTANT: change 'your_training_file' to your actual module name !!!
try:
    from your_training_file import EncoderProj, pad_collate  # noqa
    _HAS_PAD_COL = True
    print("[INFO] Imported EncoderProj and pad_collate from your_training_file.")
except Exception as e:
    print(f"[WARN] Could not import from your_training_file: {e}")
    _HAS_PAD_COL = False
    try:
        EncoderProj  # type: ignore
    except NameError:
        raise RuntimeError(
            "EncoderProj is not importable. Provide it via:\n"
            "from <your_module> import EncoderProj  # (and optionally pad_collate)"
        )
    if 'pad_collate' not in globals():
        pad_collate = None  # type: ignore


# ------------------------------------------------------------------------------------
# Dataset for precomputed spectrograms or raw waveforms (.npy files)
# ------------------------------------------------------------------------------------
def _wave_to_spec_1xFxT(y_np: np.ndarray, sr: int = 10000, n_mels: int = 64,
                        n_fft: int = 1024, hop_length: int = 512) -> torch.Tensor:
    y_np = np.nan_to_num(y_np.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
    mx = np.max(np.abs(y_np)); mx = 1.0 if (not np.isfinite(mx) or mx < 1e-8) else mx
    y_t = torch.from_numpy(y_np / mx).unsqueeze(0)  # (1, T)
    try:
        import torchaudio as ta
        mel = ta.transforms.MelSpectrogram(sample_rate=sr, n_fft=n_fft,
                                           hop_length=hop_length, n_mels=n_mels)
        S = mel(y_t)                                # (1, F, T)
        S = torch.clamp(S, min=1e-10).log()
        return S.contiguous().float()
    except Exception:
        from scipy.signal import spectrogram, get_window
        win = get_window('hann', n_fft, fftbins=True)
        f, t, Z = spectrogram(y_np, fs=sr, window=win, nperseg=n_fft,
                              noverlap=n_fft - hop_length, mode='magnitude')
        S = np.log(np.maximum(Z, 1e-10)).astype(np.float32)  # (F_lin, T)
        if S.shape[0] != n_mels:
            idx = np.linspace(0, S.shape[0]-1, num=n_mels).astype(np.int32)
            S = S[idx]
        S = torch.from_numpy(S).unsqueeze(0)
        return S.contiguous().float()


class TwoViewPrecomputed(Dataset):
    def __init__(self, spec_dir, T=256, per_sample_norm=True,
                 sr=10000, n_mels=64, n_fft=1024, hop=512,
                 wave_policy="convert",   # "convert" or "skip"
                 file_paths=None):
        if file_paths is not None:
            self.paths = list(file_paths)
        else:
            paths = []
            for f in os.listdir(spec_dir):
                if f.endswith(".npy") and f != "labels.npy":
                    stem = os.path.splitext(f)[0]
                    try: key = int(stem)
                    except ValueError: key = stem
                    paths.append((key, os.path.join(spec_dir, f)))
            paths.sort(key=lambda t: t[0])
            self.paths = [p for _, p in paths]

        self.T = T
        self.per_sample_norm = per_sample_norm
        self.sr, self.n_mels, self.n_fft, self.hop = sr, n_mels, n_fft, hop
        self.wave_policy = wave_policy

        if self.wave_policy == "skip":
            kept = []
            for p in self.paths:
                arr = np.load(p, mmap_mode="r")
                if arr.ndim == 1:
                    continue
                kept.append(p)
            self.paths = kept

    def _to_1xFxT(self, arr: np.ndarray) -> torch.Tensor:
        arr = np.array(arr, copy=True)
        if arr.ndim == 1:
            if self.wave_policy == "skip":
                raise RuntimeError("Waveform encountered but wave_policy='skip'.")
            return _wave_to_spec_1xFxT(arr, sr=self.sr, n_mels=self.n_mels,
                                       n_fft=self.n_fft, hop_length=self.hop)
        x = torch.from_numpy(arr)
        if x.ndim == 2:
            x = x.unsqueeze(0)
        elif x.ndim == 3:
            if x.shape[0] != 1:
                x = x.reshape(1, x.shape[0]*x.shape[1], x.shape[2])
        else:
            raise RuntimeError(f"Unexpected array shape {tuple(x.shape)}")
        return x.contiguous().float()

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        feat_np = np.load(self.paths[idx], mmap_mode="r")
        feat = self._to_1xFxT(feat_np)
        if self.per_sample_norm:
            m, s = feat.mean(), feat.std()
            s = s if (torch.isfinite(s) and s >= 1e-6) else torch.tensor(1e-6, device=feat.device)
            feat = (feat - m) / s
        x1 = feat
        x2 = feat.clone()
        return x1, x2


# ------------------------------------------------------------------------------------
# File utilities and viz helpers
# ------------------------------------------------------------------------------------
def list_audio_npy_files(folder):
    files = []
    for f in os.listdir(folder):
        if f.endswith(".npy") and f != "labels.npy":
            stem = os.path.splitext(f)[0]
            try: key = int(stem)
            except ValueError: key = stem
            files.append((key, os.path.join(folder, f)))
    files.sort(key=lambda t: t[0])
    return [p for _, p in files]

def load_class_labels_if_any(folder, count):
    lbl_path = os.path.join(folder, "labels.npy")
    if os.path.isfile(lbl_path):
        try:
            arr = np.load(lbl_path, allow_pickle=True)
            if len(arr) < count:
                pad = np.array(["Unknown"] * (count - len(arr)), dtype=object)
                arr = np.concatenate([arr, pad], axis=0)
            elif len(arr) > count:
                arr = arr[:count]
            return arr
        except Exception:
            pass
    return np.array(["Unknown"] * count, dtype=object)

def ali_spec(x: np.ndarray, fs: int):
    Lframe2 = 1000
    po = 80
    lov = int(np.ceil((po / 100) * Lframe2))
    taper = get_window('hann', Lframe2)
    Nfft = 2 ** (int(np.floor(np.log2(Lframe2))) + 2)
    f, t, s = spectrogram(x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode='complex')
    as_ = np.abs(s)
    as_max = np.max(as_) if np.isfinite(np.max(as_)) and np.max(as_) > 0 else 1.0
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx:max_inx+1, :], f[min_inx:max_inx+1], t

def generate_spectrogram_base64(audio, fs=10000, title="Spectrogram"):
    spec, f_axis, t_axis = ali_spec(audio, fs)
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.imshow(spec, aspect='auto', origin='lower',
              extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]], cmap='hsv')
    ax.set_title(title); ax.set_xlabel("Time (s)"); ax.set_ylabel("Frequency (Hz)")
    html = _fig_to_img_html(fig)
    return html

def make_carousel(cluster_id, scope, imgs):
    cid = f"carousel_{scope}_{cluster_id}"
    indicators = "".join(
        f'<button type="button" data-bs-target="#{cid}" data-bs-slide-to="{i}" {"class=active" if i==0 else ""} aria-current="true" aria-label="Slide {i+1}"></button>'
        for i in range(len(imgs))
    )
    items = "".join(
        f'<div class="carousel-item {"active" if i==0 else ""}"><div class="d-flex justify-content-center">{img}</div></div>'
        for i, img in enumerate(imgs)
    )
    return f"""
    <div id="{cid}" class="carousel slide" data-bs-interval="false" data-bs-touch="false">
      <div class="carousel-indicators">{indicators}</div>
      <div class="carousel-inner">{items}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{cid}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{cid}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    """


# ------------------------------------------------------------------------------------
# Spectrogram standardization & intra-cluster consistency metrics
# ------------------------------------------------------------------------------------
def _standardize_spec(S: np.ndarray, target_shape=(128, 256)) -> np.ndarray:
    S = np.asarray(S, dtype=np.float32)
    S = (S - np.mean(S)) / (np.std(S) + 1e-8)
    F, T = S.shape
    Ft, Tt = target_shape
    if F < Ft:
        pad_top = (Ft - F) // 2
        pad_bot = Ft - F - pad_top
        S = np.pad(S, ((pad_top, pad_bot), (0, 0)), mode="constant")
    elif F > Ft:
        start = (F - Ft) // 2
        S = S[start:start+Ft, :]
    F, T = S.shape
    if T < Tt:
        pad_left = (Tt - T) // 2
        pad_right = Tt - T - pad_left
        S = np.pad(S, ((0, 0), (pad_left, pad_right)), mode="constant")
    elif T > Tt:
        start = (T - Tt) // 2
        S = S[:, start:start+Tt]
    return S

def _load_spec_for_index(original_paths: np.ndarray, idx: int, fs: int = 10000,
                         target_shape=(128, 256)) -> Optional[np.ndarray]:
    try:
        arr = np.load(original_paths[idx], mmap_mode="r")
        if arr.ndim == 1:
            S, _, _ = ali_spec(arr.astype(np.float32), fs=fs)
        elif arr.ndim == 2:
            S = arr.astype(np.float32)
        else:
            return None
        return _standardize_spec(S, target_shape=target_shape)
    except Exception:
        return None

def _avg_intra_cluster_spec_cosine(original_paths: np.ndarray, idxs: np.ndarray,
                                   *, fs: int = 10000, max_samples: int = 50,
                                   target_shape=(128, 256)) -> float:
    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)
    specs = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            specs.append(S.reshape(-1))
    if len(specs) < 2:
        return float("nan")
    X = np.stack(specs, axis=0)
    sim = cosine_similarity(X)
    iu = np.triu_indices_from(sim, k=1)
    return float(np.mean(sim[iu]))

def _avg_intra_cluster_spec_dtw(original_paths: np.ndarray, idxs: np.ndarray,
                                *, fs: int = 10000, max_samples: int = 25,
                                target_shape=(128, 256)) -> float:
    try:
        from fastdtw import fastdtw
        from scipy.spatial.distance import euclidean
    except Exception:
        return float("nan")
    if len(idxs) < 2:
        return float("nan")
    if len(idxs) > max_samples:
        rng = np.random.RandomState(42)
        idxs = rng.choice(idxs, size=max_samples, replace=False)
    series = []
    for i in idxs:
        S = _load_spec_for_index(original_paths, int(i), fs=fs, target_shape=target_shape)
        if S is not None:
            series.append(np.mean(S, axis=0))
    if len(series) < 2:
        return float("nan")
    n = len(series); total, pairs = 0.0, 0
    for a in range(n):
        for b in range(a+1, n):
            d, _ = fastdtw(series[a], series[b], dist=euclidean)
            total += d; pairs += 1
    return float(total / max(pairs, 1))


# ------------------------------------------------------------------------------------
# Clustering and metrics
# ------------------------------------------------------------------------------------
def choose_clusterer(algorithm: str, embeddings: np.ndarray, n_clusters: int, seed: int = 42):
    if algorithm == 'kmeans':
        clusterer = KMeans(n_clusters=n_clusters, n_init=20, random_state=seed).fit(embeddings)
        return clusterer.labels_, None
    elif algorithm == 'agglomerative':
        clusterer = AgglomerativeClustering(n_clusters=n_clusters).fit(embeddings)
        return clusterer.labels_, None
    elif algorithm == 'gmm':
        clusterer = GaussianMixture(n_components=n_clusters, covariance_type='diag', random_state=seed, max_iter=300).fit(embeddings)
        return clusterer.predict(embeddings), clusterer.predict_proba(embeddings)
    else:
        raise ValueError(f"Unsupported clustering algorithm: {algorithm}")

def evaluate_cluster_metrics(embeddings, idxs, location_labels, location_entropy_base=None):
    X = embeddings[idxs]
    if X.shape[0] == 0:
        return {'variance': 0.0, 'mean_sim': 1.0, 'entropy': 0.0, 'quality': 0.0, 'novelty': 0.0}
    center = np.mean(X, axis=0, keepdims=True)
    variance = float(np.mean(np.sum((X - center) ** 2, axis=1)))
    if len(X) > 1:
        cos_sim = cosine_similarity(X)
        iu = np.triu_indices_from(cos_sim, k=1)
        mean_sim = float(np.mean(cos_sim[iu])) if iu[0].size > 0 else 1.0
    else:
        mean_sim = 1.0
    loc_counts = Counter(location_labels[idxs])
    loc_probs = np.array(list(loc_counts.values()), dtype=np.float32)
    loc_probs /= max(loc_probs.sum(), 1e-8)
    base = int(location_entropy_base or len(set(location_labels)))
    loc_entropy = float(entropy(loc_probs, base=base)) if base > 1 else 0.0
    max_ent = np.log2(base) if base > 1 else 1.0
    entropy_score = 1.0 - (loc_entropy / max_ent) if base > 1 else 1.0
    quality = float((mean_sim / (variance + 1e-8)) * entropy_score)
    novelty = float((loc_entropy / max_ent) * variance) if base > 1 else 0.0
    return {'variance': variance, 'mean_sim': mean_sim, 'entropy': loc_entropy, 'quality': quality, 'novelty': novelty}

def _safe_metric(fn, *a, **k) -> float:
    try:
        return float(fn(*a, **k))
    except Exception:
        return float("nan")

def _coassign_jaccard(y1: np.ndarray, y2: np.ndarray) -> float:
    """Jaccard of same-cluster co-assignment (on the same index set)."""
    n = len(y1)
    inter = 0; union = 0
    for i in range(n):
        yi = y1[i]
        for j in range(i+1, n):
            a = (yi == y1[j]); b = (y2[i] == y2[j])
            inter += int(a and b); union += int(a or b)
    return float(inter / max(union, 1))


# ------------------------------------------------------------------------------------
# Embedding extraction
# ------------------------------------------------------------------------------------
def _load_encoder(encoder_ckpt_path, device):
    enc = EncoderProj(proj_dim=256).to(device).to(memory_format=torch.channels_last)
    ckpt = torch.load(encoder_ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and "online" in ckpt:
        state_dict = ckpt["online"]
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    else:
        state_dict = ckpt
    new_state = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace("encoder.", "").replace("module.", "")
        new_state[new_key] = v
    enc.load_state_dict(new_state, strict=False)
    enc.eval()
    print(f"[INFO] Encoder loaded on {device} • CUDA available: {torch.cuda.is_available()}")
    return enc

def extract_embeddings_from_specdir(
    encoder_ckpt_path,
    spec_dir,
    batch_size=128,
    n_workers=0,          # SAFE: 0 workers to avoid hangs for first run
    device_str=None,
    subset_fraction=0.02, # 2% subset
    subset_seed=42,
    subset_strategy="random",  # "random" | "tail" | "head"
    wave_policy="convert"      # "convert" for waveforms; "skip" if files are (F,T) specs
):
    """
    Returns:
        H: (N_subset, D) embeddings as float32
        chosen_files: list[str] of file paths used, length N_subset
    """
    print(f"[INFO] Scanning: {spec_dir}")
    all_files = list_audio_npy_files(spec_dir)
    print(f"[INFO] Found {len(all_files)} candidate .npy files")
    if len(all_files) == 0:
        raise RuntimeError(f"No .npy files found in {spec_dir}")

    n_sub = max(1, int(math.ceil(len(all_files) * subset_fraction)))
    if subset_strategy == "random":
        rng = np.random.RandomState(subset_seed)
        chosen_idx = np.sort(rng.choice(len(all_files), size=n_sub, replace=False))
    elif subset_strategy == "tail":
        chosen_idx = np.arange(len(all_files) - n_sub, len(all_files))
    elif subset_strategy == "head":
        chosen_idx = np.arange(0, n_sub)
    else:
        raise ValueError(f"Unknown subset_strategy: {subset_strategy}")

    chosen_files = [all_files[i] for i in chosen_idx]
    print(f"[INFO] Subset size: {len(chosen_files)} (fraction={subset_fraction})")

    device = torch.device(device_str or ("cuda" if torch.cuda.is_available() else "cpu"))
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = False

    enc = _load_encoder(encoder_ckpt_path, device)

    ds = TwoViewPrecomputed(
        spec_dir,
        T=256,
        per_sample_norm=True,
        sr=10000, n_mels=64, n_fft=1024, hop=512,
        wave_policy=wave_policy,
        file_paths=chosen_files
    )
    if len(ds) == 0:
        raise RuntimeError(
            "No samples in dataset. Likely cause: wave_policy='skip' but files are 1-D waveforms. "
            "Use wave_policy='convert' or precompute spectrograms."
        )

    loader_kwargs = dict(
        batch_size=batch_size,
        shuffle=False,
        num_workers=n_workers,
        pin_memory=(device.type == "cuda" and n_workers == 0),
        persistent_workers=False
    )
    if _HAS_PAD_COL and pad_collate is not None:
        loader_kwargs["collate_fn"] = pad_collate

    loader = DataLoader(ds, **loader_kwargs)
    print(f"[INFO] DataLoader ready: {len(ds)} samples • batch_size={batch_size} • workers={n_workers}")

    feats = []
    use_cuda = (device.type == "cuda")
    autocast_ctx = torch.cuda.amp.autocast if use_cuda else contextlib.nullcontext

    t0 = time.perf_counter()
    n_seen = 0
    print(f"[INFO] Starting embedding pass on {device} (AMP={use_cuda})")
    with torch.inference_mode():
        with autocast_ctx():
            for step, (x1, _) in enumerate(loader):
                n_seen += x1.size(0)
                x = x1.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                h, _ = enc(x, return_backbone=True)
                feats.append(h.detach().to("cpu", non_blocking=True).float())
                if use_cuda:
                    torch.cuda.synchronize()
                if (step + 1) % 10 == 0 or (step == 0):
                    dt = time.perf_counter() - t0
                    print(f"[INFO] Batches: {step+1} • Seen: {n_seen} • {n_seen/max(dt,1e-9):.1f} items/s", flush=True)
                del x, h
        if use_cuda:
            torch.cuda.empty_cache()
    emb_time = time.perf_counter() - t0
    print(f"[INFO] Embedding done: {n_seen} items in {emb_time:.2f}s")

    H = torch.cat(feats, dim=0).numpy().astype("float32")
    return H, chosen_files


# ------------------------------------------------------------------------------------
# New helpers for k-scan, stability, retrieval, violins
# ------------------------------------------------------------------------------------
def _k_scan_curves(Xc: np.ndarray, method: str, ks=(20,30,40,50,60,80,100)):
    rows = []
    for kk in ks:
        yk, _ = choose_clusterer(method, Xc, kk, seed=42)
        rows.append([
            kk,
            _safe_metric(silhouette_score, Xc, yk, metric='euclidean'),
            _safe_metric(davies_bouldin_score, Xc, yk),
            _safe_metric(calinski_harabasz_score, Xc, yk)
        ])
    arr = np.array(rows, dtype=float)
    # Make three small figures
    figs = []
    for col, ylabel, title in [(1,"Silhouette (↑)","k-scan Silhouette"),
                               (2,"Davies–Bouldin (↓)","k-scan DBI"),
                               (3,"Calinski–Harabasz (↑)","k-scan CH")]:
        fig = plt.figure(figsize=(6.5,4.2))
        plt.plot(arr[:,0], arr[:,col], marker='o')
        plt.xlabel("k"); plt.ylabel(ylabel); plt.title(title)
        figs.append(fig)
    return arr, [_fig_to_img_html(f) for f in figs]

def _stability_vs_seed(Xc: np.ndarray, method: str, k: int, seeds=(0,1,2,3,4), subsample=2000):
    N = len(Xc); M = min(N, subsample)
    rng = np.random.RandomState(123)
    idx = np.sort(rng.choice(N, size=M, replace=False))
    Xs = Xc[idx]
    labels = []
    for s in seeds:
        y, _ = choose_clusterer(method, Xs, k, seed=int(s))
        labels.append(y)
    j_list, fmi_list = [], []
    for i in range(len(seeds)):
        for j in range(i+1, len(seeds)):
            j_list.append(_coassign_jaccard(labels[i], labels[j]))
            fmi_list.append(fowlkes_mallows_score(labels[i], labels[j]))
    jacc = float(np.mean(j_list)); fmi = float(np.mean(fmi_list))
    # bar fig
    fig = plt.figure(figsize=(5.5,3.8))
    plt.bar([0,1],[jacc,fmi])
    plt.xticks([0,1],["Jaccard\n(co-assign.)","Fowlkes–\nMallows"])
    plt.ylim(0,1)
    plt.title(f"Stability @ k={k} • subsample={M}")
    return {"jaccard": jacc, "fmi": fmi}, _fig_to_img_html(fig)

def _retrieval_at_k(Xn: np.ndarray, y_true: np.ndarray, ks=(1,5,10)):
    mask = np.array([(t is not None) and (str(t).lower() != "unknown") for t in y_true])
    idxs = np.where(mask)[0]
    if idxs.size < 2:
        return {k: float("nan") for k in ks}, None
    Xq = Xn[idxs]; yq = y_true[idxs].astype(object)
    S = Xq @ Xq.T
    np.fill_diagonal(S, -np.inf)
    out = {}
    for k in ks:
        topk = np.argpartition(-S, kth=min(k, S.shape[1]-1), axis=1)[:, :k]
        hits = 0
        for i in range(len(idxs)):
            hits += int(np.sum(yq[topk[i]] == yq[i]))
        out[k] = float(hits / (len(idxs) * k))
    # bar
    fig = plt.figure(figsize=(5.5,3.8))
    xs = list(out.keys()); ys = [out[t] for t in xs]
    plt.bar(range(len(xs)), ys)
    plt.xticks(range(len(xs)), [f"@{t}" for t in xs])
    plt.ylim(0,1)
    plt.title("Label Transfer Precision@k")
    return out, _fig_to_img_html(fig)

def _per_cluster_violins(Xc: np.ndarray, y_hat: np.ndarray):
    # per-sample silhouette
    try:
        sil_samples = silhouette_samples(Xc, y_hat, metric="euclidean")
    except Exception:
        sil_samples = np.full(len(Xc), np.nan, dtype=np.float32)
    # centroid distances
    clusters = sorted(np.unique(y_hat))
    centers = np.vstack([Xc[y_hat == c].mean(axis=0) for c in clusters])
    dists = np.zeros(len(Xc), dtype=np.float32)
    for c, centroid in zip(clusters, centers):
        idxs = np.where(y_hat == c)[0]
        dists[idxs] = np.linalg.norm(Xc[idxs] - centroid, axis=1)
    # violins
    sil_groups = [sil_samples[y_hat == c] for c in clusters]
    dist_groups = [dists[y_hat == c] for c in clusters]

    fig1 = plt.figure(figsize=(min(12, 0.18*len(clusters)+6), 4.6))
    plt.violinplot(sil_groups, showmeans=True, showextrema=False)
    plt.xlabel("Cluster index"); plt.ylabel("Silhouette (per-sample)")
    plt.title("Per-cluster Silhouette")
    v1 = _fig_to_img_html(fig1)

    fig2 = plt.figure(figsize=(min(12, 0.18*len(clusters)+6), 4.6))
    plt.violinplot(dist_groups, showmeans=True, showextrema=False)
    plt.xlabel("Cluster index"); plt.ylabel("Distance to centroid")
    plt.title("Per-cluster Centroid Distance")
    v2 = _fig_to_img_html(fig2)

    return sil_samples, dists, v1, v2


# ------------------------------------------------------------------------------------
# Main analysis (2% subset) — enhanced with new plots
# ------------------------------------------------------------------------------------
def analyze_all_unsupervised_to_html(encoder_ckpt_path, dataset_paths, labels_list,
                                     cluster_method='kmeans', n_clusters=20,
                                     subset_fraction=0.02, subset_seed=42,
                                     subset_strategy="random", wave_policy="convert",
                                     ks_scan=(20,30,40,50,60,80,100),
                                     stability_seeds=(0,1,2,3,4), stability_subsample=2000,
                                     sr_signal=10000, pca_dim=None):
    embeddings_all, loc_labels_all, class_labels_all, file_paths_all = [], [], [], []

    for path, loc_label in zip(dataset_paths, labels_list):
        H, chosen_files = extract_embeddings_from_specdir(
            encoder_ckpt_path, path,
            batch_size=128, n_workers=0,
            subset_fraction=subset_fraction,
            subset_seed=subset_seed,
            subset_strategy=subset_strategy,
            wave_policy=wave_policy
        )
        embeddings_all.append(H)
        file_paths_all.extend(chosen_files)
        loc_labels_all.extend([loc_label] * H.shape[0])

        full_files_sorted = list_audio_npy_files(path)
        cls_full = load_class_labels_if_any(path, count=len(full_files_sorted))
        name_to_label = {os.path.basename(p): cls_full[i] for i, p in enumerate(full_files_sorted)}
        class_labels_all.extend([name_to_label[os.path.basename(p)] for p in chosen_files])

    # Stack and normalize (Euclidean on L2-normalized ≈ cosine)
    X = np.vstack(embeddings_all).astype(np.float32)
    Xn = normalize(X, norm="l2", axis=1)
    Xc = Xn
    if (pca_dim is not None) and (Xn.shape[1] > pca_dim):
        Xc = PCA(n_components=int(pca_dim), random_state=42).fit_transform(Xn)

    location_labels = np.array(loc_labels_all, dtype=object)
    class_labels = np.array(class_labels_all, dtype=object)
    original_paths = np.array(file_paths_all, dtype=object)

    # ---- UMAP (cosine) for viz (by clusters & by GT)
    reducer = UMAP(n_components=3, n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42)
    t_umap0 = time.perf_counter()
    proj_3d = reducer.fit_transform(Xn)
    t_umap1 = time.perf_counter()
    print(f"[INFO] UMAP done in {t_umap1 - t_umap0:.2f}s")

    # ---- clustering on Xc
    print(f"[INFO] Clustering with {cluster_method}, k={n_clusters}")
    cluster_labels, cluster_probs = choose_clusterer(cluster_method, Xc, n_clusters, seed=42)

    # ---- Global internal metrics
    uniq = np.unique(cluster_labels)
    valid_for_metrics = (len(uniq) > 1) and (Xc.shape[0] > len(uniq))

    sil = _safe_metric(silhouette_score, Xc, cluster_labels, metric="euclidean") if valid_for_metrics else float("nan")
    dbi = _safe_metric(davies_bouldin_score, Xc, cluster_labels) if valid_for_metrics else float("nan")
    ch  = _safe_metric(calinski_harabasz_score, Xc, cluster_labels) if valid_for_metrics else float("nan")

    # ---- External metrics vs any labels.npy
    def _safe_external(fn, y_true, y_pred):
        try:
            return float(fn(y_true, y_pred)) if valid_for_metrics else float("nan")
        except Exception:
            return float("nan")
    ari  = _safe_external(adjusted_rand_score, class_labels, cluster_labels)
    ami  = _safe_external(adjusted_mutual_info_score, class_labels, cluster_labels)
    hacc = _safe_external(hungarian_accuracy, class_labels, cluster_labels)

    print(f"[INFO] Silhouette (↑): {sil if sil==sil else 'nan'}")
    print(f"[INFO] Davies–Bouldin (↓): {dbi if dbi==dbi else 'nan'}")
    print(f"[INFO] Calinski–Harabasz (↑): {ch if ch==ch else 'nan'}")
    print(f"[INFO] ARI (↑): {ari if ari==ari else 'nan'}")
    print(f"[INFO] AMI (↑): {ami if ami==ami else 'nan'}")
    print(f"[INFO] Hungarian Accuracy (↑): {hacc if hacc==hacc else 'nan'}")

    title_txt = (
        f"{cluster_method.capitalize()} @ k={n_clusters} "
        f"(Sil={sil:.3f} | DBI={dbi:.3f} | CH={ch:.1f} | "
        f"ARI={ari:.3f} | AMI={ami:.3f} | H-Acc={hacc:.3f})"
    )

    # UMAP colored by cluster
    umap_by_cluster = px.scatter_3d(
        x=proj_3d[:, 0], y=proj_3d[:, 1], z=proj_3d[:, 2],
        color=[str(c) for c in cluster_labels],
        symbol=location_labels,
        hover_data={"Cluster": cluster_labels, "GT": class_labels},
        title=title_txt + " • colored by cluster",
        opacity=0.85, height=700
    )
    umap_cluster_html = pio.to_html(umap_by_cluster, include_plotlyjs="cdn", full_html=False)

    # UMAP colored by GT labels (object-safe)
    gt_colors = class_labels.astype(str)
    umap_by_gt = px.scatter_3d(
        x=proj_3d[:, 0], y=proj_3d[:, 1], z=proj_3d[:, 2],
        color=gt_colors,
        symbol=location_labels,
        hover_data={"Cluster": cluster_labels, "GT": class_labels},
        title=title_txt + " • colored by GT",
        opacity=0.85, height=700
    )
    umap_gt_html = pio.to_html(umap_by_gt, include_plotlyjs=False, full_html=False)

    # ---- k-scan
    kscan_arr, kscan_imgs = _k_scan_curves(Xc, cluster_method, ks=tuple(ks_scan))
    kscan_html = "".join(f'<div class="col-md-4 mb-3">{img}</div>' for img in kscan_imgs)

    # ---- stability vs seed
    stability_stats, stability_img = _stability_vs_seed(Xc, cluster_method, n_clusters,
                                                        seeds=tuple(stability_seeds),
                                                        subsample=int(stability_subsample))

    # ---- retrieval@k (if any labels)
    retrieval_stats, retrieval_img = _retrieval_at_k(Xn, class_labels, ks=(1,5,10))

    # ---- per-cluster violins
    sil_samples, dists, violin_sil_html, violin_dist_html = _per_cluster_violins(Xc, cluster_labels)

    # ---- Per-cluster sections + spectrogram consistency
    cluster_html = ""
    cluster_blocks = []
    base_for_entropy = len(set(location_labels))

    spec_cos_scores, spec_dtw_scores = [], []

    for c in sorted(set(cluster_labels)):
        idxs = np.where(cluster_labels == c)[0]
        if idxs.size == 0:
            continue

        cluster_counts = Counter(location_labels[idxs])
        class_counts = Counter(class_labels[idxs])

        metrics = evaluate_cluster_metrics(Xc, idxs, location_labels, location_entropy_base=base_for_entropy)

        spec_cos = _avg_intra_cluster_spec_cosine(original_paths, idxs, fs=sr_signal, max_samples=50)
        spec_dtw = _avg_intra_cluster_spec_dtw(original_paths, idxs, fs=sr_signal, max_samples=25)

        if np.isfinite(spec_cos): spec_cos_scores.append(spec_cos)
        if np.isfinite(spec_dtw): spec_dtw_scores.append(spec_dtw)

        meta_html = "<p><strong>Location Distribution:</strong></p><ul>" + "".join(
            f"<li><b>{loc}</b>: {count} ({count/len(idxs):.1%})</li>" for loc, count in cluster_counts.items()
        ) + "</ul>"

        meta_html += "<p><strong>Class Distribution:</strong></p><ul>" + "".join(
            f"<li>{cls}: {count}</li>" for cls, count in class_counts.items()
        ) + "</ul>"

        meta_html += f"""
        <p><strong>Cluster Metrics (Embedding Space):</strong></p>
        <ul>
            <li>Size: {len(idxs)}</li>
            <li>Intra-Cluster Variance: {metrics['variance']:.4f}</li>
            <li>Mean Cosine Similarity (Embeddings): {metrics['mean_sim']:.4f}</li>
            <li>Location Entropy: {metrics['entropy']:.3f}</li>
            <li>Composite Quality Score: {metrics['quality']:.4f}</li>
            <li><strong>Novelty Score:</strong> {metrics['novelty']:.4f}</li>
        </ul>
        <p><strong>Spectrogram Consistency (Signal Space):</strong></p>
        <ul>
            <li>Mean Spectrogram Cosine (↑): {spec_cos if np.isfinite(spec_cos) else float('nan'):.4f}</li>
            <li>Mean Spectrogram DTW (↓): {spec_dtw if np.isfinite(spec_dtw) else float('nan'):.2f}</li>
        </ul>
        """

        # Nearest-to-center exemplars (4 images)
        center = np.mean(Xc[idxs], axis=0, keepdims=True)
        distances = np.linalg.norm(Xc[idxs] - center, axis=1)
        sorted_indices = np.argsort(distances)
        sampled_idxs = idxs[sorted_indices[:min(4, len(sorted_indices))]]

        imgs = []
        for i, chosen_idx in enumerate(sampled_idxs):
            try:
                x = np.load(original_paths[chosen_idx], mmap_mode="r")
                title = f"#{i+1} | {location_labels[chosen_idx]} | Class {class_labels[chosen_idx]}"
                if x.ndim == 1:
                    imgs.append(generate_spectrogram_base64(x.astype(np.float32), fs=sr_signal, title=title))
                else:
                    S = x.squeeze(0).astype(np.float32)
                    fig, ax = plt.subplots(figsize=(8, 3))
                    ax.imshow(S, aspect='auto', origin='lower', cmap='viridis')
                    ax.set_title(title); ax.set_xlabel("Frames"); ax.set_ylabel("Mel bins")
                    imgs.append(_fig_to_img_html(fig))
            except Exception as e:
                imgs.append(f"<p class='text-danger'>Error loading sample {i+1}: {e}</p>")

        carousel_html = make_carousel(c, "all", imgs)
        block = f"<div class='col-md-6 mb-4'><h4>Cluster {c}</h4>{meta_html}{carousel_html}</div>"
        cluster_blocks.append(block)

    for i in range(0, len(cluster_blocks), 2):
        cluster_html += "<div class='row'>" + "".join(cluster_blocks[i:i+2]) + "</div>"

    # Global spectrogram-consistency summary across clusters
    global_spec_cos = float(np.mean(spec_cos_scores)) if len(spec_cos_scores) else float("nan")
    global_spec_dtw = float(np.mean(spec_dtw_scores)) if len(spec_dtw_scores) else float("nan")

    summary_card = f"""
    <div class='row mb-4'>
      <div class='col-md-12'>
        <div class='alert alert-info' role='alert'>
          <h5 class='mb-2'>Spectrogram Consistency (Cluster Averages)</h5>
          <ul class='mb-0'>
            <li><strong>Mean Spectrogram Cosine</strong>: {global_spec_cos if np.isfinite(global_spec_cos) else float('nan'):.4f} (↑ better)</li>
            <li><strong>Mean Spectrogram DTW</strong>: {global_spec_dtw if np.isfinite(global_spec_dtw) else float('nan'):.2f} (↓ better)</li>
          </ul>
          <small>Cosine on standardized spectrograms (z-norm, padded/cropped). DTW optional & subsampled.</small>
        </div>
      </div>
    </div>
    """

    # Overview row with UMAPs and key plots
    overview_row = f"""
    <div class="row mb-4">
      <div class="col-md-6">{umap_cluster_html}</div>
      <div class="col-md-6">{umap_gt_html}</div>
    </div>
    <div class="row mb-3">{kscan_html}</div>
    <div class="row mb-3">
      <div class="col-md-6">{stability_img}</div>
      <div class="col-md-6">{retrieval_img if retrieval_img else "<p>No labeled subset found for retrieval@k.</p>"}</div>
    </div>
    <div class="row mb-3">
      <div class="col-md-12">{violin_sil_html}</div>
    </div>
    <div class="row mb-4">
      <div class="col-md-12">{violin_dist_html}</div>
    </div>
    """

    # Metrics mini-table
    metrics_table = f"""
    <div class="table-responsive mb-4">
      <table class="table table-sm table-bordered">
        <thead class="table-light"><tr>
          <th>Silhouette (↑)</th><th>Davies–Bouldin (↓)</th><th>Calinski–Harabasz (↑)</th>
          <th>ARI (↑)</th><th>AMI (↑)</th><th>Hungarian Acc. (↑)</th>
          <th>Stability Jaccard</th><th>Stability FMI</th>
        </tr></thead>
        <tbody><tr>
          <td>{sil:.3f}</td><td>{dbi:.3f}</td><td>{ch:.1f}</td>
          <td>{ari:.3f}</td><td>{ami:.3f}</td><td>{hacc:.3f}</td>
          <td>{stability_stats['jaccard']:.3f}</td><td>{stability_stats['fmi']:.3f}</td>
        </tr></tbody>
      </table>
    </div>
    """

    return f"""
    <div class='section'>
      <h2>{cluster_method.capitalize()} Clustering Analysis (2% subset)</h2>
      <p class="text-muted">Embeddings L2-normalized prior to clustering; Euclidean distance used (≈ cosine).</p>
      {metrics_table}
      {overview_row}
      {summary_card}
      {cluster_html}
    </div>
    """


# ------------------------------------------------------------------------------------
# Report generator
# ------------------------------------------------------------------------------------
def generate_full_report(encoder_ckpt_path, dataset_paths, labels_list,
                         cluster_method='kmeans', n_clusters=20, out_prefix="unsup_report",
                         subset_fraction=0.02, subset_seed=42, subset_strategy="random",
                         wave_policy="convert", ks_scan=(20,30,40,50,60,80,100),
                         stability_seeds=(0,1,2,3,4), stability_subsample=2000,
                         sr_signal=10000, pca_dim=None):
    section_html = analyze_all_unsupervised_to_html(
        encoder_ckpt_path=encoder_ckpt_path,
        dataset_paths=list(dataset_paths),
        labels_list=list(labels_list),
        cluster_method=cluster_method,
        n_clusters=int(n_clusters),
        subset_fraction=subset_fraction,
        subset_seed=subset_seed,
        subset_strategy=subset_strategy,
        wave_policy=wave_policy,
        ks_scan=ks_scan,
        stability_seeds=stability_seeds,
        stability_subsample=stability_subsample,
        sr_signal=sr_signal,
        pca_dim=pca_dim
    )
    html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Unsupervised Clustering Report (2% subset)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body {{ font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }}
            h1 {{ color: #2c3e50; }}
            h2, h4 {{ color: #34495e; }}
            hr {{ border-top: 2px solid #bbb; margin-top: 24px; margin-bottom: 24px; }}
            .section {{ margin-bottom: 60px; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
            img {{ max-width: 100%; height: auto; }}
        </style>
    </head>
    <body>
      <h1 class='mb-3'>Unsupervised Latent Clustering Report (2% subset)</h1>
      <p><strong>Compared Locations:</strong> {', '.join(labels_list)}</p>
      <p><strong>Subset:</strong> {int(subset_fraction*100)}% • Strategy: {subset_strategy} • Seed: {subset_seed}</p>
      <hr>
      {section_html}
    </body></html>
    """
    output_path = f"{out_prefix}_{cluster_method}_subset{int(subset_fraction*100)}.html"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html)
    return output_path


# ------------------------------------------------------------------------------------
# Example (commented)
# ------------------------------------------------------------------------------------
if __name__ == "__main__":
    encoder_ckpt_path = "reef_ssl_precomp_ckpt.pth"
    dataset_paths = ["/notebooks/dataset_preprocessed"]
    labels_list = ["Full"]
    report_path = generate_full_report(
        encoder_ckpt_path=encoder_ckpt_path,
        dataset_paths=dataset_paths,
        labels_list=labels_list,
        cluster_method='kmeans',     # 'kmeans' | 'gmm' | 'agglomerative'
        n_clusters=40,
        out_prefix="SimCLR_kmeans_k60",
        subset_fraction=0.02,
        subset_seed=45,
        subset_strategy="random",
        wave_policy="convert",
        ks_scan=(2,5,10,20,30,40,50,60,80,100),
        stability_seeds=(0,1,2,3,4),
        stability_subsample=2000,
        sr_signal=10000,
        pca_dim=None,  # e.g., 64 to reduce dimensionality before clustering
    )
    print("✅ Report saved to:", report_path)

[WARN] Could not import from your_training_file: No module named 'your_training_file'
[INFO] Scanning: /notebooks/dataset_preprocessed
[INFO] Found 413272 candidate .npy files
[INFO] Subset size: 8266 (fraction=0.02)
[INFO] Encoder loaded on cuda • CUDA available: True
[INFO] DataLoader ready: 8266 samples • batch_size=128 • workers=0
[INFO] Starting embedding pass on cuda (AMP=True)
[INFO] Batches: 1 • Seen: 128 • 16.9 items/s
[INFO] Batches: 10 • Seen: 1280 • 27.2 items/s
[INFO] Batches: 20 • Seen: 2560 • 39.0 items/s
[INFO] Batches: 30 • Seen: 3840 • 47.0 items/s
[INFO] Batches: 40 • Seen: 5120 • 55.8 items/s
[INFO] Batches: 50 • Seen: 6400 • 62.2 items/s
[INFO] Batches: 60 • Seen: 7680 • 63.7 items/s
[INFO] Embedding done: 8266 items in 127.34s



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



[INFO] UMAP done in 9.83s
[INFO] Clustering with kmeans, k=40
[INFO] Silhouette (↑): 0.2531849145889282
[INFO] Davies–Bouldin (↓): 1.229338827427351
[INFO] Calinski–Harabasz (↑): 1968.1967613655538
[INFO] ARI (↑): 0.06729750079531047
[INFO] AMI (↑): 0.17398020326790878
[INFO] Hungarian Accuracy (↑): 0.15073796273893056
✅ Report saved to: SimCLR_kmeans_k60_kmeans_subset2.html
