# NPZ Error Analysis

This notebook helps you inspect NPZ segments and find where the model performs poorly.

Workflow:
1. Configure paths and labels.
2. Build train/val splits and dataset stats.
3. Load a checkpoint and run inference.
4. Inspect misclassified and low-confidence samples.
5. Visualize raw audio and sensor channels.


In [None]:
from pathlib import Path
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import torchvision.models as models
import matplotlib.pyplot as plt

from src.ReadSegments import ReadSegments, find_segment_paths, read_labels_from_segments
from util.label_processor import LabelProcessor, DropLabel

try:
    import pandas as pd
except ImportError:
    pd = None


In [None]:
DATA_ROOT = Path("data/MMDataset_segments_first5")
CKPT_PATH = Path("checkpoints/best_multimodal.pt")

TARGET_SR = 22050
AUDIO_FMAX = 4000
TRAIN_RATIO = 0.6
SPLIT_SEED = 42
BATCH_SIZE = 32

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


In [None]:
label_processor = LabelProcessor(
    raw_to_norm={
        "no secretion": "no_secretion",
        "no secretion sound": "no_secretion",
        "no secretion sound (with hemf)": "no_secretion",
        "3ml secretion": "secretion",
        "3ml secretion m4": "secretion",
        "5ml secretion m4": "secretion",
        "5ml secretion": "secretion",
        "3ml secretion (with hemf)": "secretion",
    },
    fail_on_unknown=True,
)

def group_key_from_path(path: str) -> str:
    base = os.path.basename(path)
    if "_win" in base:
        return base.split("_win")[0]
    return os.path.splitext(base)[0]

def infer_group_label(group_paths, normalizer):
    norm_label = None
    for p in group_paths:
        d = np.load(p, allow_pickle=True)
        if "label" not in d:
            raise KeyError(f"Missing 'label' field in {p}")
        lab = d["label"]
        if isinstance(lab, np.ndarray) and lab.shape == ():
            lab = lab.item()
        current = normalizer(str(lab))
        if norm_label is None:
            norm_label = current
        elif norm_label != current:
            raise ValueError(f"Mixed labels inside group {p}: '{norm_label}' vs '{current}'")
    if norm_label is None:
        raise ValueError("Group contained no labels; cannot infer class")
    return norm_label

def split_paths_by_group(paths, train_ratio=0.6, seed=42, label_normalizer=None):
    groups = {}
    for p in paths:
        key = group_key_from_path(p)
        groups.setdefault(key, []).append(p)

    rng = random.Random(seed)

    if label_normalizer is not None:
        label_to_keys = {}
        for k, ps in groups.items():
            try:
                norm_label = infer_group_label(ps, label_normalizer)
            except DropLabel:
                continue
            label_to_keys.setdefault(norm_label, []).append(k)

        train_keys = []
        val_keys = []
        for _, keys in sorted(label_to_keys.items()):
            keys_sorted = sorted(keys)
            rng.shuffle(keys_sorted)
            n_train = int(train_ratio * len(keys_sorted))
            if len(keys_sorted) > 1:
                n_train = min(max(n_train, 1), len(keys_sorted) - 1)
            train_keys.extend(keys_sorted[:n_train])
            val_keys.extend(keys_sorted[n_train:])
        rng.shuffle(train_keys)
        rng.shuffle(val_keys)
    else:
        keys = sorted(groups.keys())
        rng.shuffle(keys)
        n_train = int(train_ratio * len(keys))
        if len(keys) > 1:
            n_train = min(max(n_train, 1), len(keys) - 1)
        train_keys = keys[:n_train]
        val_keys = keys[n_train:]

    train_paths = [p for k in train_keys for p in sorted(groups[k])]
    val_paths = [p for k in val_keys for p in sorted(groups[k])]
    return train_paths, val_paths


In [None]:
all_paths = sorted(find_segment_paths(str(DATA_ROOT)))
print("Total npz files:", len(all_paths))
if not all_paths:
    raise RuntimeError("No npz files found. Update DATA_ROOT.")

