# Inference comparison: Normal vs synthetic vs real anomalous

Load a **normal** spectrogram, a **synthetically generated anomalous** sample (normal + synthetic mask through `forward_train`), and a **real anomalous** sample from the test set; feed them through the trained stage-2 (sDSR) model and plot all inference steps for comparison.

Requires: stage1 checkpoint, stage2 checkpoint (e.g. fan), and DCASE2020 Task 2 dev data.

In [None]:
# Setup and paths
import sys
from pathlib import Path

_cwd = Path(".").resolve()
PROJECT_ROOT = _cwd.parent if _cwd.name == "notebooks" else _cwd
sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt

MACHINE_TYPE = "fan"
DATA_PATH = PROJECT_ROOT / "../data/dcase2020-task2-dev-dataset"  # adjust to your DCASE root
CKPT_DIR = PROJECT_ROOT / "checkpoints"
# Stage1: use multi-machine or single-machine ckpt (e.g. fan-only)
STAGE1_CKPT = CKPT_DIR / "stage1" / "ToyCar+ToyConveyor+fan+pump+slider+valve" / "stage1_ToyCar+ToyConveyor+fan+pump+slider+valve_final.pt"
if not STAGE1_CKPT.exists():
    STAGE1_CKPT = CKPT_DIR / "stage1" / MACHINE_TYPE / "stage1_fan_best.pt"
STAGE2_CKPT = CKPT_DIR / "stage2" / MACHINE_TYPE / "stage2_fan_best.pt"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
print(f"Stage1 exists: {STAGE1_CKPT.exists()}")
print(f"Stage2 exists: {STAGE2_CKPT.exists()}")
print(f"Data path exists: {Path(DATA_PATH).exists()}")

## 1. Datasets and model loading

Train dataset (fan only) defines normalization: mean and std are computed per machine_type. Test dataset is given these so that spectrograms returned by `test_ds[i]` are already normalized with fan's statistics — consistent with the evaluator and with training.

In [None]:
from src.data.dataset import DCASE2020Task2LogMelDataset, DCASE2020Task2TestDataset
from src.models.vq_vae.autoencoders import VQ_VAE_2Layer
from src.models.sDSR.s_dsr import sDSR, sDSRConfig

train_ds = DCASE2020Task2LogMelDataset(
    root=str(DATA_PATH),
    machine_type=MACHINE_TYPE,
    normalize=True,
)
_, _, n_mels, T = train_ds.data.shape

# test_ds uses fan train mean/std so __getitem__ returns normalized spectrograms
test_ds = DCASE2020Task2TestDataset(
    root=str(DATA_PATH),
    machine_type=MACHINE_TYPE,
    mean=train_ds.mean,
    std=train_ds.std,
    target_T=train_ds.target_T,
)

vq_vae = VQ_VAE_2Layer(
    num_hiddens=128,
    num_residual_layers=2,
    num_residual_hiddens=64,
    num_embeddings=(1024, 4096),
    embedding_dim=128,
    commitment_cost=0.25,
    decay=0.99,
)
stage1 = torch.load(STAGE1_CKPT, map_location="cpu", weights_only=True)
vq_vae.load_state_dict(stage1["model_state_dict"])

cfg = sDSRConfig(embedding_dim=128, num_hiddens=128, n_mels=n_mels, T=T)
model = sDSR(vq_vae, cfg)
stage2 = torch.load(STAGE2_CKPT, map_location="cpu", weights_only=True)
model.load_state_dict(stage2["model_state_dict"])
model = model.to(DEVICE)
model.eval()

print(f"n_mels={n_mels}, T={T}")
print(f"Test set: {len(test_ds)} clips")

## 2. Sample selection

Pick one **normal** and one **real anomalous** clip from the test set. Use another normal clip for the **synthetic anomalous** run (we will apply a synthetic mask via `forward_train`).

In [None]:
normal_idx = None
anomalous_idx = None
for i in range(len(test_ds)):
    _, label, _ = test_ds.samples[i]
    if label == 0 and normal_idx is None:
        normal_idx = i
    if label == 1 and anomalous_idx is None:
        anomalous_idx = i
    if normal_idx is not None and anomalous_idx is not None:
        break

spec_normal, label_n, mid_n = test_ds[normal_idx]
spec_anomalous, label_a, mid_a = test_ds[anomalous_idx]

# Batch of 1: (1, 1, n_mels, T); test returns (1, n_mels, T) so one unsqueeze(0)
x_normal = spec_normal.unsqueeze(0).to(DEVICE)
x_anomalous = spec_anomalous.unsqueeze(0).to(DEVICE)

# For synthetic: same normal input + synthetic mask (generated below)
x_synthetic_input = x_normal.clone()

