# Beatgrid CRNN MVP Training


## Configuration and Imports


These globals configure audio preprocessing and training defaults. `TARGET_SR`, `N_MELS`, `N_FFT`, and `HOP_LENGTH` define how raw audio is turned into log-mel frames that the network ingests, while `BEAT_TOLERANCE_SEC` specifies how close a frame must be to a ground-truth beat to count as positive.


In [None]:
import json
import random
from pathlib import Path
from typing import List, Tuple

import librosa
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

TARGET_SR = 44100
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512  # ~11.6 ms at 44.1k
BEAT_TOLERANCE_SEC = 0.03  # +/- 30 ms


Label JSON files contain a constant BPM, duration, and optional downbeat offset per track. The helpers below simply load that metadata and turn it into a per-beat timeline so later steps can create frame-wise supervision.


## Label Utilities


In [2]:
def load_label_json(label_path: str) -> dict:
    with open(label_path, "r") as f:
        return json.load(f)


def generate_beats_from_constant_bpm(
    bpm: float,
    duration: float,
    downbeat_offset_sec: float = 0.0,
) -> List[float]:
    """Generate a simple beatgrid for a constant-BPM track."""
    if bpm is None or bpm <= 0:
        return []

    period = 60.0 / float(bpm)  # seconds per beat
    beat_times = []

    t = downbeat_offset_sec
    while t < duration:
        beat_times.append(t)
        t += period

    return beat_times


## Audio to Mel Features


The dataset takes a list of `{audio_path, label_path}` pairs, converts the audio into log-mel tensors, and aligns constant-BPM label grids to per-frame 0/1 targets. Each `__getitem__` returns `(mel_tensor[T,F], label_tensor[T])`, so batching simply stacks whole tracks.


In [None]:
def load_audio_to_mel(audio_path: str):
    """Load audio to mono log-mel spectrogram and frame timestamps."""
    y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
    duration = len(y) / sr

    mel = librosa.feature.melspectrogram(
        y=y,
        sr=sr,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_mels=N_MELS,
        power=2.0,
    )
    mel_db = librosa.power_to_db(mel, ref=np.max).T  # (T, N_MELS)

    frames = np.arange(mel_db.shape[0])
    frame_times = librosa.frames_to_time(
        frames,
        sr=sr,
        hop_length=HOP_LENGTH,
        n_fft=N_FFT,
    )

    return mel_db, frame_times, duration


def beat_times_to_frame_labels(
    beat_times: List[float],
    frame_times: np.ndarray,
    tolerance_sec: float = BEAT_TOLERANCE_SEC,
) -> np.ndarray:
    """Convert beat timestamps to per-frame binary labels."""
    labels = np.zeros_like(frame_times, dtype=np.float32)

    if len(beat_times) == 0:
        return labels

    beat_idx = 0
    num_beats = len(beat_times)

    for i, ft in enumerate(frame_times):
        while beat_idx + 1 < num_beats and beat_times[beat_idx] < ft:
            if abs(beat_times[beat_idx + 1] - ft) < abs(beat_times[beat_idx] - ft):
                beat_idx += 1
            else:
                break
        if abs(beat_times[beat_idx] - ft) <= tolerance_sec:
            labels[i] = 1.0

    return labels


## Dataset


In [None]:
class BeatActivationDataset(Dataset):
    """Dataset returning full-track mel tensors and beat labels."""

    def __init__(self, items: List[dict]):
        self.items = items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        item = self.items[idx]
        audio_path = item["audio_path"]
        label_path = item["label_path"]

        mel_db, frame_times, duration_audio = load_audio_to_mel(audio_path)

        label_json = load_label_json(label_path)
        bpm = float(label_json["bpm"])
        duration_label = float(label_json.get("duration", duration_audio))
        downbeat_offset = float(label_json.get("downbeatOffset", 0))

        beat_times = generate_beats_from_constant_bpm(
            bpm=bpm,
            duration=duration_label,
            downbeat_offset_sec=downbeat_offset,
        )

        labels = beat_times_to_frame_labels(beat_times, frame_times)

        mel_tensor = torch.from_numpy(mel_db).float()
        labels_tensor = torch.from_numpy(labels).float()

        return mel_tensor, labels_tensor


