In [2]:
import logging
import os
import pickle
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [5]:
df = pd.read_parquet(r'../cleaned dataset/track_cleaned.parquet')
df

Unnamed: 0,track_idx,track_uri,album_name,artist_name,track_name,danceability,energy,key,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,time_signature,lyrics,duration
61588,0,spotify:track:000DfZJww8KiixTKuk9usJ,The Change I'm Seeking,Mike Love,Earthlings,0.631,0.513,2.0,-6.376,1.0,0.0293,0.366000,0.000004,0.1090,0.307,120.365,4.0,I just can't take no more\n I gotta get out of...,357
125100,1,spotify:track:000GjfnQc7ggBayDiy1sLW,Y las Mariposas,El Poder De Zacatecas,Abeja Miope,0.913,0.748,9.0,-3.274,1.0,0.0428,0.074500,0.000956,0.0403,0.864,114.143,4.0,,140
208027,2,spotify:track:000JCyEkMFumqCZQJAORiQ,Enough Is Enough,Nipsey Hussle,California Water,0.795,0.874,0.0,-4.523,1.0,0.2100,0.064600,0.000000,0.3410,0.483,132.966,4.0,,207
102801,3,spotify:track:000VZqvXwT0YNqKk7iG2GS,Dear Youth,The Ghost Inside,Mercy,0.444,0.991,7.0,-4.167,1.0,0.1330,0.000085,0.000084,0.1200,0.106,124.016,4.0,For whom the bell tolls\n There's a hurricane ...,256
69142,4,spotify:track:000uWezkHfg6DbUPf2eDFO,Dancehall Days,The Beautiful Girls,Me I Disconnect From You,0.714,0.635,1.0,-10.769,1.0,0.0299,0.001940,0.259000,0.0839,0.360,134.007,4.0,,321
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46540,252231,spotify:track:7zywhdkPljk4fOyWkh3EqI,Live & Solo At the Artists Den,Ben Kweller,Lizzy,0.575,0.269,7.0,-8.274,1.0,0.0365,0.779000,0.000000,0.6670,0.403,110.171,4.0,Sign me up I volunteer\nVotes are in for lifeg...,233
117614,252232,spotify:track:7zzBEZBTJejWeL6EqWmCD9,All This Bad Blood,Bastille,Get Home,0.599,0.525,9.0,-6.745,0.0,0.0397,0.729000,0.000046,0.0909,0.186,115.665,4.0,"\nHow am I gonna get myself back home?\nI, I, ...",191
195994,252233,spotify:track:7zzLt6Z9y7jMvXnEg00n58,The Sunny Album (Deluxe Edition),Hippie Sabotage,Quit Wastin Time,0.744,0.581,8.0,-10.225,0.0,0.1460,0.593000,0.958000,0.2040,0.679,126.910,4.0,No Lyrics,69
189208,252234,spotify:track:7zzbfi8fvHe6hm342GcNYl,Ace,Bob Weir,Black-Throated Wind,0.533,0.547,9.0,-9.290,1.0,0.0326,0.029900,0.011300,0.0723,0.669,72.506,4.0,"Youre bringing me down, Im running aground,\n ...",342


In [None]:
import torch
import pandas as pd
import logging
import pickle
import os
from tqdm import tqdm
from transformers import pipeline

# -------------------------------
# 1️⃣ Setup Logging
# -------------------------------
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)

# -------------------------------
# 2️⃣ Load Dataset & Preprocess Lyrics
# -------------------------------
logging.info("🔹 Loading dataset...")
df = pd.read_parquet("/Users/xavierhua/Documents/GitHub/bt4222grp9/phase2_data_cleaning/cleaned dataset/track_cleaned.parquet")
if "lyrics" not in df.columns:
    raise ValueError("❌ The dataset must contain a 'lyrics' column.")

df = df.dropna(subset=["lyrics"])
df = df[~df["lyrics"].isin(["No Lyrics", "none", "None"])].copy()
df = df[:300]

lyrics_list = df["lyrics"].tolist()
track_idx_list = df["track_idx"].tolist()
logging.info(f"✅ Using {len(df)} valid lyrics for Zero-Shot Classification.")

