# Anomaly spectrogram study (Part 1)

Characterize **where** and **how** real anomalies appear in the log-mel spectrogram, **per machine type and per machine_id**, to inform the anomaly generation module.

- Load test set per machine type; group by `machine_id` and label (normal vs anomaly).
- For each machine_id: mean normal spectrogram, mean anomalous spectrogram, difference heatmaps.
- Spatial localization: mean |anomaly − mean_normal| over anomalous set.
- Optional: per-band / per-frame stats and reconstruction residual (Stage 1 VQ-VAE).
- Short summary per machine type for mask-prior design.

## 1. Setup and data loading

Paths, imports, and load train/test datasets per machine type (train for normalization; test grouped by machine_id and label).

In [None]:
import sys
from pathlib import Path
from collections import defaultdict

_cwd = Path(".").resolve()
PROJECT_ROOT = _cwd.parent if _cwd.name == "notebooks" else _cwd
sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt

from src.data.dataset import DCASE2020Task2LogMelDataset, DCASE2020Task2TestDataset

DATA_PATH = PROJECT_ROOT / "../data/dcase2020-task2-dev-dataset"
MACHINE_TYPES = ["fan", "pump"]

if not DATA_PATH.exists():
    DATA_PATH = PROJECT_ROOT / "data/dcase2020-task2-dev-dataset"
print(f"Data path: {DATA_PATH}, exists: {DATA_PATH.exists()}")

In [None]:
def load_test_grouped_by_id_and_label(data_path: Path, machine_type: str):
    """Load train (for norm), then test; group test samples by (machine_id, label)."""
    train_ds = DCASE2020Task2LogMelDataset(
        root=str(data_path),
        machine_type=machine_type,
        normalize=True,
    )
    _, _, n_mels, T = train_ds.data.shape
    test_ds = DCASE2020Task2TestDataset(
        root=str(data_path),
        machine_type=machine_type,
        mean=train_ds.mean,
        std=train_ds.std,
        target_T=train_ds.target_T,
    )
    grouped = defaultdict(lambda: {0: [], 1: []})
    for idx in range(len(test_ds)):
        spec, label, machine_id = test_ds[idx]
        if spec.dim() == 3:
            spec = spec.squeeze(0)
        grouped[machine_id][label].append(spec.numpy())
    return train_ds, test_ds, grouped, n_mels, T

def stacked_by_id_label(grouped):
    """Convert grouped dict to arrays: id -> {0: (N_norm, n_mels, T), 1: (N_anom, n_mels, T)}."""
    out = {}
    for mid, by_label in grouped.items():
        out[mid] = {
            0: np.stack(by_label[0]) if by_label[0] else np.empty((0, 0, 0)),
            1: np.stack(by_label[1]) if by_label[1] else np.empty((0, 0, 0)),
        }
    return out

data_by_type = {}
for mt in MACHINE_TYPES:
    try:
        train_ds, test_ds, grouped, n_mels, T = load_test_grouped_by_id_and_label(DATA_PATH, mt)
        data_by_type[mt] = {"train_ds": train_ds, "test_ds": test_ds, "grouped": grouped,
                            "stacked": stacked_by_id_label(grouped), "n_mels": n_mels, "T": T}
        print(f"{mt}: n_mels={n_mels}, T={T}, IDs={sorted(grouped.keys())}")
        for mid in sorted(grouped.keys()):
            n_n, n_a = len(grouped[mid][0]), len(grouped[mid][1])
            print(f"  {mid}: normal={n_n}, anomaly={n_a}")
    except FileNotFoundError as e:
        print(f"Skip {mt}: {e}")

## 2. Mean spectrograms and difference heatmaps (per machine_id)

For each machine type and machine_id: mean normal, mean anomalous, and difference (anomalous − normal).

In [None]:
def plot_mean_and_diff(stacked, machine_type: str, machine_id: str, n_mels: int, T: int):
    norm_arr = stacked[machine_id][0]
    anom_arr = stacked[machine_id][1]
    if norm_arr.size == 0 or anom_arr.size == 0:
        return
    mean_norm = norm_arr.mean(axis=0)
    mean_anom = anom_arr.mean(axis=0)
    diff = mean_anom - mean_norm
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(mean_norm, aspect="auto", origin="lower", cmap="magma")
    axes[0].set_title(f"Mean normal ({machine_type} {machine_id})")
    axes[0].set_xlabel("Time")
    axes[0].set_ylabel("Mel")
    axes[1].imshow(mean_anom, aspect="auto", origin="lower", cmap="magma")
    axes[1].set_title(f"Mean anomalous")
    axes[1].set_xlabel("Time")
    axes[1].set_ylabel("Mel")
    v = np.abs(diff).max() or 1
    axes[2].imshow(diff, aspect="auto", origin="lower", cmap="RdBu_r", vmin=-v, vmax=v)
    axes[2].set_title("Difference (anomalous − normal)")
    axes[2].set_xlabel("Time")
    axes[2].set_ylabel("Mel")
    plt.suptitle(f"{machine_type} / {machine_id}")
    plt.tight_layout()
    plt.show()