def collate_full_tracks(batch):
    """Simple collate that stacks mels/labels (assumes equal lengths)."""
    mels = [b[0] for b in batch]
    labels = [b[1] for b in batch]

    mel_batch = torch.stack(mels, dim=0)
    label_batch = torch.stack(labels, dim=0)

    return mel_batch, label_batch


## Dataset Discovery & Splits
Automatically pair each label JSON with its audio file so the training run always reflects every annotated track, then optionally hold out a few items for validation.


In [None]:
def discover_label_items(label_dir: str, audio_dir: str) -> List[dict]:
    """Pair each label json with its corresponding audio file."""
    label_dir_path = Path(label_dir)
    audio_dir_path = Path(audio_dir)

    if not label_dir_path.exists():
        raise FileNotFoundError(f"Label directory not found: {label_dir_path}")
    if not audio_dir_path.exists():
        raise FileNotFoundError(f"Audio directory not found: {audio_dir_path}")

    label_files = sorted(label_dir_path.glob("*.labels.json"))
    if not label_files:
        raise FileNotFoundError(f"No *.labels.json files found in {label_dir_path}")

    items: List[dict] = []
    missing_audio = []
    common_exts = ["", ".mp3", ".m4a", ".wav", ".flac", ".ogg", ".aif", ".aiff"]

    for label_path in label_files:
        label_data = load_label_json(str(label_path))
        file_name = label_data.get("fileName")

        candidates = []
        if file_name:
            candidates.append(audio_dir_path / file_name)

        base_name = label_path.stem
        if base_name.endswith(".labels"):
            base_name = base_name[:-len(".labels")]

        for ext in common_exts:
            candidate_path = audio_dir_path / f"{base_name}{ext}"
            if candidate_path not in candidates:
                candidates.append(candidate_path)

        audio_path = next((path for path in candidates if path.exists()), None)

        if audio_path is None:
            missing_audio.append((label_path.name, file_name or f"{base_name}.*"))
            continue

        items.append({"audio_path": str(audio_path), "label_path": str(label_path)})

    if missing_audio:
        print("Warning: Skipping labels with missing audio files:")
        for label_name, expected in missing_audio:
            print(f" - {label_name} (expected audio similar to {expected})")

    if not items:
        raise RuntimeError("No valid label/audio pairs were found.")

    return items


def split_train_val_items(items: List[dict], val_count: int) -> Tuple[List[dict], List[dict]]:
    """Split dataset into train/val using the end of the list as validation."""
    if val_count <= 0:
        return items, []

    if val_count >= len(items):
        print(
            f"Requested val_count={val_count} but only {len(items)} tracks available. "
            "Reducing validation set so at least one training track remains."
        )
        val_count = max(0, len(items) - 1)

    if val_count == 0:
        return items, []

    train_items = items[:-val_count]
    val_items = items[-val_count:]
    return train_items, val_items



`BeatCRNN` first compresses the mel time–frequency map with a small CNN, then runs a bidirectional LSTM so each frame sees past/future context before emitting a beat logit. The rest of the notebook only relies on the logits; applying `torch.sigmoid` converts them to beat activation probabilities.


## Model


In [39]:
class BeatCRNN(nn.Module):
    """Minimal CRNN mapping log-mels to beat logits."""

    def __init__(self, n_mels=N_MELS, hidden_size=128, num_layers=2):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2)),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2)),
        )

        freq_out = n_mels // 4
        cnn_channels = 64
        rnn_input_size = cnn_channels * freq_out

        self.rnn = nn.LSTM(
            input_size=rnn_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )

        self.output = nn.Linear(hidden_size * 2, 1)

    def forward(self, x):
        """Forward pass producing beat logits."""
        B, T, F = x.shape
        x = x.unsqueeze(1)
        x = self.cnn(x)
        B, C, T_new, F_new = x.shape
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(B, T_new, C * F_new)
        x, _ = self.rnn(x)
        logits = self.output(x).squeeze(-1)
        return logits


