# Audio-specific mask strategy analysis

Compare **AudioSpecificStrategy** synthetic masks with **real anomaly patterns** from the DCASE test set (see `anomaly_spectrogram_study.ipynb`).

Goals:
- Recompute spatial localization of real anomalies (mean |anomaly − mean_normal| per machine type/ID).
- Generate masks with **current** vs **proposed** parameters (one band, several time segments of consecutive frames).
- Visualize masks at spectrogram size (n_mels × T) and resized to **q_shape** (q_top / q_bot spatial dims) used for codebook replacement.
- Recommend parameter ranges so synthetic masks better match the anomalies seen in real spectrograms for different machines.

## 1. Setup and data loading

Reuse the same data loading as `anomaly_spectrogram_study.ipynb`: train (for normalization), test grouped by (machine_id, label).

In [None]:
import sys
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

_cwd = Path(".").resolve()
PROJECT_ROOT = _cwd.parent if _cwd.name == "notebooks" else _cwd
sys.path.insert(0, str(PROJECT_ROOT))

from src.data.dataset import DCASE2020Task2LogMelDataset, DCASE2020Task2TestDataset
from src.utils.anomalies.anomaly_map import AudioSpecificStrategy

DATA_PATH = PROJECT_ROOT / "data" / "dcase2020-task2-dev-dataset"
if not DATA_PATH.exists():
    DATA_PATH = PROJECT_ROOT / "../data/dcase2020-task2-dev-dataset"

MACHINE_TYPES = ["fan", "pump"]
n_mels, T = 128, 320
spectrogram_shape = (n_mels, T)