for mt, data in data_by_type.items():
    stacked = data["stacked"]
    n_mels, T = data["n_mels"], data["T"]
    for mid in sorted(stacked.keys()):
        plot_mean_and_diff(stacked, mt, mid, n_mels, T)

## 3. Spatial localization of deviation

Per (machine_type, machine_id): mean |anomaly − mean_normal| over the anomalous set (which (mel, time) regions differ most).

In [None]:
def plot_localization(stacked, machine_type: str, machine_id: str, n_mels: int, T: int):
    norm_arr = stacked[machine_id][0]
    anom_arr = stacked[machine_id][1]
    if norm_arr.size == 0 or anom_arr.size == 0:
        return
    mean_norm = norm_arr.mean(axis=0)
    dev = np.abs(anom_arr - mean_norm).mean(axis=0)
    fig, ax = plt.subplots(figsize=(10, 4))
    im = ax.imshow(dev, aspect="auto", origin="lower", cmap="hot")
    ax.set_title(f"Spatial localization: mean |anomaly − mean_normal| ({machine_type} / {machine_id})")
    ax.set_xlabel("Time")
    ax.set_ylabel("Mel")
    plt.colorbar(im, ax=ax, label="Mean abs diff")
    plt.tight_layout()
    plt.show()

for mt, data in data_by_type.items():
    stacked = data["stacked"]
    n_mels, T = data["n_mels"], data["T"]
    for mid in sorted(stacked.keys()):
        plot_localization(stacked, mt, mid, n_mels, T)

## 4. Per-band and per-frame statistics (optional)

Mean energy per mel bin (over time) and per time frame (over mels) for normal vs anomaly; variance over time per band.

In [None]:
def per_band_per_frame_stats(stacked, machine_id: str):
    norm_arr = stacked[machine_id][0]
    anom_arr = stacked[machine_id][1]
    if norm_arr.size == 0 or anom_arr.size == 0:
        return None
    n_mels, T = norm_arr.shape[1], norm_arr.shape[2]
    mean_norm_band = norm_arr.mean(axis=(0, 2))
    mean_anom_band = anom_arr.mean(axis=(0, 2))
    mean_norm_frame = norm_arr.mean(axis=(0, 1))
    mean_anom_frame = anom_arr.mean(axis=(0, 1))
    var_norm_band = norm_arr.var(axis=2).mean(axis=0)
    var_anom_band = anom_arr.var(axis=2).mean(axis=0)
    return {
        "mean_norm_band": mean_norm_band, "mean_anom_band": mean_anom_band,
        "mean_norm_frame": mean_norm_frame, "mean_anom_frame": mean_anom_frame,
        "var_norm_band": var_norm_band, "var_anom_band": var_anom_band,
        "n_mels": n_mels, "T": T,
    }

def plot_band_frame(stats, machine_type: str, machine_id: str):
    if stats is None:
        return
    fig, axes = plt.subplots(2, 1, figsize=(10, 6))
    m, T = stats["n_mels"], stats["T"]
    axes[0].plot(stats["mean_norm_band"], label="normal")
    axes[0].plot(stats["mean_anom_band"], label="anomaly")
    axes[0].set_ylabel("Mean energy")
    axes[0].set_xlabel("Mel bin")
    axes[0].set_title(f"Per-band mean ({machine_type} / {machine_id})")
    axes[0].legend()
    axes[1].plot(stats["mean_norm_frame"], label="normal")
    axes[1].plot(stats["mean_anom_frame"], label="anomaly")
    axes[1].set_ylabel("Mean energy")
    axes[1].set_xlabel("Time frame")
    axes[1].set_title("Per-frame mean")
    axes[1].legend()
    plt.tight_layout()
    plt.show()

for mt, data in data_by_type.items():
    stacked = data["stacked"]
    for mid in sorted(stacked.keys()):
        stats = per_band_per_frame_stats(stacked, mid)
        plot_band_frame(stats, mt, mid)

## 5. Reconstruction residual (optional)