def train_mvp(
    train_items: List[dict],
    val_items: List[dict] = None,
    num_epochs: int = 20,
    lr: float = 1e-3,
    device: str = None,
):
    """Train the Beat CRNN on full tracks."""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    train_ds = BeatActivationDataset(train_items)
    train_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True,
        collate_fn=collate_full_tracks,
    )

    if val_items is not None and len(val_items) > 0:
        val_ds = BeatActivationDataset(val_items)
        val_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            collate_fn=collate_full_tracks,
        )
    else:
        val_loader = None

    model = BeatCRNN().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0

        for mel_batch, label_batch in train_loader:
            mel_batch = mel_batch.to(device)
            label_batch = label_batch.to(device)

            optim.zero_grad()
            logits = model(mel_batch)

            T_pred = logits.shape[1]
            T_true = label_batch.shape[1]
            T_min = min(T_pred, T_true)

            loss = criterion(
                logits[:, :T_min],
                label_batch[:, :T_min],
            )
            loss.backward()
            optim.step()

            total_loss += loss.item()

        avg_loss = total_loss / max(1, len(train_loader))
        print(f"Epoch {epoch}/{num_epochs} - Train loss: {avg_loss:.4f}")

        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for mel_batch, label_batch in val_loader:
                    mel_batch = mel_batch.to(device)
                    label_batch = label_batch.to(device)

                    logits = model(mel_batch)
                    T_pred = logits.shape[1]
                    T_true = label_batch.shape[1]
                    T_min = min(T_pred, T_true)

                    loss = criterion(
                        logits[:, :T_min],
                        label_batch[:, :T_min],
                    )
                    val_loss += loss.item()

            avg_val_loss = val_loss / max(1, len(val_loader))
            print(f"           Val loss:   {avg_val_loss:.4f}")

    return model



## Training
Automatically gather every labeled track, optionally shuffle/hold out validation songs, train the CRNN, and save a checkpoint for inference helpers.


In [None]:
# Resolve project/data directories whether the notebook is launched from
# repo root or the notebooks/ subdirectory.
PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "data").exists() and (PROJECT_ROOT.parent / "data").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

data_root = PROJECT_ROOT / "data"
audio_dir = data_root / "audio"
label_dir = data_root / "labels"

# Training hyperparameters / knobs
VAL_COUNT = 2          # hold out this many tracks for validation (taken from end)
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
SHUFFLE_SEED = None    # e.g. 42 to randomize order deterministically
DEVICE = auto          # None => auto (cuda if available)
CHECKPOINT_PATH = PROJECT_ROOT / "beat_crnn_mvp.pth"

all_items = discover_label_items(str(label_dir), str(audio_dir))

if SHUFFLE_SEED is not None:
    random.seed(SHUFFLE_SEED)
    random.shuffle(all_items)

train_items, val_items = split_train_val_items(all_items, VAL_COUNT)

print(f"Project root: {PROJECT_ROOT}")
print(f"Found {len(all_items)} labeled tracks.")
print(f"Training set size: {len(train_items)}")
print(
    f"Validation set size: {len(val_items)}"
    if val_items
    else "Validation set size: 0 (validation disabled)"
)

model = train_mvp(
    train_items=train_items,
    val_items=val_items,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    device=DEVICE,
)

CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), CHECKPOINT_PATH)
print(f"Saved checkpoint to {CHECKPOINT_PATH}")


=== Overfitting on 1 track ===
Using device: cpu
Epoch 1/30 - Train loss: 0.7084
Epoch 2/30 - Train loss: 0.4957
Epoch 3/30 - Train loss: 0.4409
Epoch 4/30 - Train loss: 0.4497
Epoch 5/30 - Train loss: 0.4548
Epoch 6/30 - Train loss: 0.4481
Epoch 7/30 - Train loss: 0.4410
Epoch 8/30 - Train loss: 0.4381
Epoch 9/30 - Train loss: 0.4394
Epoch 10/30 - Train loss: 0.4417
Epoch 11/30 - Train loss: 0.4425
Epoch 12/30 - Train loss: 0.4414
Epoch 13/30 - Train loss: 0.4395
Epoch 14/30 - Train loss: 0.4379
Epoch 15/30 - Train loss: 0.4369
Epoch 16/30 - Train loss: 0.4367
Epoch 17/30 - Train loss: 0.4370
Epoch 18/30 - Train loss: 0.4372
Epoch 19/30 - Train loss: 0.4369
Epoch 20/30 - Train loss: 0.4362
Epoch 21/30 - Train loss: 0.4352
Epoch 22/30 - Train loss: 0.4342
Epoch 23/30 - Train loss: 0.4333
Epoch 24/30 - Train loss: 0.4324
Epoch 25/30 - Train loss: 0.4313
Epoch 26/30 - Train loss: 0.4300
Epoch 27/30 - Train loss: 0.4284
Epoch 28/30 - Train loss: 0.4263
Epoch 29/30 - Train loss: 0.4240
Epo