def get_q_shape(n_mels: int, T: int) -> tuple[int, int]:
    """VQ-VAE 4× downsampling: H = n_mels//4, W = T//4."""
    return (max(1, n_mels // 4), max(1, T // 4))

q_shape = get_q_shape(n_mels, T)
print(f"spectrogram_shape={spectrogram_shape}, q_shape={q_shape}")

In [None]:
def load_test_grouped(data_path: Path, machine_type: str):
    """Load train (for norm), test; group test by (machine_id, label)."""
    train_ds = DCASE2020Task2LogMelDataset(
        root=str(data_path), machine_type=machine_type, normalize=True
    )
    _, _, nm, Tm = train_ds.data.shape
    test_ds = DCASE2020Task2TestDataset(
        root=str(data_path), machine_type=machine_type,
        mean=train_ds.mean, std=train_ds.std, target_T=train_ds.target_T,
    )
    grouped = defaultdict(lambda: {0: [], 1: []})
    for idx in range(len(test_ds)):
        spec, label, machine_id = test_ds[idx]
        if spec.dim() == 3:
            spec = spec.squeeze(0)
        grouped[machine_id][label].append(spec.numpy())
    return train_ds, test_ds, grouped, nm, Tm

def stacked_by_id_label(grouped):
    out = {}
    for mid, by_label in grouped.items():
        out[mid] = {
            0: np.stack(by_label[0]) if by_label[0] else np.empty((0, 0, 0)),
            1: np.stack(by_label[1]) if by_label[1] else np.empty((0, 0, 0)),
        }
    return out

data_by_type = {}
for mt in MACHINE_TYPES:
    try:
        train_ds, test_ds, grouped, nm, Tm = load_test_grouped(DATA_PATH, mt)
        data_by_type[mt] = {
            "train_ds": train_ds, "test_ds": test_ds, "grouped": grouped,
            "stacked": stacked_by_id_label(grouped), "n_mels": nm, "T": Tm
        }
        print(f"{mt}: n_mels={nm}, T={Tm}, IDs={sorted(grouped.keys())}")
    except FileNotFoundError as e:
        print(f"Skip {mt}: {e}")

## 2. Real anomaly spatial localization

For each (machine_type, machine_id): mean normal spectrogram, mean anomalous, and **dev** = mean over anomalous samples of |anomaly − mean_normal|. This shows which (mel, time) regions differ most in real data.

In [None]:
def compute_spatial_localization(stacked: dict, n_mels: int, T: int):
    """Per machine_id: mean_norm, mean_anom, diff, dev (spatial localization)."""
    result = {}
    for mid, arrs in stacked.items():
        norm_arr = arrs[0]
        anom_arr = arrs[1]
        if norm_arr.size == 0 or anom_arr.size == 0:
            continue
        mean_norm = norm_arr.mean(axis=0)
        mean_anom = anom_arr.mean(axis=0)
        diff = mean_anom - mean_norm
        dev = np.abs(anom_arr - mean_norm).mean(axis=0)
        result[mid] = {
            "mean_norm": mean_norm, "mean_anom": mean_anom,
            "diff": diff, "dev": dev,
            "n_normal": norm_arr.shape[0], "n_anomaly": anom_arr.shape[0],
        }
    return result

real_localization = {}
for mt in MACHINE_TYPES:
    if mt not in data_by_type:
        continue
    st = data_by_type[mt]["stacked"]
    nm, Tm = data_by_type[mt]["n_mels"], data_by_type[mt]["T"]
    real_localization[mt] = compute_spatial_localization(st, nm, Tm)

def print_deviation_summary(real_localization, n_mels, T):
    for mt, by_id in real_localization.items():
        print(f"## {mt}")
        for mid, data in by_id.items():
            dev = data["dev"]
            mel_peak = np.unravel_index(np.argmax(dev), dev.shape)[0]
            time_peak = np.unravel_index(np.argmax(dev), dev.shape)[1]
            band_frac = mel_peak / max(1, n_mels)
            time_frac = time_peak / max(1, T)
            diff = data["diff"]
            print(f"  {mid}: peak at mel≈{mel_peak} ({band_frac:.2f}), time≈{time_peak} ({time_frac:.2f}); diff range [{diff.min():.3f}, {diff.max():.3f}]")

print_deviation_summary(real_localization, n_mels, T)

In [None]:
def plot_real_localization(real_localization, machine_type: str, machine_id: str, n_mels: int, T: int):
    if machine_type not in real_localization or machine_id not in real_localization[machine_type]:
        return
    data = real_localization[machine_type][machine_id]
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(data["mean_norm"], aspect="auto", origin="lower", cmap="magma")
    axes[0].set_title(f"Mean normal ({machine_type} / {machine_id})")
    axes[0].set_ylabel("mel bin")
    axes[1].imshow(data["mean_anom"], aspect="auto", origin="lower", cmap="magma")
    axes[1].set_title("Mean anomaly")
    axes[2].imshow(data["dev"], aspect="auto", origin="lower", cmap="viridis")
    axes[2].set_title("Spatial localization: mean |anomaly − mean_norm|")
    for ax in axes:
        ax.set_xlabel("time frame")
    plt.tight_layout()
    plt.show()

for mt in MACHINE_TYPES:
    if mt not in real_localization:
        continue
    for mid in sorted(real_localization[mt].keys())[:2]:
        plot_real_localization(real_localization, mt, mid, n_mels, T)

## 3. AudioSpecificStrategy: current vs proposed parameters

- **Current**: one band (5–15% of n_mels), many segments (10–160), segment length up to T//4. This can produce very dense or scattered masks.
- **Proposed (paper-like)**: **one band only**, **fewer time segments** (e.g. 3–20), each segment is a **consecutive block** of frames (random start, random length). The chosen band and segments are marked in M; M is then resized to q_shape for codebook replacement in q_top / q_bot.

We generate masks at **spectrogram resolution** (n_mels × T) for comparison with real localization, then show the **resized** version (q_shape) used in training.

In [None]:
def generate_audio_specific_mask_spectrogram_size(
    n_mels: int, T: int,
    min_band_frac: float, max_band_frac: float,
    min_segments: int, max_segments: int,
    min_seg_len: int, max_seg_len: int,
    seed: int | None = None,
) -> np.ndarray:
    """
    Generate one mask at spectrogram size (n_mels, T).
    One band; several time segments of consecutive frames (random start, random length).
    """
    if seed is not None:
        np.random.seed(seed)
        import random
        random.seed(seed)
    M = np.zeros((n_mels, T), dtype=np.float32)
    band_width = int(n_mels * np.random.uniform(min_band_frac, max_band_frac))
    band_width = max(1, band_width)
    f_low = np.random.randint(0, max(1, n_mels - band_width))
    f_high = min(f_low + band_width, n_mels)
    n_seg = np.random.randint(min_segments, max_segments + 1)
    for _ in range(n_seg):
        seg_len = np.random.randint(min_seg_len, min(max_seg_len + 1, T))
        seg_len = max(1, seg_len)
        t_start = np.random.randint(0, max(1, T - seg_len))
        t_end = min(t_start + seg_len, T)
        M[f_low:f_high, t_start:t_end] = 1.0
    return M

# Parameter presets
CURRENT_PARAMS = {
    "min_band_fraction": 0.05, "max_band_fraction": 0.15,
    "min_segments": 10, "max_segments": 160,
    "min_seg_len": 1, "max_seg_len": max(1, T // 4),
}
PROPOSED_PARAMS = {
    "min_band_fraction": 0.05, "max_band_fraction": 0.20,
    "min_segments": 3, "max_segments": 20,
    "min_seg_len": 5, "max_seg_len": 60,
}
NARROW_BAND_FEW_SEGMENTS = {
    "min_band_fraction": 0.03, "max_band_fraction": 0.10,
    "min_segments": 2, "max_segments": 12,
    "min_seg_len": 10, "max_seg_len": 80,
}

In [None]:
num_samples = 6
presets = {
    "current (many segments)": CURRENT_PARAMS,
    "proposed (few segments, consecutive)": PROPOSED_PARAMS,
    "narrow band, few segments": NARROW_BAND_FEW_SEGMENTS,
}

fig, axes = plt.subplots(len(presets), num_samples, figsize=(14, 3 * len(presets)))
if len(presets) == 1:
    axes = axes[np.newaxis, :]
for i, (name, params) in enumerate(presets.items()):
    for j in range(num_samples):
        M = generate_audio_specific_mask_spectrogram_size(
            n_mels, T,
            params["min_band_fraction"], params["max_band_fraction"],
            params["min_segments"], params["max_segments"],
            params["min_seg_len"], params["max_seg_len"],
            seed=None,
        )
        ax = axes[i, j]
        ax.imshow(M, aspect="auto", origin="lower", cmap="gray", vmin=0, vmax=1)
        ax.set_xlabel("time")
        if j == 0:
            ax.set_ylabel("mel")
        if j == 0:
            ax.set_ylabel(f"{name}\nmel")
        if i == 0:
            ax.set_title(f"Sample {j+1}")
plt.suptitle("Audio-specific masks at spectrogram size (n_mels × T)", y=1.02)
plt.tight_layout()
plt.show()

## 4. Same masks resized to q_shape (used in training)

The model resizes M to q_shape (H_q × W_q) before applying codebook replacement to q_bot and q_top. Here we show the same parameter presets after resize to q_shape.

In [None]:
def resize_mask_to_q(M: np.ndarray, q_shape: tuple[int, int]) -> np.ndarray:
    """M: (n_mels, T) or (1, 1, n_mels, T). Return (H_q, W_q)."""
    t = torch.from_numpy(M).float()
    if t.dim() == 2:
        t = t.unsqueeze(0).unsqueeze(0)
    t = F.interpolate(t, size=q_shape, mode="nearest")
    return t.squeeze().numpy()

fig, axes = plt.subplots(len(presets), num_samples, figsize=(14, 3 * len(presets)))
if len(presets) == 1:
    axes = axes[np.newaxis, :]
for i, (name, params) in enumerate(presets.items()):
    for j in range(num_samples):
        M = generate_audio_specific_mask_spectrogram_size(
            n_mels, T,
            params["min_band_fraction"], params["max_band_fraction"],
            params["min_segments"], params["max_segments"],
            params["min_seg_len"], params["max_seg_len"],
            seed=j + 42,
        )
        M_q = resize_mask_to_q(M, q_shape)
        ax = axes[i, j]
        ax.imshow(M_q, aspect="auto", origin="lower", cmap="gray", vmin=0, vmax=1)
        ax.set_xlabel("time (q)")
        if j == 0:
            ax.set_ylabel(f"{name}\nfreq (q)")
        if i == 0:
            ax.set_title(f"q_shape {q_shape}")
plt.suptitle("Masks resized to q_shape (q_top / q_bot spatial dims)", y=1.02)
plt.tight_layout()
plt.show()

## 5. Side-by-side: real localization vs synthetic masks

Compare one real spatial localization map (dev) with synthetic masks from the three presets for a chosen machine type / ID.

In [None]:
machine_type = "fan"
machine_id = "id_00"
if machine_type in real_localization and machine_id in real_localization[machine_type]:
    dev = real_localization[machine_type][machine_id]["dev"]
    n_rows = 1 + len(presets)
    fig, axes = plt.subplots(n_rows, 4, figsize=(14, 3 * n_rows))
    axes[0, 0].imshow(dev, aspect="auto", origin="lower", cmap="viridis")
    axes[0, 0].set_title(f"Real: spatial localization\n({machine_type} / {machine_id})")
    axes[0, 0].set_ylabel("mel")
    for k in range(1, 4):
        axes[0, k].axis("off")
    for i, (name, params) in enumerate(presets.items()):
        for j in range(4):
            M = generate_audio_specific_mask_spectrogram_size(
                n_mels, T,
                params["min_band_fraction"], params["max_band_fraction"],
                params["min_segments"], params["max_segments"],
                params["min_seg_len"], params["max_seg_len"],
                seed=j + 100 * (i + 1),
            )
            axes[i + 1, j].imshow(M, aspect="auto", origin="lower", cmap="gray", vmin=0, vmax=1)
            axes[i + 1, j].set_xlabel("time")
            if j == 0:
                axes[i + 1, j].set_ylabel(f"{name}\nmel")
            if i == 0 and j > 0:
                axes[i + 1, j].set_title(f"Synthetic sample {j+1}")
    plt.suptitle("Real anomaly localization vs synthetic audio-specific masks", y=1.02)
    plt.tight_layout()
    plt.show()
else:
    print(f"No data for {machine_type} / {machine_id}")

## 6. Using the actual AudioSpecificStrategy class

Verify that the strategy class (with optional resizing to q_shape) produces masks consistent with the presets. The dataset can pass spectrogram_shape as q_shape so the mask stays at (n_mels, T); the model then resizes to q in `forward_train`.

In [None]:
device = torch.device("cpu")

strategy_current = AudioSpecificStrategy(
    spectrogram_shape=spectrogram_shape,
    q_shape=q_shape,
    n_mels=n_mels,
    T=T,
    min_band_fraction=0.05,
    max_band_fraction=0.15,
    min_segments=10,
    max_segments=160,
)

strategy_proposed = AudioSpecificStrategy(
    spectrogram_shape=spectrogram_shape,
    q_shape=q_shape,
    n_mels=n_mels,
    T=T,
    min_band_fraction=0.05,
    max_band_fraction=0.20,
    min_segments=3,
    max_segments=20,
)

M_current = strategy_current(4, device)
M_proposed = strategy_proposed(4, device)
print(f"Current strategy output shape: {M_current.shape}")  # (4, 1, H_q, W_q)
print(f"Proposed strategy output shape: {M_proposed.shape}")

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(14, 6))
for j in range(4):
    m = M_current[j, 0].numpy()
    axes[0, j].imshow(m, aspect="auto", origin="lower", cmap="gray", vmin=0, vmax=1)
    axes[0, j].set_title(f"Current (q_shape) {j+1}")
    axes[0, j].set_xlabel("time")
for j in range(4):
    m = M_proposed[j, 0].numpy()
    axes[1, j].imshow(m, aspect="auto", origin="lower", cmap="gray", vmin=0, vmax=1)
    axes[1, j].set_title(f"Proposed (q_shape) {j+1}")
    axes[1, j].set_xlabel("time")
axes[0, 0].set_ylabel("current\nfreq (q)")
axes[1, 0].set_ylabel("proposed\nfreq (q)")
plt.suptitle("AudioSpecificStrategy outputs (resized to q_shape)", y=1.02)
plt.tight_layout()
plt.show()

## 7. Summary and recommendations

- **Real anomalies** (from the study): often localized in a specific frequency band and time region; deviation peak and extent vary by machine type and ID.
- **Current AudioSpecificStrategy**: many segments (10–160) with short max length (T//4) can fill the band with a dense or scattered pattern that may not match real “one band + a few distinct segments” structure.
- **Proposed**: **one band** (already the case), **fewer segments** (e.g. 3–20), **consecutive frames** per segment with min/max length (e.g. 5–60 or 10–80 frames). This better matches “anomalies in a band at several time intervals.”
- **Mask flow**: M is generated at spectrogram size (n_mels × T), then resized to **q_shape** (H_q × W_q) for codebook replacement in q_bot and q_top. The dataset can return M at spectrogram size; the model resizes to q in `forward_train`.

To align with the paper and real data, consider updating `AudioSpecificStrategy` defaults to something like: `min_segments=3`, `max_segments=20`, and segment length bounds (e.g. `min_seg_len=5`, `max_seg_len=60`) instead of a single `T//4` cap. You can also add explicit `min_seg_len` / `max_seg_len` parameters to the strategy constructor.