Encode normal and anomalous with trained Stage 1 VQ-VAE; decode; plot |input − reconstructed| to see what the model "fails" to represent.

In [None]:
CKPT_DIR = PROJECT_ROOT / "checkpoints"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_stage1_if_available(machine_type: str, n_mels: int, T: int):
    from src.models.vq_vae.autoencoders import VQ_VAE_2Layer
    vq = VQ_VAE_2Layer(
        num_hiddens=128,
        num_residual_layers=2,
        num_residual_hiddens=64,
        num_embeddings=(1024, 4096),
        embedding_dim=128,
        commitment_cost=0.25,
        decay=0.99,
    )
    ckpt = CKPT_DIR / "stage1" / machine_type / "stage1_{}_best.pt".format(machine_type)
    if ckpt.exists():
        state = torch.load(ckpt, map_location="cpu", weights_only=True)
        vq.load_state_dict(state["model_state_dict"])
    return vq.to(DEVICE).eval()

def plot_reconstruction_residual(vq, stacked, machine_type: str, machine_id: str, n_mels: int, T: int, max_samples: int = 5):
    norm_arr = stacked[machine_id][0]
    anom_arr = stacked[machine_id][1]
    if norm_arr.size == 0 or anom_arr.size == 0:
        return
    x_n = torch.from_numpy(norm_arr[:max_samples]).float().unsqueeze(1).to(DEVICE)
    x_a = torch.from_numpy(anom_arr[:max_samples]).float().unsqueeze(1).to(DEVICE)
    with torch.no_grad():
        q_bot_n, q_top_n = vq.encode(x_n)
        rec_n = vq.decode_general(q_bot_n, q_top_n)
        q_bot_a, q_top_a = vq.encode(x_a)
        rec_a = vq.decode_general(q_bot_a, q_top_a)
    res_n = (x_n - rec_n).abs().cpu().numpy()
    res_a = (x_a - rec_a).abs().cpu().numpy()
    mean_res_n = res_n.mean(axis=0).squeeze(0)
    mean_res_a = res_a.mean(axis=0).squeeze(0)
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].imshow(mean_res_n, aspect="auto", origin="lower", cmap="hot")
    axes[0].set_title(f"Mean |normal − recon| ({machine_type} / {machine_id})")
    axes[0].set_xlabel("Time")
    axes[0].set_ylabel("Mel")
    axes[1].imshow(mean_res_a, aspect="auto", origin="lower", cmap="hot")
    axes[1].set_title(f"Mean |anomaly − recon|")
    axes[1].set_xlabel("Time")
    axes[1].set_ylabel("Mel")
    plt.tight_layout()
    plt.show()

for mt, data in data_by_type.items():
    n_mels, T = data["n_mels"], data["T"]
    vq = load_stage1_if_available(mt, n_mels, T)
    if vq is None:
        print(f"No Stage1 checkpoint for {mt}; skip reconstruction residual.")
        continue
    for mid in sorted(data["stacked"].keys()):
        plot_reconstruction_residual(vq, data["stacked"], mt, mid, n_mels, T)

## 6. Short summary (for mask-prior design)

Bullet list per machine type (or per ID): where anomalies tend to show (bands, regions, variance). Use this as prior for improving AudioSpecificStrategy / Perlin in the anomaly generation module.

In [None]:
summaries = {}
for mt, data in data_by_type.items():
    stacked = data["stacked"]
    n_mels, T = data["n_mels"], data["T"]
    summaries[mt] = []
    for mid in sorted(stacked.keys()):
        norm_arr = stacked[mid][0]
        anom_arr = stacked[mid][1]
        if norm_arr.size == 0 or anom_arr.size == 0:
            summaries[mt].append(f"{mid}: no data")
            continue
        mean_norm = norm_arr.mean(axis=0)
        diff = anom_arr.mean(axis=0) - mean_norm
        dev = np.abs(anom_arr - mean_norm).mean(axis=0)
        mel_peak = np.unravel_index(np.argmax(dev), dev.shape)[0]
        time_peak = np.unravel_index(np.argmax(dev), dev.shape)[1]
        band_frac = mel_peak / max(1, n_mels)
        time_frac = time_peak / max(1, T)
        summaries[mt].append(
            f"{mid}: deviation peak at mel≈{mel_peak} ({band_frac:.2f}), time≈{time_peak} ({time_frac:.2f}); diff range [{diff.min():.3f}, {diff.max():.3f}]."
        )
for mt in MACHINE_TYPES:
    if mt not in summaries:
        continue
    print(f"## {mt}")
    for s in summaries[mt]:
        print(f"- {s}")
    print()