## Building Beat Grid
This section converts beat activations into musical timing information. Given a track’s activation curve we (1) estimate the BPM, (2) refine the downbeat offset, and (3) output a constant-BPM beatgrid annotated with per-beat confidence samples.

In [52]:
def find_best_offset_from_activation(
    frame_times,
    probs,
    bpm,
    duration,
    search_window=None,
    n_offsets=200,
):
    """
    Search over candidate offsets in [0, period) and pick the one whose
    ideal grid best aligns with high activations.

    frame_times: (T,) seconds
    probs: (T,) beat probabilities
    bpm: float
    duration: float (seconds)
    search_window: (t_start, t_end) or None
        Only use this region for scoring (e.g. skip very quiet intro).
    n_offsets: number of offset candidates to test.

    Returns:
        best_offset: float in [0, period)
        best_score: float mean activation score for that offset.
    """
    period = 60.0 / float(bpm)

    # Restrict to region for scoring (optional)
    ft = np.asarray(frame_times)
    pb = np.asarray(probs)

    if search_window is not None:
        t_start, t_end = search_window
        mask = (ft >= t_start) & (ft <= t_end)
        ft_score = ft[mask]
        pb_score = pb[mask]
        dur_score = min(t_end, duration) - t_start
    else:
        ft_score = ft
        pb_score = pb
        dur_score = duration

    if dur_score <= 0 or len(ft_score) == 0:
        # fallback
        return 0.0, -np.inf

    # Candidate offsets from [0, period)
    offsets = np.linspace(0.0, period, num=n_offsets, endpoint=False)
    best_score = -np.inf
    best_offset = 0.0

    for off in offsets:
        # grid times within score window
        if search_window is not None:
            t0 = max(off, search_window[0])
        else:
            t0 = off
        if t0 > duration:
            continue

        grid_times = np.arange(t0, duration, period)

        # restrict to scoring window
        if search_window is not None:
            grid_times = grid_times[
                (grid_times >= search_window[0]) & (grid_times <= search_window[1])
            ]

        if len(grid_times) == 0:
            continue

        # sample activation at grid times
        grid_probs = sample_probs_at_times(ft_score, pb_score, grid_times)

        # score: mean prob (you can also use sum or top-k mean)
        score = grid_probs.mean()

        if score > best_score:
            best_score = score
            best_offset = off

    return best_offset, best_score


def sample_probs_at_times(frame_times, probs, query_times):
    """
    Given frame_times (T,), probs (T,) and query_times (K,),
    return probs at nearest frame for each query time.
    """
    frame_times = np.asarray(frame_times)
    probs = np.asarray(probs)
    query_times = np.asarray(query_times)

    indices = np.searchsorted(frame_times, query_times, side="left")
    indices = np.clip(indices, 0, len(frame_times) - 1)

    # Optional: for middle-between-frames, pick closer of left/right
    left_indices = np.clip(indices - 1, 0, len(frame_times) - 1)
    right_indices = indices
    left_d = np.abs(frame_times[left_indices] - query_times)
    right_d = np.abs(frame_times[right_indices] - query_times)
    use_left = left_d < right_d
    final_indices = np.where(use_left, left_indices, right_indices)

    return probs[final_indices]


def estimate_bpm_from_activation(
    frame_times,
    probs,
    duration,
    bpm_min=70,
    bpm_max=200,
    bpm_step=0.5,
    search_window=None,
    n_offsets=120,
):
    """
    Grid-search BPM by evaluating how well a constant beatgrid aligns with the
    activation curve. Returns (best_bpm, score).
    """
    candidate_bpms = np.arange(bpm_min, bpm_max + 1e-6, bpm_step)
    best_bpm = None
    best_score = -np.inf

    for bpm in candidate_bpms:
        if bpm <= 0:
            continue
        _, score = find_best_offset_from_activation(
            frame_times,
            probs,
            bpm,
            duration,
            search_window=search_window,
            n_offsets=n_offsets,
        )

        if score > best_score:
            best_score = score
            best_bpm = float(bpm)

    if best_bpm is None:
        best_bpm = float(bpm_min)

    return best_bpm, best_score


