# 01d — MTG Jamendo/FMA combo
Combine and split the data into train/test sets for MTG and FMA

In [None]:
import subprocess
import random
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tqdm import tqdm
import os

FMA_DIR = Path('/root/workspace/data/fma_large/wav_32k_mono')
MTG_JAMENDO_DIR = Path('/root/workspace/data/mtg_jamendo/wav_32k_mono')
OUT_DIR = Path('/root/workspace/data/all_data')



SEGMENT_SECONDS = 60
TARGET_SR = 32000
CHANNELS = 1
TRAIN_RATIO = 0.9
RANDOM_SEED = 42
NUM_SAMPLES_TOTAL = None  # adjust to control how many source tracks to keep

SEGMENTS_DIR = OUT_DIR / f'segments_{SEGMENT_SECONDS}s'


OUT_DIR.mkdir(parents=True, exist_ok=True)
SEGMENTS_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_DIR = OUT_DIR / "manifests"


print(len(os.listdir(FMA_DIR)))
print(len(os.listdir(MTG_JAMENDO_DIR)))


def segment_wav(input_path: Path, output_pattern: Path,
                segment_seconds=SEGMENT_SECONDS, target_sr=TARGET_SR, channels=CHANNELS):
    cmd = [
        "ffmpeg",
        "-hide_banner", "-loglevel", "error",
        "-i", str(input_path),
        "-ar", str(target_sr),
        "-ac", str(channels),
        "-f", "segment",
        "-segment_time", str(segment_seconds),
        "-reset_timestamps", "1",
        str(output_pattern),
    ]
    subprocess.run(cmd, check=True)

# Helper Functions

In [None]:
from pathlib import Path
import subprocess



def _tag(src: Path) -> str:
    s = str(src)
    if s.startswith(str(FMA_DIR)):
        return "fma"
    if s.startswith(str(MTG_JAMENDO_DIR)):
        return "mtg"
    return "unk"

def segment_one(src: Path):
    prefix = _tag(src) + "__"
    # pattern required by ffmpeg segment muxer
    out_pattern = SEGMENTS_DIR / f"{prefix}{src.stem}_%05d.wav"
    segment_wav(src, out_pattern)   # <-- segment_wav should accept (input_path, output_pattern)

# Filter for selected durations

In [8]:
import contextlib
import wave
import subprocess
from pathlib import Path
from tqdm import tqdm


def tag(src: Path) -> str:
    s = str(src)
    if s.startswith(str(FMA_DIR)):
        return "fma"
    if s.startswith(str(MTG_JAMENDO_DIR)):
        return "mtg"
    return "unk"

# 1) Build set of completed sources from existing segment files ONCE
# assumes output pattern like: fma__<stem>_00000.wav
done_keys = set()
for seg in SEGMENTS_DIR.glob("*.wav"):
    # remove trailing "_00000" etc (split from the right once)
    base = seg.stem.rsplit("_", 1)[0]   # "fma__trackname"
    done_keys.add(base)

print(f"Existing segment files: {len(done_keys)} unique sources already done")

def get_duration_seconds_fast(p: Path) -> float:
    # Fast path: WAV header
    with contextlib.closing(wave.open(str(p), "rb")) as wf:
        return wf.getnframes() / float(wf.getframerate())

def should_keep(p: Path, min_seconds: float) -> bool:
    try:
        return get_duration_seconds_fast(p) >= min_seconds
    except Exception:
        # fallback if header read fails
        try:
            r = subprocess.run(
                ["ffprobe", "-v", "error",
                 "-show_entries", "format=duration",
                 "-of", "default=noprint_wrappers=1:nokey=1",
                 str(p)],
                capture_output=True, text=True, check=True
            )
            return float(r.stdout.strip()) >= min_seconds
        except Exception:
            return False

wav_files = sorted(FMA_DIR.rglob("*.wav")) + sorted(MTG_JAMENDO_DIR.rglob("*.wav"))
print(f"Total WAV discovered: {len(wav_files)}")

kept = []
skipped_done = 0

for p in tqdm(wav_files, desc=f"Filtering < {SEGMENT_SECONDS}s (skip done)", unit=" files"):
    key = f"{tag(p)}__{p.stem}"
    if key in done_keys:
        skipped_done += 1
        continue
    if should_keep(p, SEGMENT_SECONDS):
        kept.append(p)

wav_files = kept
print(f"Already segmented (skipped): {skipped_done}")
print(f"WAV after duration filter (to process): {len(wav_files)}")

Existing segment files: 41874 unique sources already done
Total WAV discovered: 162114


Filtering < 60s (skip done): 100%|██████████| 162114/162114 [00:20<00:00, 7746.04 files/s] 

Already segmented (skipped): 41874
WAV after duration filter (to process): 12226





In [9]:
import subprocess
import random
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import os

max_workers = min(os.cpu_count() or 4, 32)
print(f"Max workers: {max_workers}")

with ThreadPoolExecutor(max_workers=max_workers) as executor:
    for _ in tqdm(executor.map(segment_one, wav_files),
                  total=len(wav_files),
                  desc="Segmenting WAVs",
                  unit=" files"):
        pass

Max workers: 32


Segmenting WAVs: 100%|██████████| 12226/12226 [03:30<00:00, 57.97 files/s]


In [None]:
import json
import random
from concurrent.futures import ThreadPoolExecutor

segments = sorted(SEGMENTS_DIR.glob("*.wav"))
random.seed(RANDOM_SEED)
random.shuffle(segments)

split_idx = int(len(segments) * TRAIN_RATIO)
train_files = segments[:split_idx]
valid_files = segments[split_idx:] or segments[-1:]

MANIFEST_DIR.mkdir(parents=True, exist_ok=True)

def write_manifest(split_name, files):
    """Write a manifest file for a split"""
    manifest_path = MANIFEST_DIR / f"{split_name}.jsonl"
    with open(manifest_path, "w") as f:
        for p in files:
            f.write(json.dumps({"path": str(p)}) + "\n")
    return split_name, len(files)

# Write both manifests in parallel
with ThreadPoolExecutor(max_workers=2) as executor:
    futures = [
        executor.submit(write_manifest, "train", train_files),
        executor.submit(write_manifest, "valid", valid_files)
    ]
    results = [f.result() for f in futures]

print("Train/valid counts:", len(train_files), len(valid_files))
print("Sample manifest line:")
print((MANIFEST_DIR / "train.jsonl").read_text().splitlines()[:1])