## How to use the WhisperFormer Model - Step by Step Guide

This notebook runs a trained WhisperFormer model on your audio files and saves the detected calls as `.json` files. Optionally, results can be converted to Raven selection tables.

**Always run the cells in order from top to bottom!**

### Step 1: Install Dependencies

You only need to run this cell the first time running the notebook (it might take some time).

**After running this cell, restart the kernel** (Kernel → Restart Kernel) and then continue with Step 2.

In [None]:
!pip install --user --upgrade torch torchvision numpy scipy transformers librosa pandas matplotlib

### Step 2: Import Required Modules

In [None]:
import logging
import json
import os
import torch

from utils import infer

### Step 3: Set Paths and Parameters

Before running: place your `.wav` files in the `audios/` folder next to this notebook.

You also need:
- **checkpoint file** (`.pth`) -- the trained WhisperFormer model
- **whisper_config/** folder -- must contain `config.json` and `preprocessor_config.json` from the Whisper model used during training (copy from `whisper_models/whisper_base` or `whisper_models/whisper_large` in the main repository)

In [None]:
logging.basicConfig(level=logging.INFO)
PATH = os.getcwd()

# --- Paths (adjust if needed) ---
DATA_DIR = os.path.join(PATH, "audios")
CHECKPOINT_PATH = os.path.join(PATH, "checkpoint.pth")
WHISPER_CONFIG_PATH = os.path.join(PATH, "whisper_config")
OUTPUT_DIR = os.path.join(PATH, "jsons")

# --- Inference parameters ---
THRESHOLD = 0.35        # minimum confidence score to keep a prediction
IOU_THRESHOLD = 0.4     # IoU threshold for non-maximum suppression
NUM_RUNS = 3            # number of offset runs (1 = fast, 3 = more robust)
OVERLAP_TOLERANCE = 0.1 # IoU threshold for consolidating predictions across runs
TOTAL_SPEC_COLUMNS = 3000
BATCH_SIZE = 4

### Step 4: Verify Paths

Run this cell to check that all required paths exist.

In [None]:
all_ok = True

print("Checkpoint:", CHECKPOINT_PATH)
if os.path.exists(CHECKPOINT_PATH):
    size_mb = os.path.getsize(CHECKPOINT_PATH) / 1e6
    print(f"  OK ({size_mb:.0f} MB)")
else:
    print("  MISSING -- please provide a .pth checkpoint file")
    all_ok = False

print("Whisper config:", WHISPER_CONFIG_PATH)
if os.path.isdir(WHISPER_CONFIG_PATH):
    contents = os.listdir(WHISPER_CONFIG_PATH)
    print(f"  OK (files: {contents})")
    for needed in ["config.json", "preprocessor_config.json"]:
        if needed not in contents:
            print(f"  WARNING: {needed} is missing in whisper_config/")
            all_ok = False
else:
    print("  MISSING -- copy from whisper_models/whisper_base or whisper_large")
    all_ok = False

print("Audio folder:", DATA_DIR)
if os.path.isdir(DATA_DIR):
    wav_files = [f for f in os.listdir(DATA_DIR) if f.lower().endswith(".wav")]
    print(f"  OK ({len(wav_files)} WAV file(s) found)")
    if not wav_files:
        print("  WARNING: no .wav files found in audios/")
        all_ok = False
else:
    print("  MISSING -- create an 'audios' folder and place your .wav files there")
    all_ok = False

print()
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")
print()
if all_ok:
    print("Everything looks good! Proceed to Step 5.")
else:
    print("Please fix the issues above before running inference.")

### Step 5: Run Inference

This will process all `.wav` files in `audios/` and save predictions as `.json` files in `jsons/`.

The model runs each file multiple times with different time offsets and consolidates the results for more robust predictions.

In [None]:
results = infer(
    data_dir=DATA_DIR,
    checkpoint_path=CHECKPOINT_PATH,
    whisper_config_path=WHISPER_CONFIG_PATH,
    output_dir=OUTPUT_DIR,
    threshold=THRESHOLD,
    iou_threshold=IOU_THRESHOLD,
    total_spec_columns=TOTAL_SPEC_COLUMNS,
    batch_size=BATCH_SIZE,
    num_runs=NUM_RUNS,
    overlap_tolerance=OVERLAP_TOLERANCE,
)

### Step 6: Results Summary

Overview of the detected calls per file:

total = 0
for filename, preds in results.items():
    n = len(preds["onset"])
    total += n
    clusters = {}
    for c in preds["cluster"]:
        clusters[c] = clusters.get(c, 0) + 1
    cluster_str = ", ".join(f"{k}: {v}" for k, v in sorted(clusters.items()))
    print(f"  {filename}: {n} predictions ({cluster_str})")

print(f"\nTotal: {total} predictions across {len(results)} file(s)")
print(f"Results saved to: {OUTPUT_DIR}")

### Step 7: Visualize Predictions

For each audio file, the first few segments are plotted showing:
1. **Mel spectrogram** — the audio representation fed to the model
2. **Per-class scores** — the raw model output (confidence per frame and class) with predicted call segments highlighted

Adjust `NUM_SEGMENTS_TO_PLOT` to show more or fewer segments per file.

In [None]:
import contextlib
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import librosa
import librosa.display
import numpy as np

from utils import load_trained_whisperformer, get_id_to_cluster
from transformers import WhisperFeatureExtractor

NUM_SEGMENTS_TO_PLOT = 3

# --- Reload model & feature extractor for visualization ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model, num_classes, detected_size = load_trained_whisperformer(
    CHECKPOINT_PATH, device, WHISPER_CONFIG_PATH
)
feature_extractor = WhisperFeatureExtractor.from_pretrained(
    WHISPER_CONFIG_PATH, local_files_only=True
)

id_to_cluster = get_id_to_cluster(num_classes)
cluster_to_id = {v: k for k, v in id_to_cluster.items()}

LABEL_DISPLAY = {"m": "moan", "h": "hmm", "w": "wail"}
COLOR_MAP = {0: "darkorange", 1: "cornflowerblue", 2: "gold", 3: "r"}
SEC_PER_COL = 0.02
SR = 16000
seg_duration = (TOTAL_SPEC_COLUMNS / 2) * SEC_PER_COL
num_samples_in_clip = int(round((TOTAL_SPEC_COLUMNS * 0.01) * SR))

use_autocast = device != "cpu" and (
    (isinstance(device, str) and device.startswith("cuda"))
    or (hasattr(device, "type") and device.type == "cuda")
)
autocast_ctx = (
    torch.amp.autocast(device_type="cuda", dtype=torch.float16)
    if use_autocast else contextlib.nullcontext()
)


def plot_segment(mel_spec, class_scores, pred_onsets, pred_offsets,
                 pred_classes, title="", threshold=THRESHOLD):
    """Plot mel spectrogram with model scores and predicted segments."""
    T = class_scores.shape[0]
    num_cls = class_scores.shape[1]
    time_axis = np.arange(T) * SEC_PER_COL

    fig, (ax_spec, ax_scores) = plt.subplots(
        2, 1, figsize=(12, 5), height_ratios=[3, 1.2]
    )

    # Mel spectrogram
    librosa.display.specshow(
        mel_spec, cmap=plt.cm.magma, sr=SR, hop_length=160,
        x_axis="time", y_axis="mel", fmin=0, fmax=8000, ax=ax_spec,
    )
    ax_spec.set_ylabel("Frequency (Hz)")
    ax_spec.set_title(title)

    # Per-class score bars
    for c in range(num_cls):
        label = LABEL_DISPLAY.get(
            id_to_cluster.get(c, ""), id_to_cluster.get(c, str(c))
        )
        ax_scores.bar(
            time_axis, class_scores[:, c], width=SEC_PER_COL,
            align="edge", alpha=1, label=label,
            color=COLOR_MAP.get(c, f"C{c}"),
        )

    ax_scores.axhline(
        y=threshold, color="r", linestyle="--",
        label=f"Threshold {threshold}",
    )
    ax_scores.set_ylim(0, 1.1)
    ax_scores.set_xlim(0, T * SEC_PER_COL)
    ax_scores.set_xlabel("Time (s)")
    ax_scores.set_ylabel("Score")
    ax_scores.set_title("WhisperFormer Scores and Predictions")

    # Highlight predicted segments
    for onset, offset, cls in zip(pred_onsets, pred_offsets, pred_classes):
        cid = cluster_to_id.get(cls, 0)
        ax_scores.axvspan(
            onset, offset, color=COLOR_MAP.get(cid, "gray"), alpha=0.3
        )

    # De-duplicated legend
    handles, labels_leg = ax_scores.get_legend_handles_labels()
    by_label = dict(zip(labels_leg, handles))
    ax_scores.legend(by_label.values(), by_label.keys(), loc="upper right")

    plt.tight_layout()
    plt.show()


# --- Loop over audio files and visualize ---
wav_files = sorted(Path(DATA_DIR).rglob("*.[Ww][Aa][Vv]"))

for wav_path in wav_files:
    preds = results.get(wav_path.name)
    if preds is None:
        continue

    audio, _ = librosa.load(wav_path, sr=SR)
    pred_onsets = np.array(preds["onset"])
    pred_offsets = np.array(preds["offset"])
    pred_clusters = np.array(preds["cluster"])

    for seg_i in range(NUM_SEGMENTS_TO_PLOT):
        seg_start_sec = seg_i * seg_duration
        seg_end_sec = (seg_i + 1) * seg_duration
        seg_start_sample = int(seg_start_sec * SR)

        if seg_start_sample >= len(audio):
            break

        clip = audio[seg_start_sample : seg_start_sample + num_samples_in_clip]
        if len(clip) < SR * 0.1:
            break

        clip_padded = np.concatenate(
            [clip, np.zeros(max(0, num_samples_in_clip - len(clip)))]
        ).astype(np.float32)

        feats = feature_extractor(
            clip_padded, sampling_rate=SR, padding="do_not_pad"
        )["input_features"][0]
        mel_spec = np.array(feats)

        # Per-frame model scores
        x = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(device)
        with torch.no_grad(), autocast_ctx:
            class_preds, _ = model(x)
            class_scores = torch.sigmoid(class_preds).squeeze(0).cpu().numpy()

        # Select predictions falling within this segment
        in_seg = (pred_onsets < seg_end_sec) & (pred_offsets > seg_start_sec)
        seg_pred_onsets = np.clip(pred_onsets[in_seg] - seg_start_sec, 0, None)
        seg_pred_offsets = np.clip(
            pred_offsets[in_seg] - seg_start_sec, None, seg_duration
        )
        seg_pred_clusters = pred_clusters[in_seg]

        plot_segment(
            mel_spec, class_scores,
            seg_pred_onsets, seg_pred_offsets, seg_pred_clusters,
            title=f"{wav_path.name} — Segment {seg_i + 1}",
            threshold=THRESHOLD,
        )

In [None]:
### Step 8: Convert to Raven Selection Tables (Optional)

To visualize the results in Raven, convert the `.json` files to `.txt` selection tables.

In [None]:
from json_to_raven import process_folder

JSON_DIR = os.path.join(PATH, "jsons")
RAVEN_DIR = os.path.join(PATH, "raven")

process_folder(JSON_DIR, RAVEN_DIR)

The `.txt` selection tables can now be found in the `raven/` folder. Open them in Raven Pro alongside the corresponding audio files to visualize the detected calls.

---
*End of notebook.*