In [3]:
# 01_embed_panns.ipynb — extract PANNs CNN14 embeddings for all previews
import os, time, numpy as np, pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import librosa
from panns_inference import AudioTagging

# ---- paths ----
ROOT = Path.cwd().parents[1]  # 'web/model-training/notebooks' -> ROOT='web'
CSV  = ROOT  / "data-pipeline" / "input" / "songs_dataset_6k.csv"
AUDIO_DIR = ROOT / "data-pipeline" / "output" / "previews"
OUT_DIR   = ROOT / "model-training" / "output"
OUT_DIR.mkdir(parents=True, exist_ok=True)
EMB_PARQUET = OUT_DIR / "embeddings_panns.parquet"

# ---- audio config ----
SR = 32000
CROP_SEC = 10.0
def center_crop(y, need):
    if len(y) >= need:
        s = (len(y)-need)//2
        return y[s:s+need]
    pad = need-len(y)
    return np.pad(y,(pad//2,pad-pad//2))

def load_audio_32k(path):
    y, _ = librosa.load(path, sr=SR, mono=True)
    return center_crop(y, int(SR*CROP_SEC))

# ---- load CSV ----
df = pd.read_csv(CSV)
assert "audio_file" in df.columns, "CSV must have 'audio_file' filenames"
df["audio_path"] = df["audio_file"].apply(lambda fn: str((AUDIO_DIR/str(fn)).resolve()))
df = df[df["audio_path"].apply(lambda p: Path(p).exists())].reset_index(drop=True)
print("Rows with audio:", len(df))

# ---- backbone ----
model = AudioTagging(checkpoint_path=None, device="cpu")  # CPU is fine/stable

# ---- embed ----
emb_list, idx_list = [], []
t0 = time.time()
for i, row in tqdm(df.iterrows(), total=len(df), desc="Embedding (PANNs CNN14)"):
    try:
        y = load_audio_32k(row["audio_path"])
        _, emb = model.inference(y[None, :])      # (1, 2048)
        emb_list.append(np.asarray(emb).squeeze().astype("float32"))
        idx_list.append(i)
    except Exception:
        # skip corrupted/unreadable files
        pass

assert len(emb_list) > 0, "No embeddings extracted"
emb = np.stack(emb_list, axis=0)

# Keep the label columns we need later
keep_cols = [
    "genre","artist_name","track_name","track_id","audio_file",
    "acousticness","danceability","duration_ms","energy","instrumentalness",
    "key","liveness","loudness","mode","speechiness","tempo","time_signature","valence",
    # "popularity"  # optional; we won't train it
]
meta = df.iloc[idx_list][keep_cols].reset_index(drop=True)

# Write a tidy parquet: meta + e0..e2047
emb_df = pd.DataFrame(emb, columns=[f"e{i}" for i in range(emb.shape[1])])
final = pd.concat([meta, emb_df], axis=1)
final.to_parquet(EMB_PARQUET, engine="pyarrow", compression="zstd", index=False)

dt = time.time() - t0
print(f"✅ Saved: {EMB_PARQUET} | shape={final.shape} | took {dt/60:.1f} min")
final.head()

Rows with audio: 6000
Checkpoint path: /Users/prajeetdarda/panns_data/Cnn14_mAP=0.431.pth
Using CPU.


Embedding (PANNs CNN14):   0%|          | 0/6000 [00:00<?, ?it/s]



✅ Saved: /Users/prajeetdarda/Desktop/All_Coding/AI-Project/web/model-training/output/embeddings_panns.parquet | shape=(6000, 2066) | took 17.0 min


Unnamed: 0,genre,artist_name,track_name,track_id,audio_file,acousticness,danceability,duration_ms,energy,instrumentalness,...,e2038,e2039,e2040,e2041,e2042,e2043,e2044,e2045,e2046,e2047
0,Country,A Thousand Horses,My Time's Comin',16zol4GvHyTER5irYODUk0,a_thousand_horses_my_time_s_comin_0.mp3,0.00192,0.327,194107,0.835,0.00015,...,0.0,0.0,0.0,0.0,0.0,0.0,0.584037,0.289752,0.0,0.0
1,Soundtrack,Mark Mothersbaugh,House Tour,6ac5gUfGTckpdGQCyWsdh2,mark_mothersbaugh_house_tour_1.mp3,0.932,0.253,102920,0.0798,0.568,...,0.0,0.0,0.0,0.0,0.0,0.0,0.66464,0.246523,0.051477,0.0
2,Reggae,Unified Highway,We Can't Fall (Remix) [feat. J. Patz],09Yz6koF1Y15n1012t1UX6,unified_highway_we_can_t_fall_remix_feat_j_pat...,0.0331,0.821,225437,0.737,0.0134,...,0.0,0.0,0.0,0.0,0.0,0.0,0.656714,0.418959,0.411713,0.0
3,Electronic,Stooki Sound,Endz - Original Mix,3dzEZARDL4ZwICMKVta7Xn,stooki_sound_endz_original_mix_3.mp3,0.00428,0.745,225400,0.772,0.114,...,0.0,0.0,0.0,0.0,0.0,0.0,0.466213,0.255734,0.0,0.0
4,Comedy,Bill Hicks,I Love My Job (Live),39Z1G5384UgGa5vmW6WyxC,bill_hicks_i_love_my_job_live_4.mp3,0.965,0.502,287973,0.804,9.6e-05,...,0.0,0.01251,0.0,0.0,0.0,0.0,0.457702,0.411279,0.694587,0.0
