Dependencies
-------------------

In [1]:
!pip install librosa
!pip install umap-learn
!pip install plotly
!pip install hdbscan

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.62.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-1.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting llvmlite<0.46,>=0.45.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.45.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64

Report generation code
------------------------

In [8]:
# vae_runner.py

import os, json, math, random, base64
from io import BytesIO
from typing import List, Tuple, Sequence, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset

import torchaudio
from sklearn.mixture import GaussianMixture
from sklearn.metrics import (
    adjusted_rand_score,
    adjusted_mutual_info_score,
    silhouette_score,
    davies_bouldin_score,
    calinski_harabasz_score,
)
from scipy.optimize import linear_sum_assignment
from scipy.signal import spectrogram, get_window

# Safe optional deps for visualization
try:
    import umap.umap_ as umap
    import plotly.express as px
    import plotly.io as pio
    import matplotlib.pyplot as plt
    _HAVE_VIS = True
except Exception:
    _HAVE_VIS = False

# ----------------------------
# Repro & helpers
# ----------------------------
def _set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def _atomic_save(state: dict, path: str):
    """Save to a temp file then replace, to avoid partial/corrupt checkpoints."""
    tmp = path + ".tmp"
    torch.save(state, tmp)
    os.replace(tmp, path)

def _load_resume_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: Optional[torch.cuda.amp.GradScaler],
    ckpt_path: str,
    device: torch.device
):
    """Load a training checkpoint and restore model/optimizer/scaler/epoch/best_val and RNG."""
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    if scaler is not None and "scaler_state_dict" in ckpt and scaler.is_enabled():
        scaler.load_state_dict(ckpt["scaler_state_dict"])
    start_epoch = int(ckpt.get("epoch", 0)) + 1
    best_val = float(ckpt.get("best_val", math.inf))
    # Restore RNG (optional)
    try:
        if "rng_state" in ckpt:
            torch.set_rng_state(ckpt["rng_state"]["torch"])
            random.setstate(ckpt["rng_state"]["python"])
            np.random.set_state(ckpt["rng_state"]["numpy"])
    except Exception:
        pass
    return start_epoch, best_val

# ======================================================
# One-click precompute helpers
# ======================================================
def _mel_pack_dir(audio_dir: str) -> str:
    """
    Returns the sibling directory where precomputed log-mels (.pt) live.
    E.g., /data/siteA  ->  /data/siteA__mels64_fft1024_h512
    """
    return f"{audio_dir}__mels64_fft1024_h512"

def precompute_logmels_if_needed(
    audio_dir: str,
    *,
    sr: int = 10000,
    n_mels: int = 64,
    n_fft: int = 1024,
    hop_length: int = 512,
    dtype: torch.dtype = torch.float16,
) -> str:
    """
    If a precomputed mel pack does not exist, build it once.
    Returns the mel_dir path.
    """
    mel_dir = _mel_pack_dir(audio_dir)
    labels_src = os.path.join(audio_dir, "labels.npy")
    if not os.path.exists(labels_src):
        raise FileNotFoundError(f"labels.npy missing in {audio_dir}")

    os.makedirs(mel_dir, exist_ok=True)
    labels_dst = os.path.join(mel_dir, "labels.npy")

    raw_files = sorted(
        [f for f in os.listdir(audio_dir) if f.endswith(".npy") and f != "labels.npy"],
        key=lambda x: int(os.path.splitext(x)[0])
    )
    pt_files = {f.replace(".npy", ".pt") for f in os.listdir(mel_dir) if f.endswith(".pt")}
    if len(pt_files) >= int(0.99 * len(raw_files)) and os.path.exists(labels_dst):
        return mel_dir

    print(f"Precomputing log-mels into: {mel_dir} (this is one-time)")
    np.save(labels_dst, np.load(labels_src))

    mel_tfm = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
    )
    for f in raw_files:
        out_pt = os.path.join(mel_dir, f.replace(".npy", ".pt"))
        if os.path.exists(out_pt):
            continue
        y = np.load(os.path.join(audio_dir, f)).astype(np.float32)
        m = np.max(np.abs(y)) + 1e-8
        y = y / m
        y_t = torch.from_numpy(y).unsqueeze(0)  # [1, T]
        with torch.no_grad():
            logmel = torch.log(mel_tfm(y_t) + 1e-8).to(dtype)
        torch.save(logmel.contiguous(), out_pt)

    print("‚úÖ Log-mel precompute complete.")
    return mel_dir

# ----------------------------
# Datasets
# ----------------------------
class SpectrogramDataset(Dataset):
    """
    Loads 1D npy audio and returns log-mel tensor [1, 64, T].
    (On-the-fly MelSpectrogram; slower than precomputed.)
    """
    def __init__(self, audio_dir: str, sr: int = 10000, n_mels: int = 64, cache: bool = True):
        self.audio_paths = sorted(
            [os.path.join(audio_dir, f) for f in os.listdir(audio_dir)
             if f.endswith(".npy") and f != "labels.npy"],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        labels_path = os.path.join(audio_dir, "labels.npy")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"labels.npy missing at {labels_path}")
        self.labels = np.load(labels_path)
        if len(self.audio_paths) != len(self.labels):
            raise ValueError(f"{len(self.audio_paths)} audio files != {len(self.labels)}")

        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=1024, hop_length=512, n_mels=n_mels
        )
        self.cache_enabled = cache
        self.cache = [None] * len(self.audio_paths) if cache else None

    def __len__(self):
        return len(self.audio_paths)

    def _compute_logmel(self, idx: int) -> torch.Tensor:
        y = np.load(self.audio_paths[idx]).astype(np.float32)
        y = y / (np.max(np.abs(y)) + 1e-8)
        y_tensor = torch.from_numpy(y).unsqueeze(0)  # [1, T]
        mel = self.mel_transform(y_tensor)          # [1, 64, Tm]
        logmel = torch.log(mel + 1e-8)
        return logmel

    def __getitem__(self, idx):
        if self.cache_enabled and self.cache[idx] is not None:
            logmel = self.cache[idx]
        else:
            logmel = self._compute_logmel(idx)
            if self.cache_enabled:
                self.cache[idx] = logmel
        label = int(self.labels[idx])
        return logmel, label

class PrecomputedMelDataset(Dataset):
    """
    Loads precomputed log-mels saved as .pt tensors (shape [1, 64, T]).
    """
    def __init__(self, mel_dir: str, cache: bool = False):
        self.paths = sorted(
            [os.path.join(mel_dir, f) for f in os.listdir(mel_dir) if f.endswith(".pt")],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        labels_path = os.path.join(mel_dir, "labels.npy")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"labels.npy missing at {labels_path}")
        self.labels = np.load(labels_path)
        if len(self.paths) != len(self.labels):
            raise ValueError(f"{len(self.paths)} mel files != {len(self.labels)} labels")
        self.cache = [None] * len(self.paths) if cache else None

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        if self.cache is not None and self.cache[idx] is not None:
            x = self.cache[idx]
        else:
            x = torch.load(self.paths[idx], map_location="cpu")  # [1, 64, T] (likely fp16)
            if self.cache is not None:
                self.cache[idx] = x
        return x, int(self.labels[idx])

