In [None]:
import torch
import pandas as pd
import logging
import pickle
import os
from tqdm import tqdm
from transformers import pipeline

# -------------------------------
# 1️⃣ Setup Logging
# -------------------------------
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)

# -------------------------------
# 2️⃣ Load Dataset & Preprocess Lyrics
# -------------------------------
logging.info("🔹 Loading dataset...")
df = pd.read_parquet("/Users/xavierhua/Documents/GitHub/bt4222grp9/phase2_data_cleaning/cleaned dataset/track_cleaned.parquet")
if "lyrics" not in df.columns:
    raise ValueError("❌ The dataset must contain a 'lyrics' column.")

df = df.dropna(subset=["lyrics"])
df = df[~df["lyrics"].isin(["No Lyrics", "none", "None"])].copy()

lyrics_list = df["lyrics"].tolist()
track_idx_list = df["track_idx"].tolist()
logging.info(f"✅ Using {len(df)} valid lyrics for Zero-Shot Classification.")

# -------------------------------
# 3️⃣ Load Zero-Shot Classification Model
# -------------------------------
logging.info("🔹 Loading Zero-Shot Classifier...")
device = 0 if torch.cuda.is_available() else -1
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
classifier = pipeline(
    "zero-shot-classification",
    model=model_name,
    device=device,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

emotions = ["joy", "calm", "sadness", "fear", "energizing", "dreamy"]

# -------------------------------
# 4️⃣ Checkpoint Setup
# -------------------------------
checkpoint_path = "classification_checkpoint.pkl"
start_idx = 0
emotion_scores = {emotion: [] for emotion in emotions}
track_ids = []

if os.path.exists(checkpoint_path):
    logging.info("🔄 Resuming from last checkpoint...")
    with open(checkpoint_path, "rb") as f:
        saved_data = pickle.load(f)
        start_idx = saved_data["last_processed_idx"]
        emotion_scores = saved_data["emotion_scores"]
        track_ids = saved_data["track_ids"]

logging.info(f"🔹 Resuming from index {start_idx} / {len(lyrics_list)}")

# -------------------------------
# 5️⃣ Batched Zero-Shot Classification
# -------------------------------
batch_size_classify = 32  # Increase batch size if possible
logging.info("🔹 Running Zero-Shot Classification...")

batches_processed = 0
for i in tqdm(range(start_idx, len(lyrics_list), batch_size_classify), desc="Classifying Emotions"):
    batch_lyrics = lyrics_list[i:i + batch_size_classify]
    batch_track_ids = track_idx_list[i:i + batch_size_classify]
    
    batch_results = classifier(batch_lyrics, candidate_labels=emotions, multi_label=True)
    if isinstance(batch_results, dict):
        batch_results = [batch_results]

    for idx, result in enumerate(batch_results):
        track_ids.append(batch_track_ids[idx])
        scores_dict = dict(zip(result["labels"], result["scores"]))
        for emotion in emotions:
            emotion_scores[emotion].append(scores_dict.get(emotion, 0.0))
    
    batches_processed += 1
    if batches_processed % 2 == 0:  # Adjust frequency as needed
        with open(checkpoint_path, "wb") as f:
            pickle.dump({
                "last_processed_idx": i + batch_size_classify,
                "emotion_scores": emotion_scores,
                "track_ids": track_ids
            }, f)
        logging.info(f"💾 Checkpoint saved at index {i + batch_size_classify}")

logging.info("✅ Classification complete. Saving results...")

# -------------------------------
# 6️⃣ Save Results
# -------------------------------
min_length = min(len(track_ids), *[len(scores) for scores in emotion_scores.values()])
track_ids = track_ids[:min_length]
emotion_scores = {k: v[:min_length] for k, v in emotion_scores.items()}

output_df = pd.DataFrame({
    "track_idx": track_ids,
    **emotion_scores
})
output_df = output_df.drop_duplicates(subset=["track_idx"], keep="first")
output_df.to_csv("full_dataset_emotion_scores.csv", index=False)

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)

logging.info("✅ Full dataset emotion scores saved as 'full_dataset_emotion_scores.csv'.")
print("✅ Emotion classification complete! Results saved.")