raw_labels = read_labels_from_segments(all_paths)
print(label_processor.summarize_counts(raw_labels))

train_paths, val_paths = split_paths_by_group(
    all_paths, train_ratio=TRAIN_RATIO, seed=SPLIT_SEED, label_normalizer=label_processor
)
print("Train segments:", len(train_paths))
print("Val segments:", len(val_paths))
print("Total groups:", len({group_key_from_path(p) for p in all_paths}))


In [None]:
def collate_fn(batch):
    audio_list = [b["audio"] for b in batch]
    audio_lengths = [a.shape[0] for a in audio_list]
    max_T_audio = max(audio_lengths)
    B = len(batch)

    audio = torch.zeros(B, max_T_audio, dtype=torch.float32)
    for i, (a, L) in enumerate(zip(audio_list, audio_lengths)):
        audio[i, :L] = a

    sensor_seqs = []
    lengths = []
    paths = []
    raw_labels = []
    norm_labels = []

    for b in batch:
        P = b["P"]
        Q = b["Q"]

        if P is None or Q is None:
            T = b["audio"].shape[0]
            sensor = torch.zeros(T, 2, dtype=torch.float32)
        else:
            if P.dim() == 1:
                P_2d = P.unsqueeze(-1)
            else:
                P_2d = P
            if Q.dim() == 1:
                Q_2d = Q.unsqueeze(-1)
            else:
                Q_2d = Q
            P_main = P_2d[:, 0:1]
            Q_main = Q_2d[:, 0:1]
            sensor = torch.cat([P_main, Q_main], dim=-1)

        sensor_seqs.append(sensor)
        lengths.append(sensor.shape[0])
        paths.append(b["path"])
        raw_labels.append(b.get("raw_label"))
        norm_labels.append(b.get("norm_label"))

    max_len = max(lengths)
    sensor_padded = torch.zeros(B, max_len, 2, dtype=torch.float32)
    for i, (seq, L) in enumerate(zip(sensor_seqs, lengths)):
        sensor_padded[i, :L, :] = seq

    lengths_tensor = torch.tensor(lengths, dtype=torch.long)
    audio_lengths_tensor = torch.tensor(audio_lengths, dtype=torch.long)

    labels = torch.stack([b["label_id"] for b in batch], dim=0)

    return {
        "audio": audio,
        "audio_lengths": audio_lengths_tensor,
        "sensor": sensor_padded,
        "sensor_lengths": lengths_tensor,
        "label": labels,
        "paths": paths,
        "raw_label": raw_labels,
        "norm_label": norm_labels,
    }

train_ds = ReadSegments(
    train_paths,
    target_sample_rate=TARGET_SR,
    label_normalizer=label_processor,
    group_normalize=True,
)
val_ds = ReadSegments(
    val_paths,
    target_sample_rate=TARGET_SR,
    label_normalizer=label_processor,
    label2id=train_ds.label2id,
    label_map_paths=train_paths,
    group_normalize=True,
)

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)
val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)

print("Label map:", train_ds.label2id)


In [None]:
def _replace_bn_with_gn(module, num_groups=8):
    for name, child in module.named_children():
        if isinstance(child, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            num_channels = child.num_features
            gn = nn.GroupNorm(
                num_groups=min(num_groups, num_channels),
                num_channels=num_channels,
                affine=True,
            )
            setattr(module, name, gn)
        else:
            _replace_bn_with_gn(child, num_groups=num_groups)

class AudioResNetEncoder(nn.Module):
    def __init__(
        self,
        sample_rate=48000,
        n_mels=64,
        n_fft=1024,
        hop_length=512,
        out_dim=128,
        f_min=0.0,
        f_max=None,
        use_pretrained=False,
    ):
        super().__init__()
        max_freq = sample_rate / 2 if f_max is None else f_max
        self.melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            f_min=f_min,
            f_max=max_freq,
        )
        self.db = torchaudio.transforms.AmplitudeToDB()

        weights = models.ResNet18_Weights.IMAGENET1K_V1 if use_pretrained else None
        self.backbone = models.resnet18(weights=weights)
        _replace_bn_with_gn(self.backbone, num_groups=8)
        self.backbone.conv1 = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, out_dim)

    def forward(self, audio):
        x = audio.unsqueeze(1)
        x = x.squeeze(1)
        mel = self.melspec(x)
        mel_db = self.db(mel)
        mel_db = mel_db.unsqueeze(1)
        feat = self.backbone(mel_db)
        return feat