# ----------------------------
# Collate: pad/crop time dim to target_T
# ----------------------------
def _pad_or_crop_time(x: torch.Tensor, target_T: int) -> torch.Tensor:
    """
    x: [B, 1, 64, Tvar] -> [B, 1, 64, target_T]
    """
    B, C, Freq, T = x.shape
    if T == target_T:
        return x
    if T > target_T:
        start = (T - target_T) // 2
        return x[..., start:start+target_T]
    total_pad = target_T - T
    left = total_pad // 2
    right = total_pad - left
    return F.pad(x, (left, right), mode="constant", value=0.0)

def _collate_fixed_T(batch, target_T: int):
    xs, ys = zip(*batch)
    # Fast path: stack if all Ts match target_T
    try:
        Ts = {t.shape[-1] for t in xs}
        if len(Ts) == 1 and list(Ts)[0] == target_T:
            return torch.stack(xs, dim=0), torch.tensor(ys, dtype=torch.long)
    except Exception:
        pass
    max_T = max(t.shape[-1] for t in xs)
    stacked = torch.zeros(len(xs), 1, 64, max_T, dtype=xs[0].dtype)
    for i, t in enumerate(xs):
        T = t.shape[-1]
        stacked[i, :, :, :T] = t
    fixed = _pad_or_crop_time(stacked, target_T)
    labels = torch.tensor(ys, dtype=torch.long)
    return fixed, labels

# ----------------------------
# VAE
# ----------------------------
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=64, input_shape=(1, 64, 63)):
        super().__init__()
        self.input_shape = input_shape
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU()
        )
        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)
            h = self.encoder_cnn(dummy_input)
            self.encoder_output_shape = h.shape[1:]  # [C,H,W]
            self.flattened_dim = h.numel()

        self.encoder = nn.Sequential(self.encoder_cnn, nn.Flatten())
        self.fc_mu = nn.Linear(self.flattened_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_dim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, self.flattened_dim)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, self.encoder_output_shape),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Sigmoid(),
            nn.Upsample(size=self.input_shape[1:], mode='bilinear', align_corners=False)
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(self.fc_decode(z))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# ----------------------------
# Loss & metrics
# ----------------------------
def _vae_loss(x_recon, x, mu, logvar):
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

def _hungarian_accuracy(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_pred)):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return sum(w[i, j] for i, j in zip(row_ind, col_ind)) / len(y_pred)

def _internal_indices(X: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    out = {"silhouette": float("nan"),
           "davies_bouldin": float("nan"),
           "calinski_harabasz": float("nan")}
    n = len(labels)
    k = len(np.unique(labels))
    # need at least 2 clusters and fewer than n samples,
    # and at least one cluster with >1 sample for silhouette
    if k <= 1 or k >= n:
        return out
    try:
        out["silhouette"] = float(silhouette_score(X, labels))
    except Exception:
        pass
    try:
        out["davies_bouldin"] = float(davies_bouldin_score(X, labels))
    except Exception:
        pass
    try:
        out["calinski_harabasz"] = float(calinski_harabasz_score(X, labels))
    except Exception:
        pass
    return out

# ======================================================
# Fixed split helper ‚Äî do the stratified 80/20 split OUTSIDE training
# ======================================================
def make_fixed_train_test_indices(dataset_path: str, *, test_size: float = 0.2, random_state: int = 42) -> Tuple[np.ndarray, np.ndarray]:
    audio_paths = sorted(
        [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)
         if f.endswith(".npy") and f != "labels.npy"],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
    )
    labels = np.load(os.path.join(dataset_path, "labels.npy"))
    assert len(audio_paths) == len(labels), "Mismatch between files and labels"
    from sklearn.model_selection import train_test_split
    indices = np.arange(len(audio_paths))
    train_idx, test_idx = train_test_split(
        indices, test_size=test_size, stratify=labels, random_state=random_state
    )
    return np.asarray(train_idx), np.asarray(test_idx)

