In [None]:
from __future__ import annotations
import os, shutil
from pathlib import Path
import pandas as pd
import numpy as np

SRC_ROOT = Path(r"/Users/Phillip/Downloads/musicnet 2") #path to the musicnet dataset
DST_ROOT = SRC_ROOT.parent / (SRC_ROOT.name + "_small")

# Keep fixed test IDs if they exist locally
TEST_KEEP_IDS = {1759, 1819, 2106, 2191, 2298, 2303, 2382, 2416, 2556, 2628}

# M1-friendly size (adjust if you want)
TRAIN_KEEP_N = 40          # <- start here (40); 
MIN_EVENTS = 300
COPY_INSTEAD_OF_MOVE = True   # safer; set False if you want to free disk

def ensure_structure(root: Path) -> dict[str, Path]:
    root = root.expanduser().resolve()
    expected = ["train_data", "train_labels", "test_data", "test_labels"]
    if all((root / e).exists() for e in expected):
        base = root
    elif (root / "musicnet").exists() and all(((root / "musicnet") / e).exists() for e in expected):
        base = root / "musicnet"
    else:
        raise FileNotFoundError(f"Can't find {expected} under {root}")
    return {
        "base": base,
        "train_wav": base / "train_data",
        "train_csv": base / "train_labels",
        "test_wav":  base / "test_data",
        "test_csv":  base / "test_labels",
    }

def wav_id(p: Path):
    try: return int(p.stem)
    except: return None

def label_stats(csv_path: Path):
    df = pd.read_csv(csv_path)
    if "start_time" not in df.columns or "end_time" not in df.columns:
        return (0.0, len(df), 0.0)
    t0 = df["start_time"].min()
    t1 = df["end_time"].max()
    dur = float(t1 - t0) if pd.notnull(t1) else 0.0
    events = len(df)
    dens = events / dur if dur > 0 else 0.0
    return (dur, events, dens)

def build_table(wav_dir: Path, csv_dir: Path) -> pd.DataFrame:
    rows = []
    for wav in sorted(wav_dir.glob("*.wav")):
        tid = wav_id(wav)
        if tid is None: 
            continue
        csv = csv_dir / f"{wav.stem}.csv"
        if not csv.exists():
            continue
        size_mb = os.path.getsize(wav) / 1e6
        dur, events, dens = label_stats(csv)
        rows.append(dict(id=tid, wav=wav, csv=csv, size_mb=size_mb, dur_units=dur, events=events, density=dens))
    return pd.DataFrame(rows)

def transfer(src: Path, dst: Path, copy: bool):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if copy: shutil.copy2(src, dst)
    else: shutil.move(str(src), str(dst))

dirs = ensure_structure(SRC_ROOT)
print("Using base:", dirs["base"])

train_tbl = build_table(dirs["train_wav"], dirs["train_csv"])
print("Train tracks:", len(train_tbl))

tbl = train_tbl[train_tbl["events"] >= MIN_EVENTS].copy()

# Score: high density + event count, penalize extreme duration
tbl["score"] = np.log1p(tbl["density"]) + 0.03*np.log1p(tbl["events"]) - 0.15*np.log1p(tbl["dur_units"])
tbl = tbl.sort_values("score", ascending=False)

keep_train = tbl["id"].head(TRAIN_KEEP_N).astype(int).tolist()

print("Keeping train:", len(keep_train))
print("Top 15 train IDs:", keep_train[:15])

for sub in ["train_data", "train_labels", "test_data", "test_labels"]:
    (DST_ROOT / sub).mkdir(parents=True, exist_ok=True)

# Train
for tid in keep_train:
    wav = dirs["train_wav"] / f"{tid}.wav"
    csv = dirs["train_csv"] / f"{tid}.csv"
    if wav.exists() and csv.exists():
        transfer(wav, DST_ROOT/"train_data"/wav.name, COPY_INSTEAD_OF_MOVE)
        transfer(csv, DST_ROOT/"train_labels"/csv.name, COPY_INSTEAD_OF_MOVE)

# Test (only if present)
missing_test = []
for tid in sorted(TEST_KEEP_IDS):
    wav = dirs["test_wav"] / f"{tid}.wav"
    csv = dirs["test_csv"] / f"{tid}.csv"
    if wav.exists() and csv.exists():
        transfer(wav, DST_ROOT/"test_data"/wav.name, COPY_INSTEAD_OF_MOVE)
        transfer(csv, DST_ROOT/"test_labels"/csv.name, COPY_INSTEAD_OF_MOVE)
    else:
        missing_test.append(tid)

print("Reduced dataset at:", DST_ROOT)
if missing_test:
    print("Note: missing some requested test IDs locally:", missing_test)


Using base: /Users/Phillip/Downloads/musicnet 2
Train tracks: 320
Keeping train: 40
Top 15 train IDs: [2310, 2234, 2305, 2238, 2196, 2247, 2232, 2207, 2230, 2240, 2292, 2228, 2214, 2224, 2302]
Reduced dataset at: /Users/Phillip/Downloads/musicnet 2_small