class InceptionBlock1D(nn.Module):
    def __init__(self, in_channels, n_filters, kernel_sizes, bottleneck_channels):
        super().__init__()
        use_bottleneck = bottleneck_channels > 0 and in_channels > 1
        self.bottleneck = (
            nn.Conv1d(in_channels, bottleneck_channels, kernel_size=1, bias=False)
            if use_bottleneck
            else None
        )
        conv_in = bottleneck_channels if use_bottleneck else in_channels
        self.convs = nn.ModuleList(
            [
                nn.Conv1d(
                    conv_in,
                    n_filters,
                    kernel_size=k,
                    padding=k // 2,
                    bias=False,
                )
                for k in kernel_sizes
            ]
        )
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
        self.pool_conv = nn.Conv1d(in_channels, n_filters, kernel_size=1, bias=False)
        out_channels = n_filters * (len(kernel_sizes) + 1)
        self.bn = nn.GroupNorm(
            num_groups=min(8, out_channels),
            num_channels=out_channels,
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_in = x
        if self.bottleneck is not None:
            x = self.bottleneck(x)
        outs = [conv(x) for conv in self.convs]
        outs.append(self.pool_conv(self.maxpool(x_in)))
        x = torch.cat(outs, dim=1)
        x = self.bn(x)
        return self.relu(x)

class InceptionTimeEncoder(nn.Module):
    def __init__(
        self,
        input_dim=2,
        out_dim=128,
        n_filters=32,
        kernel_sizes=(9, 19, 39),
        bottleneck_channels=32,
        n_blocks=6,
        use_residual=True,
    ):
        super().__init__()
        self.use_residual = use_residual
        self.blocks = nn.ModuleList()
        self.residuals = nn.ModuleDict()

        in_channels = input_dim
        res_in_channels = input_dim
        for i in range(n_blocks):
            block = InceptionBlock1D(
                in_channels=in_channels,
                n_filters=n_filters,
                kernel_sizes=kernel_sizes,
                bottleneck_channels=bottleneck_channels,
            )
            self.blocks.append(block)
            out_channels = n_filters * (len(kernel_sizes) + 1)
            if self.use_residual and (i + 1) % 3 == 0:
                self.residuals[str(i)] = nn.Sequential(
                    nn.Conv1d(res_in_channels, out_channels, kernel_size=1, bias=False),
                    nn.GroupNorm(num_groups=min(8, out_channels), num_channels=out_channels),
                )
                res_in_channels = out_channels
            in_channels = out_channels

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Linear(in_channels, out_dim)
        self.res_relu = nn.ReLU(inplace=True)

    def forward(self, x, lengths):
        _ = lengths
        x = x.transpose(1, 2)
        for i, block in enumerate(self.blocks):
            if self.use_residual and i % 3 == 0:
                res_input = x
            x = block(x)
            if self.use_residual and (i + 1) % 3 == 0:
                res = self.residuals[str(i)](res_input)
                x = self.res_relu(x + res)
        feats = self.global_pool(x).squeeze(-1)
        return self.proj(feats)

class MultiModalLateFusionNet(nn.Module):
    def __init__(
        self,
        audio_feat_dim=128,
        sensor_feat_dim=128,
        num_classes=2,
        sample_rate=48000,
        f_min=0.0,
        f_max=None,
        use_pretrained=False,
    ):
        super().__init__()
        self.audio_encoder = AudioResNetEncoder(
            sample_rate=sample_rate,
            out_dim=audio_feat_dim,
            f_min=f_min,
            f_max=f_max,
            use_pretrained=use_pretrained,
        )
        self.sensor_encoder = InceptionTimeEncoder(
            input_dim=2,
            out_dim=sensor_feat_dim,
        )
        self.fusion = nn.Sequential(
            nn.Linear(audio_feat_dim + sensor_feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )

    def forward(self, audio, sensor, sensor_lengths, drop_audio=False, drop_sensor=False):
        fa = self.audio_encoder(audio)
        fs = self.sensor_encoder(sensor, sensor_lengths)

        if drop_audio:
            fa = torch.zeros_like(fa)
        if drop_sensor:
            fs = torch.zeros_like(fs)

        fused = torch.cat([fa, fs], dim=-1)
        logits = self.fusion(fused)
        return logits


In [None]:
model = None
id2label = None

if not CKPT_PATH.exists():
    print("Checkpoint not found:", CKPT_PATH)
else:
    ckpt = torch.load(CKPT_PATH, map_location="cpu")
    label2id = ckpt.get("label2id", train_ds.label2id)
    id2label = {v: k for k, v in label2id.items()}
    model = MultiModalLateFusionNet(
        audio_feat_dim=128,
        sensor_feat_dim=128,
        num_classes=len(label2id),
        sample_rate=TARGET_SR,
        f_max=AUDIO_FMAX,
        use_pretrained=False,
    )
    try:
        model.load_state_dict(ckpt["model_state"], strict=True)
    except RuntimeError as exc:
        print("Checkpoint load failed:", exc)
        model = None
    if model is not None:
        model.to(DEVICE)
        model.eval()
        print("Loaded checkpoint with labels:", id2label)


In [None]:
@torch.no_grad()
def run_inference(model, loader, device, id2label):
    model.eval()
    ce = nn.CrossEntropyLoss(reduction="none")
    rows = []
    for batch in loader:
        audio = batch["audio"].to(device)
        sensor = batch["sensor"].to(device)
        lengths = batch["sensor_lengths"].to(device)
        labels = batch["label"].to(device)
        paths = batch["paths"]
        raw_labels = batch.get("raw_label", [None] * len(paths))
        norm_labels = batch.get("norm_label", [None] * len(paths))

        logits = model(audio, sensor, lengths, drop_audio=False, drop_sensor=False)
        probs = torch.softmax(logits, dim=1)
        loss = ce(logits, labels)
        preds = logits.argmax(dim=1)

        for i in range(len(paths)):
            true_id = int(labels[i].item())
            pred_id = int(preds[i].item())
            prob_true = float(probs[i, true_id].item())
            prob_pred = float(probs[i, pred_id].item())
            rows.append({
                "path": paths[i],
                "group": group_key_from_path(paths[i]),
                "true_id": true_id,
                "pred_id": pred_id,
                "true_label": id2label.get(true_id, str(true_id)),
                "pred_label": id2label.get(pred_id, str(pred_id)),
                "prob_true": prob_true,
                "prob_pred": prob_pred,
                "loss": float(loss[i].item()),
                "correct": int(true_id == pred_id),
                "raw_label": raw_labels[i],
                "norm_label": norm_labels[i],
            })

    if pd is not None:
        return pd.DataFrame(rows)
    return rows


In [None]:
EVAL_SPLIT = "val"  # 'train' or 'val'
eval_loader = val_loader if EVAL_SPLIT == "val" else train_loader

if model is None:
    print("Load a checkpoint before running inference.")
else:
    df = run_inference(model, eval_loader, DEVICE, id2label)
    if pd is not None:
        display(df.head())
    else:
        print(df[:5])


In [None]:
def compute_confusion(labels, preds, num_classes):
    cm = np.zeros((num_classes, num_classes), dtype=int)
    for t, p in zip(labels, preds):
        cm[int(t), int(p)] += 1
    return cm

def compute_precision_recall(cm):
    precision = np.diag(cm) / np.maximum(cm.sum(axis=0), 1)
    recall = np.diag(cm) / np.maximum(cm.sum(axis=1), 1)
    return precision, recall

if model is not None:
    if pd is not None:
        labels = df["true_id"].to_numpy()
        preds = df["pred_id"].to_numpy()
    else:
        labels = [row["true_id"] for row in df]
        preds = [row["pred_id"] for row in df]

    cm = compute_confusion(labels, preds, num_classes=len(id2label))
    precision, recall = compute_precision_recall(cm)

    print("Accuracy:", float(np.mean(np.array(labels) == np.array(preds))))
    if pd is not None:
        cm_df = pd.DataFrame(cm,
                            index=[id2label[i] for i in range(len(id2label))],
                            columns=[id2label[i] for i in range(len(id2label))])
        display(cm_df)
        pr_df = pd.DataFrame({
            "precision": precision,
            "recall": recall,
        }, index=[id2label[i] for i in range(len(id2label))])
        display(pr_df)
    else:
        print("Confusion matrix:
", cm)
        print("Precision:", precision)
        print("Recall:", recall)


In [None]:
if model is not None:
    if pd is not None:
        print("Misclassified samples (top 20 by loss):")
        display(df[df["correct"] == 0].sort_values("loss", ascending=False).head(20))
        print("Hard correct samples (low prob_true):")
        display(df[df["correct"] == 1].sort_values("prob_true", ascending=True).head(20))
    else:
        mis = [row for row in df if row["correct"] == 0]
        mis = sorted(mis, key=lambda r: r["loss"], reverse=True)[:20]
        print(mis)


In [None]:
if model is not None and pd is not None:
    group_summary = (
        df.groupby("group")
          .agg(total=("path", "count"),
               correct=("correct", "sum"),
               acc=("correct", "mean"),
               mean_loss=("loss", "mean"))
          .sort_values("acc", ascending=True)
    )
    display(group_summary.head(20))


In [None]:
def plot_npz(path, pred_row=None, target_sr=TARGET_SR):
    d = np.load(path, allow_pickle=True)
    audio = d["audio"].astype(np.float32)
    sr = int(d.get("audio_rate_hz", 48000))
    label = d.get("label")
    if isinstance(label, np.ndarray) and label.shape == ():
        label = label.item()

    audio_t = torch.from_numpy(audio)
    if sr != target_sr:
        audio_t = torchaudio.functional.resample(audio_t, sr, target_sr)
        sr = target_sr

    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr,
        n_mels=64,
        n_fft=1024,
        hop_length=512,
        f_max=AUDIO_FMAX,
    )(audio_t.unsqueeze(0))
    mel_db = torchaudio.transforms.AmplitudeToDB()(mel).squeeze(0).numpy()

    sensor_all = d["sensor_values"].astype(np.float32)
    sensor_cols = [str(c) for c in d["sensor_cols"]]
    p_idx = [i for i, name in enumerate(sensor_cols) if "P_" in name]
    q_idx = [i for i, name in enumerate(sensor_cols) if "F_" in name]

    P = sensor_all[:, p_idx[0]] if p_idx else None
    Q = sensor_all[:, q_idx[0]] if q_idx else None

    title = f"label={label}"
    if pred_row is not None:
        title += f" | pred={pred_row.get('pred_label')} | prob_true={pred_row.get('prob_true'):.3f}"

    fig, axes = plt.subplots(3, 1, figsize=(12, 8))
    axes[0].plot(audio_t.numpy())
    axes[0].set_title("Audio waveform")
    axes[1].imshow(mel_db, aspect="auto", origin="lower")
    axes[1].set_title("Mel spectrogram (dB)")
    if P is not None or Q is not None:
        if P is not None:
            axes[2].plot(P, label="P")
        if Q is not None:
            axes[2].plot(Q, label="Q")
        axes[2].legend()
    axes[2].set_title("Sensor channels")
    fig.suptitle(title)
    plt.tight_layout()
    plt.show()


In [None]:
if model is not None and pd is not None and len(df):
    sample = df.sort_values("loss", ascending=False).iloc[0]
    plot_npz(sample["path"], sample)