# ----------------------------
# Training (uses ONLY provided train indices)
# ----------------------------
def run_training(
    audio_dir: str,
    *,
    train_indices: Sequence[int],
    val_split: float = 0.1,
    epochs: int = 50,
    batch_size: int = 256,
    lr: float = 1e-3,
    latent_dim: int = 64,
    save_path: str = "vae_model.pth",
    seed: int = 42,
    amp: bool = True,
    resume_from: Optional[str] = None,
    checkpoint_path: str = "vae_train.ckpt",
    precompute_mels: bool = True,
    num_workers: Optional[int] = None,
    grad_accum_steps: int = 1,
    early_stop_patience: int = 10,
):
    _set_seed(seed)

    dataset_path_for_training = audio_dir
    use_precomputed = False
    if precompute_mels:
        try:
            dataset_path_for_training = precompute_logmels_if_needed(audio_dir)
            use_precomputed = True
        except Exception as e:
            print(f"‚ö†Ô∏è Precompute failed or skipped, using on-the-fly mel spec: {e}")

    full_dataset = (
        PrecomputedMelDataset(dataset_path_for_training, cache=False)
        if use_precomputed else SpectrogramDataset(audio_dir)
    )

    model = ConvVAE(latent_dim=latent_dim)
    target_T = model.input_shape[2]

    train_base = Subset(full_dataset, list(train_indices))
    val_size = max(1, int(round(val_split * len(train_base))))
    tr_size = len(train_base) - val_size
    if tr_size <= 0:
        raise ValueError(f"val_split too large for train size {len(train_base)}")

    g = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = random_split(train_base, [tr_size, val_size], generator=g)

    if num_workers is None:
        num_workers = max(2, (os.cpu_count() or 4) // 2)
    loader_kwargs = dict(
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
        collate_fn=lambda b: _collate_fixed_T(b, target_T),
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, **loader_kwargs)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Fast math
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

    # Make the model channels-last
    model = model.to(device).to(memory_format=torch.channels_last)
    # (keep torch.compile disabled to avoid cudagraph live-storage errors)
    # try:
    #     model = torch.compile(model, mode="reduce-overhead")
    # except Exception:
    #     pass

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, foreach=True)
    scaler = torch.cuda.amp.GradScaler(enabled=amp and device.type == "cuda")

    # Resume support
    start_epoch = 1
    best_val = math.inf
    if resume_from is not None and os.path.isfile(resume_from):
        try:
            start_epoch, best_val = _load_resume_checkpoint(model, optimizer, scaler, resume_from, device)
            print(f"üîÅ Resumed from '{resume_from}' at epoch {start_epoch} (best_val={best_val:.2f})")
        except Exception as e:
            print(f"‚ö†Ô∏è  Failed to resume from '{resume_from}': {e}. Starting fresh.")

    # Warmup + cosine
    warmup_epochs = max(1, epochs // 20)
    def lr_lambda(e):
        if e <= warmup_epochs: return e / float(warmup_epochs)
        t = (e - warmup_epochs) / max(1, (epochs - warmup_epochs))
        return 0.5 * (1 + math.cos(math.pi * t))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    history, bad = [], 0

    for epoch in range(start_epoch, epochs + 1):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad(set_to_none=True)

        for i, (x, _) in enumerate(train_loader):
            x = x.to(device, non_blocking=True)
            #  Upcast any fp16 precomputed mels to fp32 to match model params
            if x.dtype != torch.float32:
                x = x.to(torch.float32)
            #  Ensure consistent layout (channels-last)
            x = x.contiguous(memory_format=torch.channels_last)

            with torch.cuda.amp.autocast(enabled=amp and device.type == "cuda"):
                x_recon, mu, logvar = model(x)
                loss = _vae_loss(x_recon, x, mu, logvar) / max(1, grad_accum_steps)

            if scaler.is_enabled():
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if (i + 1) % max(1, grad_accum_steps) == 0:
                if scaler.is_enabled():
                    scaler.step(optimizer); scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            total_loss += loss.item() * max(1, grad_accum_steps)

        # Validation (sampled ~20% for speed)
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for j, (x, _) in enumerate(val_loader):
                if j > max(1, len(val_loader) // 5):
                    break
                x = x.to(device, non_blocking=True)
                if x.dtype != torch.float32:
                    x = x.to(torch.float32)
                x = x.contiguous(memory_format=torch.channels_last)
                x_recon, mu, logvar = model(x)
                val_loss += _vae_loss(x_recon, x, mu, logvar).item()

        history.append({"epoch": epoch, "train_loss_sum": total_loss, "val_loss_sum": val_loss})

        # Save BEST eval-only weights
        if val_loss < best_val - 1e-6:
            torch.save(model.state_dict(), save_path)
            best_val = val_loss
            bad = 0
            print(f" Epoch {epoch}: new best val {best_val:.2f} ‚Äî saved {save_path}")
        else:
            bad += 1

        # Always save a TRAINING CHECKPOINT (atomic overwrite) so we can resume later
        rng_state = {
            "torch": torch.get_rng_state(),
            "python": random.getstate(),
            "numpy": np.random.get_state(),
        }
        ckpt = {
            "epoch": epoch,
            "best_val": float(best_val),
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict() if scaler.is_enabled() else None,
            "config": {
                "audio_dir": audio_dir,
                "latent_dim": latent_dim,
                "batch_size": batch_size,
                "lr": lr,
                "seed": seed,
                "amp": amp,
                "train_indices_len": len(train_indices),
                "used_precomputed": bool(use_precomputed),
                "precompute_pack": dataset_path_for_training if use_precomputed else None,
            },
            "rng_state": rng_state,
        }
        _atomic_save(ckpt, checkpoint_path)
        print(f"üíæ Epoch {epoch}: checkpoint saved ‚Üí {checkpoint_path}")

        scheduler.step()

        # Early stopping
        if bad >= early_stop_patience:
            print(f"‚èπÔ∏è Early stopping at epoch {epoch}")
            break

    return {
        "best_val_loss_sum": float(best_val),
        "history": history,
        "save_path": save_path,
        "latent_dim": latent_dim,
        "seed": seed,
        "audio_dir": audio_dir,
        "train_samples": len(train_dataset),
        "val_samples": len(val_dataset),
        "last_checkpoint": checkpoint_path,
        "last_epoch": epoch,
        "used_precomputed": bool(use_precomputed),
        "precompute_pack": dataset_path_for_training if use_precomputed else None,
    }

def _preprocess_latents(latents: np.ndarray, *, l2: bool = True, pca_dim: Optional[int] = None, seed: int = 42):
    Z = StandardScaler().fit_transform(latents.astype(np.float32))
    if l2:
        Z = normalize(Z, norm="l2", axis=1)
    if pca_dim is not None and Z.shape[1] > pca_dim:
        Z = PCA(n_components=int(pca_dim), whiten=True, random_state=seed).fit_transform(Z)
    return Z

from sklearn.preprocessing import StandardScaler, normalize
from sklearn.decomposition import PCA
from sklearn.metrics.cluster import contingency_matrix

def _hungarian_from_contingency(y_true, y_pred) -> float:
    C = contingency_matrix(y_true, y_pred)  # rows=true, cols=pred
    r, c = linear_sum_assignment(C.max() - C)
    return float(C[r, c].sum() / C.sum())

# ======================================================
# Runner: INFERENCE / EVAL on FIXED TEST 20%
# ======================================================
def run_inference(
    audio_dir: str,
    *,
    test_indices: Sequence[int],
    checkpoint_path: str = "vae_model.pth",
    batch_size: int = 256,
    kmeans_k: Optional[int] = None,
    do_umap: bool = True,
    umap_components: int = 3,
    umap_html_path: str = "vae_test_umap.html",
    seed: int = 42,
    l2_normalize: bool = True,
    pca_dim: Optional[int] = None,   # e.g., 64 to match other pipeline
) -> Dict[str, Any]:
    """
    Load trained VAE, extract Œº on FIXED TEST subset, preprocess (Std->L2->PCA optional),
    run KMeans with configurable k, and report ARI/AMI/H-Acc + Sil/DB/CH.
    """
    _set_seed(seed)
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Prefer precomputed pack if it exists
    mel_dir = _mel_pack_dir(audio_dir)
    if os.path.isdir(mel_dir) and os.path.exists(os.path.join(mel_dir, "labels.npy")):
        full_dataset = PrecomputedMelDataset(mel_dir, cache=False)
        chosen_src = mel_dir
    else:
        full_dataset = SpectrogramDataset(audio_dir, cache=True)
        chosen_src = audio_dir

    # DEBUG: print label distribution for sanity
    y_all = getattr(full_dataset, "labels", None)
    if y_all is None:
        print("‚ùó Dataset has no labels attribute (this would break AMI/ARI).")
    else:
        vals, cnts = np.unique(y_all, return_counts=True)
        print(f"[INFO] Inference source dir: {chosen_src}")
        print("[INFO] Full label dist:", dict(zip(vals.tolist(), cnts.tolist())))
        tvals, tcnts = np.unique(y_all[np.asarray(test_indices)], return_counts=True)
        print("[INFO] TEST label dist:", dict(zip(tvals.tolist(), tcnts.tolist())))

    # Build loader over fixed TEST indices
    model = ConvVAE(latent_dim=model_latent_dim_from_ckpt(checkpoint_path) or 64)
    target_T = model.input_shape[2]
    test_subset = Subset(full_dataset, list(test_indices))
    num_workers = max(2, (os.cpu_count() or 4) // 2)
    test_loader = DataLoader(
        test_subset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True, persistent_workers=True,
        prefetch_factor=4,
        collate_fn=lambda b: _collate_fixed_T(b, target_T)
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model = model.to(device).to(memory_format=torch.channels_last).eval()

    # Extract TEST Œº
    latents, true_labels = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device, non_blocking=True)
            if x.dtype != torch.float32:
                x = x.to(torch.float32)
            x = x.contiguous(memory_format=torch.channels_last)
            mu, _ = model.encode(x)
            latents.append(mu.cpu().numpy())
            true_labels.append(y.numpy())
    latents = np.concatenate(latents, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)

    # === Preprocess Œº like the other report ===
    Z = _preprocess_latents(latents, l2=l2_normalize, pca_dim=pca_dim, seed=seed)

    # K for clustering
    n_clusters = int(kmeans_k) if kmeans_k is not None else int(len(np.unique(true_labels)))

    # Cluster on Z (not raw Œº)
    from sklearn.cluster import KMeans
    pred_labels = KMeans(n_clusters=n_clusters, n_init=20, random_state=seed).fit_predict(Z)

    # Metrics
    uniq_pred = np.unique(pred_labels)
    valid_internal = (len(uniq_pred) > 1) and (Z.shape[0] > len(uniq_pred))
    extra = {
        "silhouette": float(silhouette_score(Z, pred_labels)) if valid_internal else float("nan"),
        "davies_bouldin": float(davies_bouldin_score(Z, pred_labels)) if valid_internal else float("nan"),
        "calinski_harabasz": float(calinski_harabasz_score(Z, pred_labels)) if valid_internal else float("nan"),
    }

    u_true = np.unique(true_labels).size
    u_pred = np.unique(pred_labels).size
    if u_true < 2 or u_pred < 2:
        ari = float("nan")
        ami = float("nan")
        # H-Acc still meaningful if true has >= 1 class
        hung = _hungarian_from_contingency(true_labels, pred_labels) if u_true >= 1 else float("nan")
    else:
        ari = float(adjusted_rand_score(true_labels, pred_labels))
        ami = float(adjusted_mutual_info_score(true_labels, pred_labels))
        hung = _hungarian_from_contingency(true_labels, pred_labels)

    results: Dict[str, Any] = {
        "test_samples": int(len(true_labels)),
        "latents_shape": tuple(latents.shape),
        "processed_shape": tuple(Z.shape),
        "kmeans_k": int(n_clusters),
        "metrics": {
            "ari": float(ari),
            "ami": float(ami),
            "hungarian_accuracy": float(hung),
            "silhouette": float(extra["silhouette"]),
            "davies_bouldin": float(extra["davies_bouldin"]),
            "calinski_harabasz": float(extra["calinski_harabasz"]),
        },
        "checkpoint_path": checkpoint_path,
        "preprocess": {
            "l2_normalize": bool(l2_normalize),
            "pca_dim": int(pca_dim) if pca_dim is not None else None
        }
    }

    # Optional UMAP on processed Z
    if do_umap and _HAVE_VIS:
        if umap_components not in (2, 3):
            raise ValueError("umap_components must be 2 or 3")
        reducer = umap.UMAP(n_components=umap_components, metric="cosine", random_state=seed)
        z_umap = reducer.fit_transform(Z)
        title_bits = [f"k={n_clusters}"]
        if l2_normalize: title_bits.append("L2")
        if pca_dim: title_bits.append(f"PCA{pca_dim}-whiten")
        title = "VAE Test Latents (" + ", ".join(title_bits) + f") | ARI {ari:.4f}, AMI {ami:.4f}, H-Acc {hung:.4f}"
        if umap_components == 3:
            fig = px.scatter_3d(x=z_umap[:,0], y=z_umap[:,1], z=z_umap[:,2],
                                color=true_labels.astype(str), title=title)
        else:
            fig = px.scatter(x=z_umap[:,0], y=z_umap[:,1],
                             color=true_labels.astype(str), title=title)
        fig.write_html(umap_html_path)
        print(f"Saved UMAP visualization to '{umap_html_path}'")
        results["umap_html_path"] = umap_html_path

    return results

def model_latent_dim_from_ckpt(_path: str) -> Optional[int]:
    """(Optional) Heuristic placeholder if you encode latent_dim in filename; otherwise return None."""
    try:
        base = os.path.basename(_path)
        if "lat" in base:
            import re
            m = re.search(r"lat(\d+)", base)
            if m: return int(m.group(1))
    except Exception:
        pass
    return None

# ----------------------------
# Report helpers (unchanged behavior; safe UMAP usage)
# ----------------------------
def _ali_spec(x, fs):
    Lframe2 = 1000
    po = 80
    lov = int(np.ceil((po / 100) * Lframe2))
    taper = get_window('hann', Lframe2)
    Nfft = 2 ** (int(np.floor(np.log2(Lframe2))) + 2)
    f, t, s = spectrogram(x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode='complex')
    as_ = np.abs(s)
    as_max = np.max(as_)
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx:max_inx+1, :], f[min_inx:max_inx+1], t

def _generate_spectrogram_base64(audio, fs=10000, title="Spectrogram"):
    if not _HAVE_VIS:
        return "<p>Matplotlib not available.</p>"
    spec, f_axis, t_axis = _ali_spec(audio, fs)
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.imshow(spec, aspect='auto', origin='lower',
              extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]],
              cmap='hsv')
    ax.set_title(title)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Frequency (Hz)")
    fig.tight_layout()
    buf = BytesIO()
    fig.savefig(buf, format='png', dpi=150, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" alt="{title}">'

def _make_carousel(cluster_id, class_id, encoded_imgs):
    carousel_id = f"carousel_class{class_id}_cluster{cluster_id}"
    indicators = "".join([
        f'<button type="button" data-bs-target="#{carousel_id}" data-bs-slide-to="{i}" class="{ "active" if i==0 else "" }" aria-current="{ "true" if i==0 else "false" }" aria-label="Slide {i+1}"></button>'
        for i in range(len(encoded_imgs))
    ])
    slides = "".join([
        f'''<div class="carousel-item {'active' if i==0 else ''}">{img}</div>'''
        for i, img in enumerate(encoded_imgs)
    ])
    return f'''
    <div id="{carousel_id}" class="carousel slide" data-bs-ride="carousel">
      <div class="carousel-indicators">{indicators}</div>
      <div class="carousel-inner">{slides}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{carousel_id}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{carousel_id}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    '''

@torch.no_grad()
def _extract_embeddings(vae: ConvVAE, dataloader: DataLoader, device) -> Tuple[np.ndarray, np.ndarray]:
    vae.eval()
    embeddings, labels = [], []
    for x, lbl in dataloader:
        x = x.to(device, non_blocking=True)
        if x.dtype != torch.float32:
            x = x.to(torch.float32)
        x = x.contiguous(memory_format=torch.channels_last)
        mu, _ = vae.encode(x)
        embeddings.append(mu.cpu().numpy())
        labels.extend(lbl.numpy())
    return np.concatenate(embeddings), np.array(labels)

def _analyze_class_with_gmm(vae, class_id, dataset_paths, labels_list, device, gmm_components=10):
    if not _HAVE_VIS:
        return f"<div class='section'><h2>Analysis for Class {class_id}</h2><p>Visualization dependencies not available.</p></div>"
    embeddings_all, loc_labels_all, original_paths = [], [], []
    target_T = vae.input_shape[2]

    for path, loc_label in zip(dataset_paths, labels_list):
        # prefer precomputed pack if it exists
        mel_dir = _mel_pack_dir(path)
        if os.path.isdir(mel_dir) and os.path.exists(os.path.join(mel_dir, "labels.npy")):
            ds = PrecomputedMelDataset(mel_dir, cache=False)
            idxs = np.arange(len(ds))
            loader = DataLoader(ds, batch_size=64, shuffle=False,
                                num_workers=max(2, (os.cpu_count() or 4)//2),
                                pin_memory=True, persistent_workers=True, prefetch_factor=4,
                                collate_fn=lambda b: _collate_fixed_T(b, target_T))
            emb, labels = _extract_embeddings(vae, loader, device)
            # we don't have original raw paths here; skip carousels if needed
            raw_paths = np.array([None]*len(ds))
        else:
            ds = SpectrogramDataset(path, cache=True)
            idxs = np.arange(len(ds))
            loader = DataLoader(ds, batch_size=64, shuffle=False,
                                num_workers=max(2, (os.cpu_count() or 4)//2),
                                pin_memory=True, persistent_workers=True, prefetch_factor=4,
                                collate_fn=lambda b: _collate_fixed_T(b, target_T))
            emb, labels = _extract_embeddings(vae, loader, device)
            raw_paths = np.array([os.path.join(path, f"{i}.npy") for i in idxs])

        mask = labels == class_id
        if np.sum(mask) == 0:
            continue
        embeddings_all.append(emb[mask])
        loc_labels_all.append(np.array([loc_label] * int(np.sum(mask))))
        original_paths.append(raw_paths[mask])

    if len(embeddings_all) == 0:
        return f"<div class='section'><h2>Analysis for Class {class_id}</h2><p>No samples for this class.</p></div>"

    embeddings = np.vstack(embeddings_all)
    location_labels = np.concatenate(loc_labels_all)
    original_paths_flat = np.concatenate(original_paths)

    reducer = umap.UMAP(n_components=3, n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42)
    proj_3d = reducer.fit_transform(embeddings)

    gmm = GaussianMixture(n_components=gmm_components, covariance_type="full", random_state=0)
    cluster_labels = gmm.fit_predict(embeddings)
    sil = float(silhouette_score(embeddings, cluster_labels))

    umap_fig = px.scatter_3d(
        x=proj_3d[:, 0], y=proj_3d[:, 1], z=proj_3d[:, 2],
        color=[str(int(c)) for c in cluster_labels],
        symbol=list(location_labels),
        title=f"Class {class_id} - GMM Clusters (Silhouette={sil:.2f})",
        opacity=0.85, height=800
    )
    umap_html = pio.to_html(umap_fig, include_plotlyjs="cdn", full_html=False)

    cluster_blocks = []
    for c in range(gmm_components):
        idxs = np.where(cluster_labels == c)[0]
        if len(idxs) == 0:
            continue
        cluster_counts = {loc: int(np.sum(location_labels[idxs] == loc)) for loc in set(location_labels)}
        total = len(idxs)
        cluster_label_html = "<ul>" + "".join(
            [f"<li>{loc}: {count} ({count/total:.2%})</li>" for loc, count in cluster_counts.items()]
        ) + "</ul>"

        center = np.mean(embeddings[idxs], axis=0, keepdims=True)
        distances = np.linalg.norm(embeddings[idxs] - center, axis=1)
        sorted_indices = np.argsort(distances)
        sampled_idxs = idxs[sorted_indices[:min(10, len(sorted_indices))]]

        imgs = []
        for i, chosen_idx in enumerate(sampled_idxs):
            try:
                if original_paths_flat[chosen_idx] is None:
                    imgs.append(f"<p>Precomputed dataset: raw audio path unavailable.</p>")
                else:
                    audio = np.load(original_paths_flat[chosen_idx])
                    label = location_labels[chosen_idx]
                    imgs.append(_generate_spectrogram_base64(audio, title=f"#{i+1} ({label})"))
            except Exception as e:
                imgs.append(f"<p>Error loading sample {i+1}: {e}</p>")

        carousel_html = _make_carousel(c, class_id, imgs)
        block = f"<div class='col-md-6 mb-4'><h4>Cluster {c}</h4>{cluster_label_html}{carousel_html}</div>"
        cluster_blocks.append(block)

    cluster_html = ""
    for i in range(0, len(cluster_blocks), 2):
        cluster_html += "<div class='row'>" + "".join(cluster_blocks[i:i+2]) + "</div>"

    return f"""<div class='section'>
        <h2>Analysis for Class {class_id}</h2>
        <div class="container">
            <div class="row justify-content-center mb-4">
                <div class="col-md-12 d-flex justify-content-center">{umap_html}</div>
            </div>
        </div>
        {cluster_html}
    </div>"""

# ----------------------------
# Public API: Report (unchanged interface)
# ----------------------------
def generate_report(
    model_path: str,
    dataset_paths: List[str],
    labels: List[str],
    classes: List[int] = (1,2,3,4,5,6),
    output_html: str = "vae_gmm_report.html",
    latent_dim: int = 64,
) -> str:
    """
    Load a trained VAE and write an interactive UMAP+GMM HTML report.
    Returns the output path.
    """
    if len(dataset_paths) != len(labels):
        raise ValueError("`dataset_paths` and `labels` must have the same length")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = ConvVAE(latent_dim=latent_dim).to(device).to(memory_format=torch.channels_last)
    vae.load_state_dict(torch.load(model_path, map_location=device))
    vae.eval()

    class_names = {
        1: 'Red Hind',
        2: 'Nassau Grouper',
        3: 'Black Grouper',
        4: 'Yellow Fin Grouper',
        5: 'Squirrel Fish',
        6: 'Other Sounds'
    }

    html = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Latent Space Clustering Report (VAE + GMM)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body { font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }
            h1 { color: #2c3e50; }
            h2, h4 { color: #34495e; }
            .section { margin-bottom: 60px; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
        </style>
    </head>
    <body>
    <h1 class='mb-4'>VAE + GMM Clustering Report</h1>
    """

    for class_id in classes:
        class_title = class_names.get(class_id, f"Class {class_id}")
        html += f"<h2>{class_title} (Class {class_id})</h2>"
        html += _analyze_class_with_gmm(vae, class_id, dataset_paths, labels, device)
        html += "<hr>"

    html += "</body></html>"
    with open(output_html, "w", encoding="utf-8") as f:
        f.write(html)
    return output_html

In [3]:
# vae_unsup_report.py
# Self-contained helpers + class-agnostic, K-controlled report for VAE Œº-embeddings.

import os, math, base64
from io import BytesIO
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

# Optional viz deps
try:
    import umap.umap_ as umap
    import plotly.express as px
    import plotly.io as pio
    import matplotlib.pyplot as plt
    _HAVE_VIZ = True
except Exception:
    _HAVE_VIZ = False

# ----------------------------
# Repro & small utils
# ----------------------------
def _set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def _mel_pack_dir(audio_dir: str) -> str:
    """Companion directory name used by your precompute step."""
    return f"{audio_dir}__mels64_fft1024_h512"

def _choose_subset(total_n: int, fraction: float = 0.2, seed: int = 42, strategy: str = "random"):
    """
    Returns sorted indices for a subset of size ceil(total_n * fraction).
    strategy: 'random' | 'head' | 'tail'
    """
    n_sub = max(1, int(math.ceil(total_n * float(fraction))))
    idx = np.arange(total_n)
    if strategy == "random":
        rng = np.random.RandomState(seed)
        chosen = np.sort(rng.choice(total_n, size=n_sub, replace=False))
    elif strategy == "head":
        chosen = idx[:n_sub]
    elif strategy == "tail":
        chosen = idx[-n_sub:]
    else:
        raise ValueError(f"Unknown subset strategy: {strategy}")
    return chosen

# ----------------------------
# Datasets
# ----------------------------
class SpectrogramDataset(Dataset):
    """
    Loads .npy files (1D waveform or 2D spectrogram-like) and returns log-mels [1,64,T].
    If labels.npy exists, it is read but not required (unsupervised mode).
    """
    def __init__(self, audio_dir: str, sr: int = 10000, n_mels: int = 64, cache: bool = True):
        self.audio_paths = sorted(
            [os.path.join(audio_dir, f) for f in os.listdir(audio_dir)
             if f.endswith(".npy") and f != "labels.npy"],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        self.labels = None
        labels_path = os.path.join(audio_dir, "labels.npy")
        if os.path.exists(labels_path):
            try:
                self.labels = np.load(labels_path, allow_pickle=True)
                if len(self.labels) != len(self.audio_paths):
                    self.labels = None  # ignore mismatched labels
            except Exception:
                self.labels = None

        # torchaudio mel
        try:
            import torchaudio
            self._have_ta = True
            self.mel_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sr, n_fft=1024, hop_length=512, n_mels=n_mels
            )
        except Exception:
            self._have_ta = False
            from scipy.signal import spectrogram, get_window
            self._sp_spec = (spectrogram, get_window)
            self._sr = sr
            self._n_mels = n_mels

        self.cache_enabled = cache
        self.cache = [None] * len(self.audio_paths) if cache else None

    def __len__(self): return len(self.audio_paths)

    def _wave_to_logmel(self, y: np.ndarray) -> torch.Tensor:
        y = np.nan_to_num(y.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        m = np.max(np.abs(y)); m = 1.0 if (not np.isfinite(m) or m < 1e-8) else m
        y_t = torch.from_numpy(y / m).unsqueeze(0)  # [1, T]
        if self._have_ta:
            mel = self.mel_transform(y_t)
            return torch.log(torch.clamp(mel, min=1e-8))
        # Fallback: SciPy STFT ‚Üí log-mag, then linear-bin subsample to n_mels
        spectrogram, get_window = self._sp_spec
        win = get_window('hann', 1024, fftbins=True)
        f, t, Z = spectrogram(y, fs=self._sr, window=win, nperseg=1024,
                              noverlap=1024-512, mode='magnitude')
        S = np.log(np.maximum(Z, 1e-10)).astype(np.float32)  # (F_lin, T)
        if S.shape[0] != self._n_mels:
            idx = np.linspace(0, S.shape[0]-1, num=self._n_mels).astype(np.int32)
            S = S[idx]
        return torch.from_numpy(S).unsqueeze(0)

    def _to_logmel(self, arr: np.ndarray) -> torch.Tensor:
        if arr.ndim == 1:
            return self._wave_to_logmel(arr)
        if arr.ndim == 2:
            x = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0)
            x = torch.clamp(x, min=1e-8).log()
            return x
        if arr.ndim == 3 and arr.shape[0] == 1:
            x = torch.from_numpy(arr.astype(np.float32))
            x = torch.clamp(x, min=1e-8).log()
            return x
        raise RuntimeError(f"Unexpected array shape {arr.shape}")

    def __getitem__(self, idx):
        if self.cache_enabled and self.cache[idx] is not None:
            x = self.cache[idx]
        else:
            arr = np.load(self.audio_paths[idx], mmap_mode="r")
            x = self._to_logmel(arr).contiguous()
            if self.cache_enabled:
                self.cache[idx] = x
        label = 0 if self.labels is None else self.labels[idx]
        return x, label

class PrecomputedMelDataset(Dataset):
    """
    Loads precomputed log-mels saved as .pt tensors (shape [1, 64, T]).
    labels.npy is optional and ignored for unsupervised clustering.
    """
    def __init__(self, mel_dir: str, cache: bool = False):
        self.paths = sorted(
            [os.path.join(mel_dir, f) for f in os.listdir(mel_dir) if f.endswith(".pt")],
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )
        self.cache = [None] * len(self.paths) if cache else None

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        if self.cache is not None and self.cache[idx] is not None:
            x = self.cache[idx]
        else:
            x = torch.load(self.paths[idx], map_location="cpu")
            if self.cache is not None:
                self.cache[idx] = x
        return x, 0  # dummy label

# ----------------------------
# Collate: pad/crop time to target_T
# ----------------------------
def _pad_or_crop_time(x: torch.Tensor, target_T: int) -> torch.Tensor:
    B, C, Freq, T = x.shape
    if T == target_T:
        return x
    if T > target_T:
        start = (T - target_T) // 2
        return x[..., start:start+target_T]
    total_pad = target_T - T
    left = total_pad // 2
    right = total_pad - left
    return F.pad(x, (left, right), mode="constant", value=0.0)

def _collate_fixed_T(batch, target_T: int):
    xs, ys = zip(*batch)
    try:
        Ts = {t.shape[-1] for t in xs}
        if len(Ts) == 1 and list(Ts)[0] == target_T:
            return torch.stack(xs, dim=0), torch.tensor(ys)
    except Exception:
        pass
    max_T = max(t.shape[-1] for t in xs)
    stacked = torch.zeros(len(xs), 1, 64, max_T, dtype=xs[0].dtype)
    for i, t in enumerate(xs):
        T = t.shape[-1]; stacked[i, :, :, :T] = t
    fixed = _pad_or_crop_time(stacked, target_T)
    return fixed, torch.tensor(ys)

# ----------------------------
# VAE
# ----------------------------
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=64, input_shape=(1, 64, 63)):
        super().__init__()
        self.input_shape = input_shape
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            h = self.encoder_cnn(dummy)
            self.encoder_output_shape = h.shape[1:]
            self.flattened_dim = h.numel()
        self.encoder = nn.Sequential(self.encoder_cnn, nn.Flatten())
        self.fc_mu = nn.Linear(self.flattened_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_dim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, self.flattened_dim)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, self.encoder_output_shape),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Sigmoid(),
            nn.Upsample(size=self.input_shape[1:], mode='bilinear', align_corners=False)
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(self.fc_decode(z))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# ----------------------------
# Embed extraction
# ----------------------------
@torch.no_grad()
def _extract_embeddings(vae: ConvVAE, dataloader: DataLoader, device) -> Tuple[np.ndarray, np.ndarray]:
    vae.eval()
    emb, labels = [], []
    for x, lbl in dataloader:
        x = x.to(device, non_blocking=True)
        if x.dtype != torch.float32:
            x = x.to(torch.float32)
        x = x.contiguous(memory_format=torch.channels_last)
        mu, _ = vae.encode(x)
        emb.append(mu.cpu().numpy())
        labels.extend(np.asarray(lbl))
    return np.concatenate(emb), np.array(labels)

# ----------------------------
# Spectrogram viz helpers
# ----------------------------
def _ali_spec(x, fs=10000):
    from scipy.signal import spectrogram, get_window
    Lframe2 = 1000
    po = 80
    lov = int(np.ceil((po / 100) * Lframe2))
    taper = get_window('hann', Lframe2)
    Nfft = 2 ** (int(np.floor(np.log2(Lframe2))) + 2)
    f, t, s = spectrogram(x, fs=fs, window=taper, noverlap=lov, nfft=Nfft, mode='complex')
    as_ = np.abs(s); as_max = np.max(as_) if np.isfinite(np.max(as_)) and np.max(as_) > 0 else 1.0
    sdb = 10 * np.log10(100 * as_ / as_max + 1e-10)
    min_inx = np.argmin(np.abs(f - 0))
    max_inx = np.argmin(np.abs(f - 800))
    return sdb[min_inx:max_inx+1, :], f[min_inx:max_inx+1], t

def _generate_spectrogram_base64(audio, fs=10000, title="Spectrogram"):
    if not _HAVE_VIZ:
        return "<p>Matplotlib not available.</p>"
    spec, f_axis, t_axis = _ali_spec(audio, fs)
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.imshow(spec, aspect='auto', origin='lower',
              extent=[t_axis[0], t_axis[-1], f_axis[0], f_axis[-1]], cmap='hsv')
    ax.set_title(title); ax.set_xlabel("Time (s)"); ax.set_ylabel("Frequency (Hz)")
    fig.tight_layout()
    buf = BytesIO(); fig.savefig(buf, format='png', dpi=150); plt.close(fig); buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return f'<img class="d-block w-100" src="data:image/png;base64,{image_base64}" alt="{title}">'

def _make_carousel(cluster_id: int, class_id: int, encoded_imgs: List[str]) -> str:
    cid = f"carousel_class{class_id}_cluster{cluster_id}"
    inds = "".join([
        f'<button type="button" data-bs-target="#{cid}" data-bs-slide-to="{i}" '
        f'class="{ "active" if i==0 else "" }" aria-current="{ "true" if i==0 else "false" }" '
        f'aria-label="Slide {i+1}"></button>'
        for i in range(len(encoded_imgs))
    ])
    slides = "".join([
        f'''<div class="carousel-item {'active' if i==0 else ''}">
        <div class="d-flex justify-content-center">{img}</div></div>'''
        for i, img in enumerate(encoded_imgs)
    ])
    return f"""
    <div id="{cid}" class="carousel slide" data-bs-interval="false" data-bs-touch="false">
      <div class="carousel-indicators">{inds}</div>
      <div class="carousel-inner">{slides}</div>
      <button class="carousel-control-prev" type="button" data-bs-target="#{cid}" data-bs-slide="prev">
        <span class="carousel-control-prev-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Previous</span>
      </button>
      <button class="carousel-control-next" type="button" data-bs-target="#{cid}" data-bs-slide="next">
        <span class="carousel-control-next-icon" aria-hidden="true"></span>
        <span class="visually-hidden">Next</span>
      </button>
    </div>
    """

# ----------------------------
# Class-agnostic, K-controlled report
# ----------------------------
def generate_report(
    model_path: str,
    dataset_paths: List[str],
    labels: List[str],
    *,
    n_clusters: int = 20,
    cluster_method: str = "gmm",
    output_html: str = "vae_unsup_report.html",
    latent_dim: int = 64,
    seed: int = 42,
    examples_per_cluster: int = 6,
    subset_fraction: float = 0.2,        # ‚Üê NEW: run on 20% by default
    subset_strategy: str = "random",     # ‚Üê 'random' | 'head' | 'tail'
) -> str:
    assert len(dataset_paths) == len(labels), "`dataset_paths` and `labels` must match"
    _set_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = ConvVAE(latent_dim=latent_dim).to(device).to(memory_format=torch.channels_last).eval()
    vae.load_state_dict(torch.load(model_path, map_location=device))
    target_T = vae.input_shape[2]

    all_Z, all_loc, all_raw = [], [], []

    for ds_path, loc_label in zip(dataset_paths, labels):
        mel_dir = _mel_pack_dir(ds_path)
        if os.path.isdir(mel_dir) and os.path.exists(os.path.join(mel_dir, "labels.npy")):
            # --- precomputed pack (.pt)
            full_ds = PrecomputedMelDataset(mel_dir, cache=False)
            idxs = _choose_subset(len(full_ds), fraction=subset_fraction, seed=seed, strategy=subset_strategy)
            ds = Subset(full_ds, list(map(int, idxs)))
            raw_paths = np.array([None] * len(ds), dtype=object)  # no raw backrefs
        else:
            # --- raw .npy (wave/spectrogram)
            full_ds = SpectrogramDataset(ds_path, cache=True)
            # Reconstruct raw .npy list to mirror dataset order
            raw_files = sorted(
                [f for f in os.listdir(ds_path) if f.endswith(".npy") and f != "labels.npy"],
                key=lambda x: int(os.path.splitext(x)[0])
            )
            idxs = _choose_subset(len(full_ds), fraction=subset_fraction, seed=seed, strategy=subset_strategy)
            ds = Subset(full_ds, list(map(int, idxs)))
            raw_paths = np.array([os.path.join(ds_path, raw_files[i]) for i in idxs], dtype=object)

        loader = DataLoader(
            ds, batch_size=256, shuffle=False,
            num_workers=max(2, (os.cpu_count() or 4)//2), pin_memory=True,
            persistent_workers=True, prefetch_factor=4,
            collate_fn=lambda b: _collate_fixed_T(b, target_T)
        )
        Z, _ = _extract_embeddings(vae, loader, device)
        all_Z.append(Z)
        all_loc.extend([loc_label] * Z.shape[0])
        all_raw.append(raw_paths)

    # --- (the rest of your function stays the same) ---
    embeddings = np.vstack(all_Z).astype(np.float32)
    location_labels = np.array(all_loc, dtype=object)
    original_paths = np.concatenate(all_raw, axis=0)

    from sklearn.preprocessing import StandardScaler
    Zs = StandardScaler().fit_transform(embeddings)

    cm = cluster_method.lower()
    if cm == "kmeans":
        from sklearn.cluster import KMeans
        clusterer = KMeans(n_clusters=n_clusters, n_init=10, random_state=seed).fit(Zs)
        cluster_labels = clusterer.labels_
    elif cm == "agglomerative":
        from sklearn.cluster import AgglomerativeClustering
        clusterer = AgglomerativeClustering(n_clusters=n_clusters).fit(Zs)
        cluster_labels = clusterer.labels_
    elif cm == "gmm":
        from sklearn.mixture import GaussianMixture
        clusterer = GaussianMixture(n_components=n_clusters, covariance_type="full", random_state=seed).fit(Zs)
        cluster_labels = clusterer.predict(Zs)
    else:
        raise ValueError("cluster_method must be one of: 'gmm', 'kmeans', 'agglomerative'")

    from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
    def _safe_metric(fn, X, y):
        try: return float(fn(X, y))
        except Exception: return float("nan")
    uniq = np.unique(cluster_labels)
    valid = (len(uniq) > 1) and (len(cluster_labels) > len(uniq))
    sil = _safe_metric(silhouette_score, Zs, cluster_labels) if valid else float("nan")
    dbi = _safe_metric(davies_bouldin_score, Zs, cluster_labels) if valid else float("nan")
    ch  = _safe_metric(calinski_harabasz_score, Zs, cluster_labels) if valid else float("nan")

    if _HAVE_VIZ:
        try:
            reducer = umap.UMAP(n_components=3, n_neighbors=15, min_dist=0.1, metric="cosine", random_state=seed)
            proj_3d = reducer.fit_transform(Zs)
            title = (f"Unsupervised {cm.upper()} (k={n_clusters}, subset={int(subset_fraction*100)}%) ‚Ä¢ "
                     f"Sil={sil:.3f} | DBI={dbi:.3f} | CH={ch:.1f}")
            fig = px.scatter_3d(
                x=proj_3d[:, 0], y=proj_3d[:, 1], z=proj_3d[:, 2],
                color=[str(int(c)) for c in cluster_labels],
                symbol=location_labels,
                hover_data={"Cluster": cluster_labels},
                title=title, opacity=0.85, height=800
            )
            umap_html = pio.to_html(fig, include_plotlyjs="cdn", full_html=False)
        except Exception as e:
            umap_html = f"<p class='text-danger'>UMAP/Plotly failed: {e}</p>"
    else:
        umap_html = "<p>UMAP/Plotly not available.</p>"

        # --- Per-cluster sections with exemplar carousels
    cluster_blocks = []
    for c in sorted(set(cluster_labels)):
        idxs = np.where(cluster_labels == c)[0]
        if idxs.size == 0:
            continue

        # Location distribution
        from collections import Counter
        loc_counts = Counter(location_labels[idxs])
        meta_html = "<p><strong>Location Distribution:</strong></p><ul>" + "".join(
            f"<li><b>{loc}</b>: {cnt} ({cnt/len(idxs):.1%})</li>" for loc, cnt in loc_counts.items()
        ) + "</ul>"

        # Nearest-to-center exemplars (in standardized latent space)
        center = np.mean(Zs[idxs], axis=0, keepdims=True)
        dists = np.linalg.norm(Zs[idxs] - center, axis=1)
        order = np.argsort(dists)
        show = idxs[order[:min(examples_per_cluster, len(order))]]

        imgs = []
        for i, j in enumerate(show):
            p = original_paths[j]
            try:
                if p is None:
                    imgs.append("<p>Raw audio not available (precomputed pack).</p>")
                else:
                    arr = np.load(p, mmap_mode="r")
                    if arr.ndim == 1:
                        # waveform ‚Üí make spectrogram preview
                        imgs.append(_generate_spectrogram_base64(
                            arr.astype(np.float32),
                            title=f"#{i+1} ‚Ä¢ loc={location_labels[j]} ‚Ä¢ C={int(c)}"
                        ))
                    elif arr.ndim == 2:
                        # already a (F,T) spectrogram array
                        if not _HAVE_VIZ:
                            imgs.append("<p>Image backend unavailable.</p>")
                        else:
                            import matplotlib.pyplot as plt
                            from io import BytesIO
                            import base64
                            fig, ax = plt.subplots(figsize=(8, 3))
                            ax.imshow(arr, aspect='auto', origin='lower', cmap='viridis')
                            ax.set_title(f"#{i+1} ‚Ä¢ loc={location_labels[j]} ‚Ä¢ C={int(c)}")
                            ax.set_xlabel("Frames"); ax.set_ylabel("Mel bins")
                            fig.tight_layout()
                            buf = BytesIO(); fig.savefig(buf, format='png', dpi=150); plt.close(fig); buf.seek(0)
                            b64 = base64.b64encode(buf.read()).decode("utf-8")
                            imgs.append(f'<img class="d-block w-100" src="data:image/png;base64,{b64}" alt="Spec">')
                    else:
                        imgs.append(f"<p>Unsupported array shape {arr.shape}</p>")
            except Exception as e:
                imgs.append(f"<p class='text-danger'>Error loading sample: {e}</p>")

        carousel_html = _make_carousel(cluster_id=int(c), class_id=-1, encoded_imgs=imgs)
        block = f"<div class='col-md-6 mb-4'><h4>Cluster {c} ‚Ä¢ size={len(idxs)}</h4>{meta_html}{carousel_html}</div>"
        cluster_blocks.append(block)

    # stitch cluster blocks into rows
    cluster_html = ""
    for i in range(0, len(cluster_blocks), 2):
        cluster_html += "<div class='row'>" + "".join(cluster_blocks[i:i+2]) + "</div>"

    # --- Final HTML
    html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Unsupervised Clustering Report (VAE Œº)</title>
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
        <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
        <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
        <style>
            body {{ font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; }}
            h1 {{ color: #2c3e50; }}
            h2, h4 {{ color: #34495e; }}
            .section {{ margin-bottom: 60px; padding: 20px; background: white;
                        border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
        </style>
    </head>
    <body>
      <h1 class='mb-4'>Unsupervised Latent Clustering (VAE Œº)</h1>
      <p><b>Method:</b> {cm.upper()} ‚Ä¢ <b>k</b>={n_clusters} ‚Ä¢ <b>subset</b>={int(subset_fraction*100)}%</p>
      <p><b>Silhouette</b>={sil:.3f} ‚Ä¢ <b>DBI</b>={dbi:.3f} ‚Ä¢ <b>CH</b>={ch:.1f}</p>
      <hr>
      <div class='section'>
        <div class="container">
          <div class="row justify-content-center mb-4">
            <div class="col-md-12 d-flex justify-content-center">{umap_html}</div>
          </div>
        </div>
        {cluster_html}
      </div>
    </body></html>
    """

    # --- Save and return
    with open(output_html, "w", encoding="utf-8") as f:
        f.write(html)
    return output_html

# ----------------------------
# Example usage
# ----------------------------
# report = generate_report(
#     model_path="vae_best.pth",
#     dataset_paths=["/notebooks/dataset_preprocessed"],
#     labels=["PR_U1137"],
#     n_clusters=60,
#     cluster_method="gmm",
#     subset_fraction=0.20,      # ‚Üê 20% of items
#     subset_strategy="random",  # or 'head' / 'tail'
#     output_html="vae_unsup_k60_subset20.html",
#     latent_dim=64,
#     examples_per_cluster=5,
#     seed=45
# )
# print("Saved:", report)

Report and training runner
---------------------------

In [10]:
# 1) Make the fixed split once (outside training)
train_idx, test_idx = make_fixed_train_test_indices("/notebooks/dataset_preprocessed", test_size=0.2, random_state=42)

# 2) Train ONLY on the 80% train indices (a small val is carved from that 80%)
# train_res = run_training(
#     audio_dir="/notebooks/dataset_preprocessed",
#     train_indices=train_idx,
#     epochs=50,
#     save_path="vae_best.pth",
#     checkpoint_path="vae_train.ckpt",   # continue overwriting
#     #resume_from="vae_train.ckpt"        # <- resume here
# )

# 3) Evaluate ONLY on the same held-out 20% with configurable K
test_res = run_inference(
    audio_dir="/notebooks/dataset_preprocessed",
    test_indices=test_idx,
    checkpoint_path="vae_best.pth",
    kmeans_k=60,
    do_umap=True,
    umap_components=3,
    umap_html_path="vae_test_umap.html",
    l2_normalize=True,
    pca_dim=64,   # try None vs 64 and compare
)
print(test_res["metrics"])

[INFO] Inference source dir: /notebooks/dataset_preprocessed__mels64_fft1024_h512
[INFO] Full label dist: {1: 69829, 2: 70932, 3: 70044, 4: 70779, 5: 70983, 6: 60705}
[INFO] TEST label dist: {1: 13966, 2: 14186, 3: 14009, 4: 14156, 5: 14197, 6: 12141}



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



Saved UMAP visualization to 'vae_test_umap.html'
{'ari': 0.021293397892635138, 'ami': 0.10507615485475799, 'hungarian_accuracy': 0.07834976710422842, 'silhouette': 0.07351772487163544, 'davies_bouldin': 3.337626608126009, 'calinski_harabasz': 914.3685113584048}
