In [7]:
from __future__ import annotations
import os, json, argparse, re, string, unicodedata
from pathlib import Path
from functools import partial
from typing import Dict, List, Tuple
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torchaudio
from tqdm import tqdm

In [8]:
SAMPLE_RATE = 22_050
N_FFT       = 1024
HOP_LENGTH  = 256
N_MELS      = 128
F_MIN       = 20
F_MAX       = 10_000

AUDIO_EXTS  = {".wav", ".flac", ".mp3", ".ogg"}

_taudio_cache: Dict[str, torchaudio.transforms.MelSpectrogram] = {}

In [9]:
_clean_re = re.compile(f"[{re.escape(string.punctuation)}]")

def clean_text(txt: str) -> str:
    txt = unicodedata.normalize("NFKD", txt).lower()
    txt = _clean_re.sub("", txt)
    return re.sub(r"\s+", " ", txt).strip()

In [10]:
def load_audio_any(path: Path) -> np.ndarray:
    """Return mono float32 @ SAMPLE_RATE using sf or torchaudio."""
    try:
        audio, sr = sf.read(path, dtype="float32")
    except Exception:
        # Fallback to torchaudio (handles MP3 via FFmpeg)
        audio, sr = torchaudio.load(path)
        audio = audio.mean(0).numpy()  # [channels, T] → mono
    if sr != SAMPLE_RATE:
        audio = torchaudio.functional.resample(torch.from_numpy(audio), sr, SAMPLE_RATE).numpy()
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    return audio


In [11]:
def process_one(job: Tuple[Path, str, str, Path, Path], device: str="cpu") -> Dict:
    audio_path, track_id, text_prompt, mel_dir, data_root = job
    try:
        audio = load_audio_any(audio_path)

        key = f"{device}:{SAMPLE_RATE}"
        if key not in _taudio_cache:
            _taudio_cache[key] = torchaudio.transforms.MelSpectrogram(
                sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH,
                n_mels=N_MELS, f_min=F_MIN, f_max=F_MAX, power=2.0
            ).to(device)
        mel_spec = _taudio_cache[key](torch.from_numpy(audio).to(device))
        mel_db   = torchaudio.functional.amplitude_to_DB(
            mel_spec, multiplier=10.0, amin=1e-10, db_multiplier=0.0, top_db=80)
        mel_db   = mel_db.cpu().numpy().astype(np.float32)

        mel_path = mel_dir / f"{track_id}.npy"
        np.save(mel_path, mel_db)
        return {
            "track_id":  track_id,
            "audio_path": str(audio_path.relative_to(data_root)),
            "mel_path":   str(mel_path.relative_to(mel_dir.parent)),
            "text":       text_prompt,
        }
    except Exception as e:
        return {"error": str(e), "track_id": track_id}

In [12]:
def preprocess_dataset(data_root: Path, output_root: Path, workers: int, device: str):
    output_root.mkdir(parents=True, exist_ok=True)
    mel_dir = output_root / "mel"
    mel_dir.mkdir(exist_ok=True)

    meta_path = data_root / "metadata.json"
    captions: Dict[str, str] = json.loads(meta_path.read_text()) if meta_path.exists() else {}

    jobs: List[Tuple] = []
    for p in data_root.rglob("*"):
        if p.suffix.lower() in AUDIO_EXTS:
            tid  = p.stem
            txt  = clean_text(captions.get(tid, "unknown track"))
            jobs.append((p, tid, txt, mel_dir, data_root))

    with ProcessPoolExecutor(max_workers=workers) as pool:
        results = list(tqdm(pool.map(partial(process_one, device=device), jobs, chunksize=8),
                            total=len(jobs), desc="Preprocessing"))

    df = pd.DataFrame([r for r in results if "error" not in r])
    df.to_csv(output_root / "metadata.csv", index=False)

    errors = [r for r in results if "error" in r]
    if errors:
        (output_root/"errors.log").write_text("\n".join(map(str, errors)))
        print(f"⚠️  {len(errors)} failures – see errors.log")
    print(f"✔︎  Saved {len(df)} spectrograms → {output_root}")

In [14]:
data_root = Path("./data/raw/fma_small")
output_root = Path("./data/processed")
workers = 1
device = "cuda"

preprocess_dataset(data_root, output_root, workers, device)

Preprocessing:  28%|██▊       | 2209/8000 [02:50<05:59, 16.12it/s][src/libmpg123/layer3.c:INT123_do_layer3():1804] error: dequantization failed!
Preprocessing:  40%|████      | 3209/8000 [04:02<04:46, 16.72it/s]Note: Illegal Audio-MPEG-Header 0x00000000 at offset 33361.
Note: Trying to resync...
Note: Skipped 1024 bytes in input.
[src/libmpg123/parse.c:wetwork():1349] error: Giving up resync after 1024 bytes - your stream is not nice... (maybe increasing resync limit could help).
Note: Illegal Audio-MPEG-Header 0x00000000 at offset 33361.
Note: Trying to resync...
Note: Skipped 1024 bytes in input.
[src/libmpg123/parse.c:wetwork():1349] error: Giving up resync after 1024 bytes - your stream is not nice... (maybe increasing resync limit could help).
Preprocessing:  40%|████      | 3217/8000 [04:03<05:47, 13.77it/s]Note: Illegal Audio-MPEG-Header 0x00000000 at offset 22401.
Note: Trying to resync...
Note: Skipped 1024 bytes in input.
[src/libmpg123/parse.c:wetwork():1349] error: Giving u

⚠️  6 failures – see errors.log
✔︎  Saved 7994 spectrograms → data/processed