print(f"Normal: idx={normal_idx}, label={label_n}, machine_id={mid_n}, shape={x_normal.shape}")
print(f"Real anomalous: idx={anomalous_idx}, label={label_a}, machine_id={mid_a}, shape={x_anomalous.shape}")

## 3. Synthetic anomaly mask

Generate one synthetic anomaly mask (same strategy as stage-2 training) and format it as `(1, 1, n_mels, T)` for `forward_train`.

In [None]:
from src.utils.anomalies import AnomalyMapGenerator

mask_gen = AnomalyMapGenerator(
    strategy="both",
    spectrogram_shape=(n_mels, T),
    q_shape=(n_mels, T),
    n_mels=n_mels,
    T=T,
    zero_mask_prob=0.0,
)
M_synth = mask_gen.generate(1, device="cpu", force_anomaly=True)  # (1, 1, n_mels, T)
M_synth = M_synth.to(DEVICE)
print(f"Synthetic mask shape: {M_synth.shape}")

## 4. Run inference

- **Normal** and **real anomalous**: `forward(x, return_intermediates=True)` → M_out, X_G, X_S.
- **Synthetic anomalous**: `forward_train(x_normal, M_synth)` → same outputs with codebook-replaced codes.

In [None]:
with torch.no_grad():
    # Normal: standard inference path
    m_out_n, x_g_n, x_s_n = model(x_normal, return_intermediates=True)

    # Real anomalous: same inference path, different input
    m_out_a, x_g_a, x_s_a = model(x_anomalous, return_intermediates=True)

    # Synthetic anomalous: training path with synthetic mask (model in eval mode)
    out_synth = model.forward_train(x_synthetic_input, M_synth)
    m_out_s = out_synth["m_out"]
    x_g_s = out_synth["x_g"]
    x_s_s = out_synth["x_s"]

# Detach and move to CPU for plotting
def to_np(t):
    return t.detach().cpu().squeeze().numpy()

inputs = {
    "Normal": to_np(x_normal),
    "Synthetic anomalous": to_np(x_synthetic_input),
    "Real anomalous": to_np(x_anomalous),
}
x_g = {"Normal": to_np(x_g_n), "Synthetic anomalous": to_np(x_g_s), "Real anomalous": to_np(x_g_a)}
x_s = {"Normal": to_np(x_s_n), "Synthetic anomalous": to_np(x_s_s), "Real anomalous": to_np(x_s_a)}
m_out = {"Normal": to_np(m_out_n), "Synthetic anomalous": to_np(m_out_s), "Real anomalous": to_np(m_out_a)}

# Anomaly logit is channel 1
anomaly_logit = {k: v[1] if v.ndim == 3 else v for k, v in m_out.items()}
diff_xg_xs = {k: np.abs(x_g[k] - x_s[k]) for k in x_g}

## 5. Figures: compare all steps

For each sample type (normal, synthetic anomalous, real anomalous), plot: **Input**, **X_G** (general decoder), **X_S** (object-specific decoder), **|X_G − X_S|**, and **anomaly score map** (M_out channel 1).

In [None]:
titles = ["Normal", "Synthetic anomalous", "Real anomalous"]
keys = list(inputs.keys())

fig, axes = plt.subplots(5, 3, figsize=(12, 14))
fig.suptitle("Inference steps: Normal vs Synthetic vs Real anomalous (fan)", fontsize=12)

for col, key in enumerate(keys):
    axes[0, col].imshow(inputs[key], aspect="auto", origin="lower", cmap="magma")
    axes[0, col].set_title(titles[col])
    axes[0, col].set_ylabel("Input X")

    axes[1, col].imshow(x_g[key], aspect="auto", origin="lower", cmap="magma")
    axes[1, col].set_ylabel("X_G")

    axes[2, col].imshow(x_s[key], aspect="auto", origin="lower", cmap="magma")
    axes[2, col].set_ylabel("X_S")

    axes[3, col].imshow(diff_xg_xs[key], aspect="auto", origin="lower", cmap="hot")
    axes[3, col].set_ylabel("|X_G − X_S|")

    axes[4, col].imshow(anomaly_logit[key], aspect="auto", origin="lower", cmap="viridis")
    axes[4, col].set_ylabel("Anomaly logit")

for ax in axes.flat:
    ax.set_xticks([])
    ax.set_yticks([])
plt.tight_layout()
plt.show()

## 6. Optional: plot synthetic mask

Visualize the synthetic anomaly mask used for the middle column.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
ax.imshow(to_np(M_synth).squeeze(), aspect="auto", origin="lower", cmap="gray")
ax.set_title("Synthetic anomaly mask M (used for forward_train)")
ax.set_xlabel("Time")
ax.set_ylabel("Mel")
plt.tight_layout()
plt.show()