# -------------------------------
# 3️⃣ Load Zero-Shot Classification Model
# -------------------------------
logging.info("🔹 Loading Zero-Shot Classifier...")
device = 0 if torch.cuda.is_available() else -1
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
classifier = pipeline(
    "zero-shot-classification",
    model=model_name,
    device=device,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

emotions = ["joy", "calm", "sadness", "fear", "energizing", "dreamy"]

# -------------------------------
# 4️⃣ Checkpoint Setup
# -------------------------------
checkpoint_path = "classification_checkpoint.pkl"
start_idx = 0
emotion_scores = {emotion: [] for emotion in emotions}
track_ids = []

if os.path.exists(checkpoint_path):
    logging.info("🔄 Resuming from last checkpoint...")
    with open(checkpoint_path, "rb") as f:
        saved_data = pickle.load(f)
        start_idx = saved_data["last_processed_idx"]
        emotion_scores = saved_data["emotion_scores"]
        track_ids = saved_data["track_ids"]

logging.info(f"🔹 Resuming from index {start_idx} / {len(lyrics_list)}")

# -------------------------------
# 5️⃣ Batched Zero-Shot Classification
# -------------------------------
batch_size_classify = 32  # Increase batch size if possible
logging.info("🔹 Running Zero-Shot Classification...")

batches_processed = 0
for i in tqdm(range(start_idx, len(lyrics_list), batch_size_classify), desc="Classifying Emotions"):
    batch_lyrics = lyrics_list[i:i + batch_size_classify]
    batch_track_ids = track_idx_list[i:i + batch_size_classify]
    
    batch_results = classifier(batch_lyrics, candidate_labels=emotions, multi_label=True)
    if isinstance(batch_results, dict):
        batch_results = [batch_results]

    for idx, result in enumerate(batch_results):
        track_ids.append(batch_track_ids[idx])
        scores_dict = dict(zip(result["labels"], result["scores"]))
        for emotion in emotions:
            emotion_scores[emotion].append(scores_dict.get(emotion, 0.0))
    
    batches_processed += 1
    if batches_processed % 2 == 0:  # Adjust frequency as needed
        with open(checkpoint_path, "wb") as f:
            pickle.dump({
                "last_processed_idx": i + batch_size_classify,
                "emotion_scores": emotion_scores,
                "track_ids": track_ids
            }, f)
        logging.info(f"💾 Checkpoint saved at index {i + batch_size_classify}")

logging.info("✅ Classification complete. Saving results...")

# -------------------------------
# 6️⃣ Save Results
# -------------------------------
min_length = min(len(track_ids), *[len(scores) for scores in emotion_scores.values()])
track_ids = track_ids[:min_length]
emotion_scores = {k: v[:min_length] for k, v in emotion_scores.items()}

output_df = pd.DataFrame({
    "track_idx": track_ids,
    **emotion_scores
})
output_df = output_df.drop_duplicates(subset=["track_idx"], keep="first")
output_df.to_csv("full_dataset_emotion_scores.csv", index=False)

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)

logging.info("✅ Full dataset emotion scores saved as 'full_dataset_emotion_scores.csv'.")
print("✅ Emotion classification complete! Results saved.")

2025-03-18 22:21:44,568 - INFO - 🔹 Loading dataset...
2025-03-18 22:21:45,228 - INFO - ✅ Using 300 valid lyrics for Zero-Shot Classification.
2025-03-18 22:21:45,228 - INFO - 🔹 Loading Zero-Shot Classifier...
Device set to use cpu
2025-03-18 22:21:46,626 - INFO - 🔹 Resuming from index 0 / 300
2025-03-18 22:21:46,626 - INFO - 🔹 Running Zero-Shot Classification...
Classifying Emotions:  40%|████      | 4/10 [02:21<03:32, 35.48s/it]2025-03-18 22:24:42,627 - INFO - 💾 Checkpoint saved at index 160
Classifying Emotions:  90%|█████████ | 9/10 [05:16<00:34, 34.69s/it]2025-03-18 22:27:14,299 - INFO - 💾 Checkpoint saved at index 320
Classifying Emotions: 100%|██████████| 10/10 [05:27<00:00, 32.77s/it]
2025-03-18 22:27:14,300 - INFO - ✅ Classification complete. Saving results...
2025-03-18 22:27:14,306 - INFO - ✅ Full dataset emotion scores saved as 'full_dataset_emotion_scores(token).csv'.


✅ Emotion classification complete! Results saved.