def build_beatgrid_from_activation(
    frame_times,
    probs,
    bpm,
    duration,
    search_window=None,
    best_offset=None,
):
    """
    Given beat activation, BPM and duration, produce a beatgrid.

    Returns:
        beat_grid: [t0, t1, ...] grid times (seconds)
        info_per_beat: list of dicts with optional confidence etc.
    """
    period = 60.0 / float(bpm)

    if best_offset is None:
        best_offset, _ = find_best_offset_from_activation(
            frame_times, probs, bpm, duration, search_window=search_window
        )

    # Generate full grid
    t = best_offset
    beat_grid = []
    while t < duration + 1e-6:
        beat_grid.append(float(t))
        t += period

    beat_grid = np.array(beat_grid)

    # For each grid beat, find nearest activation frame + prob
    probs_at_grid = sample_probs_at_times(frame_times, probs, beat_grid)

    info_per_beat = []
    for t_grid, p in zip(beat_grid, probs_at_grid):
        info_per_beat.append(
            {
                "time": float(t_grid),
                "prob_at_grid": float(p),
            }
        )

    return beat_grid, info_per_beat, best_offset


def compute_activation_for_track(model, audio_path):
    """
    Run model on a track and return (frame_times, probs, duration).
    """
    model.eval()
    device = next(model.parameters()).device

    mel_db, frame_times, duration = load_audio_to_mel(audio_path)

    with torch.no_grad():
        x = torch.from_numpy(mel_db).float().unsqueeze(0).to(device)
        logits = model(x)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    # match frame_times length to probs (due to CNN pooling)
    frame_times = frame_times[: len(probs)]

    return frame_times, probs, duration


def infer_bpm_offset_and_grid(
    model,
    audio_path,
    bpm_min=70,
    bpm_max=200,
    bpm_step=0.5,
    search_window=None,
):
    """
    Convenience wrapper that runs the model, estimates BPM + downbeat offset,
    and returns the full beatgrid plus raw activation.
    """
    frame_times, probs, duration = compute_activation_for_track(model, audio_path)

    best_bpm, bpm_score = estimate_bpm_from_activation(
        frame_times,
        probs,
        duration,
        bpm_min=bpm_min,
        bpm_max=bpm_max,
        bpm_step=bpm_step,
        search_window=search_window,
    )

    best_offset, _ = find_best_offset_from_activation(
        frame_times,
        probs,
        best_bpm,
        duration,
        search_window=search_window,
    )

    beat_grid, info_per_beat, _ = build_beatgrid_from_activation(
        frame_times,
        probs,
        best_bpm,
        duration,
        search_window=search_window,
        best_offset=best_offset,
    )

    return {
        "bpm": best_bpm,
        "bpm_score": bpm_score,
        "offset": best_offset,
        "beat_grid": beat_grid,
        "info_per_beat": info_per_beat,
        "frame_times": frame_times,
        "probs": probs,
        "duration": duration,
    }


## Running The Beatgrid Pipeline
`infer_bpm_offset_and_grid` glues everything together: it runs the CRNN, estimates BPM from the activation curve, finds the downbeat offset, and samples the final beatgrid. The returned dict also exposes the raw activation so you can build your own visualizations or compare against label JSONs.

In [70]:
model.load_state_dict(torch.load("beat_crnn_mvp.pth", map_location="cpu"))
audio_path = "data/audio/Open Your Mind.m4a"

results = infer_bpm_offset_and_grid(
    model,
    audio_path,
    bpm_min=120,
    bpm_max=190,
    bpm_step=0.5,
)

print(f"Estimated BPM: {results['bpm']:.2f}")
print(f"Estimated downbeat offset: {results['offset']:.3f}s")
print(f"Num beats: {len(results['beat_grid'])}")
results["beat_grid"][:10]


  y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Estimated BPM: 150.00
Estimated downbeat offset: 0.056s
Num beats: 436


array([0.056, 0.456, 0.856, 1.256, 1.656, 2.056, 2.456, 2.856, 3.256,
       3.656])