In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import LEDTokenizer, LEDForSequenceClassification
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import pickle
import os

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# Step 1: Load and Preprocess Datasets
interpretation_ds = load_dataset("jamimulgrave/Song-Interpretation-Dataset")['train']
enrich_ds = load_dataset("seungheondoh/enrich-music4all")['train']

In [4]:
# Create mappings
pseudo_map = {row['track_id']: row['pseudo_caption'] for row in enrich_ds}
artist_map = {row['track_id']: row['artist_name'] for row in enrich_ds}
tag_map = {row['track_id']: row.get('tag_list', []) for row in enrich_ds}

In [5]:
# Extract from interpretation_ds
music4all_ids = interpretation_ds['music4all_id']
descriptions = interpretation_ds['comment']
lyrics_list = interpretation_ds['lyrics']
num_samples = len(music4all_ids)

In [6]:
num_samples

310315

In [7]:
train_idx = int(0.8 * num_samples)
val_idx = int(0.9 * num_samples)
train_ids, train_descs, train_lyrics = music4all_ids[:train_idx], descriptions[:train_idx], lyrics_list[:train_idx]
val_ids, val_descs, val_lyrics = music4all_ids[train_idx:val_idx], descriptions[train_idx:val_idx], lyrics_list[train_idx:val_idx]
test_ids, test_descs, test_lyrics = music4all_ids[val_idx:], descriptions[val_idx:], lyrics_list[val_idx:]

In [8]:
def generate_pairs(ids, real_descs, lyrics, all_lyrics, artist_map, tag_map, pseudo_map, num_neg=4, max_neg_pool=500, max_attempts=10):
    # Check if cached pairs exist
    cache_file = f"pairs_{ids[0][:4]}_to_{ids[-1][:4]}.pkl"
    if os.path.exists(os.path.join("persistent_volume", cache_file)):
        with open(os.path.join("persistent_volume", cache_file), 'rb') as f:
            return pickle.load(f)
    
    positives = []
    negatives = []
    
    # Use a subset of all_lyrics for efficiency
    all_lyrics_subset = all_lyrics[:min(max_neg_pool, len(all_lyrics))]
    id_lyric_map = dict(zip(ids[:max_neg_pool], all_lyrics_subset))
    
    # Add progress bar for the main loop
    for i, (sid, real_desc, pos_lyric) in enumerate(tqdm(zip(ids, real_descs, lyrics), total=len(ids), desc="Generating pairs")):
        # Positive pairs: Real description + pseudo if available
        positives.append((real_desc, pos_lyric, 1))
        if sid in pseudo_map:
            positives.append((pseudo_map[sid], pos_lyric, 1))
        
        # Generate negatives with attempt limit
        neg_count = 0
        artist = artist_map.get(sid, '')
        pos_tags = tag_map.get(sid, [])
        attempts = 0
        
        while neg_count < num_neg and attempts < max_attempts:
            # Same artist negative
            if artist and neg_count < 1:
                same_artist_ids = [id_ for id_ in ids[:max_neg_pool] if id_ != sid and artist_map.get(id_) == artist]
                if same_artist_ids:
                    neg_sid = random.choice(same_artist_ids)
                    neg_lyric = id_lyric_map[neg_sid]
                    if neg_lyric != pos_lyric and not any(neg[1] == neg_lyric for neg in negatives):
                        negatives.append((real_desc, neg_lyric, 0))
                        neg_count += 1
                        attempts = 0  # Reset attempts on success
                    else:
                        attempts += 1
                else:
                    attempts += 1
            
            # Tag-based negative
            if pos_tags and neg_count < num_neg:
                candidate_ids = [id_ for id_ in ids[:max_neg_pool] if id_ != sid and tag_map.get(id_, [])]
                if candidate_ids:
                    neg_sid = random.choice(candidate_ids)
                    neg_lyric = id_lyric_map[neg_sid]
                    if neg_lyric != pos_lyric and not any(neg[1] == neg_lyric for neg in negatives):
                        negatives.append((real_desc, neg_lyric, 0))
                        neg_count += 1
                        attempts = 0  # Reset attempts on success
                    else:
                        attempts += 1
                else:
                    attempts += 1
            
            # Fall back to random negative
            if neg_count < num_neg:
                neg_lyric = random.choice(all_lyrics_subset)
                while neg_lyric == pos_lyric:
                    neg_lyric = random.choice(all_lyrics_subset)
                negatives.append((real_desc, neg_lyric, 0))
                neg_count += 1
                attempts = 0  # Reset attempts on success
        
        if neg_count < num_neg:
            print(f"Warning: Only {neg_count} negatives found for sid {sid}, expected {num_neg}")
    
    pairs = positives + negatives
    random.shuffle(pairs)
    
    # Save to cache
    with open(cache_file, 'wb') as f:
        pickle.dump(pairs, f)
    
    return pairs

In [9]:
all_lyrics = lyrics_list
train_pairs = generate_pairs(train_ids, train_descs, train_lyrics, all_lyrics, artist_map, tag_map, pseudo_map)
val_pairs = generate_pairs(val_ids, val_descs, val_lyrics, all_lyrics, artist_map, tag_map, pseudo_map)
test_pairs = generate_pairs(test_ids, test_descs, test_lyrics, all_lyrics, artist_map, tag_map, pseudo_map)

In [10]:
# Custom Dataset Class
class LyricsMatcherDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_length=1024):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        desc, lyric, label = self.pairs[idx]
        input_text = f"[CLS] {desc} [SEP] {lyric}"
        encoding = self.tokenizer(input_text, truncation=True, max_length=self.max_length, padding='max_length', return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

In [11]:
# Step 3: Setup Model and Tokenizer
tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')
model = LEDForSequenceClassification.from_pretrained('allenai/led-base-16384', num_labels=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of LEDForSequenceClassification were not initialized from the model checkpoint at allenai/led-base-16384 and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LEDForSequenceClassification(
  (led): LEDModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): LEDEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): LEDLearnedPositionalEmbedding(16384, 768)
      (layers): ModuleList(
        (0-5): 6 x LEDEncoderLayer(
          (self_attn): LEDEncoderAttention(
            (longformer_self_attn): LEDEncoderSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
              (value_global): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Linear(in_features=768, out_features=768, bias=True)
          )
    

In [12]:
# DataLoaders
train_dataset = LyricsMatcherDataset(train_pairs, tokenizer)
val_dataset = LyricsMatcherDataset(val_pairs, tokenizer)
test_dataset = LyricsMatcherDataset(test_pairs, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)  # Reduced batch size for memory
val_loader = DataLoader(val_dataset, batch_size=4)
test_loader = DataLoader(test_dataset, batch_size=4)

In [13]:
# Step 4: Training
optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.BCEWithLogitsLoss()

train_losses = []
val_losses = []
epochs = 5
start_epoch = 0

In [14]:
# Load last checkpoint if exists
checkpoint_path = "persistent_volume/last_checkpoint.pth"
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    print(f"Resuming training from epoch {start_epoch}")

In [None]:
for epoch in range(start_epoch, epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device).unsqueeze(1)
        
        try:
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = loss_fn(outputs.logits, labels)
            train_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        except RuntimeError as e:
            print(f"Error in batch: {e}")
            continue
    
    train_losses.append(train_loss / len(train_loader))
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device).unsqueeze(1)
            
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = loss_fn(outputs.logits, labels)
            val_loss += loss.item()
    
    val_losses.append(val_loss / len(val_loader))
    print(f"Epoch {epoch+1}: Train Loss {train_losses[-1]:.4f}, Val Loss {val_losses[-1]:.4f}")
    
    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses
    }, checkpoint_path)

  0%|          | 41/357183 [00:59<143:36:56,  1.45s/it]

In [None]:
# Step 5: Testing and Metrics
def compute_ranking_metrics(model, descs, lyrics, all_lyrics):
    model.eval()
    mrr = 0
    recall_at_5 = 0
    recall_at_10 = 0
    num_queries = len(descs)
    
    with torch.no_grad():
        for i, desc in enumerate(tqdm(descs)):
            scores = []
            for lyric in all_lyrics:
                input_text = f"[CLS] {desc} [SEP] {lyric}"
                encoding = tokenizer(input_text, truncation=True, max_length=16384, padding='max_length', return_tensors='pt').to(device)
                output = model(**encoding)
                score = torch.sigmoid(output.logits).item()
                scores.append(score)
            
            ranked_indices = np.argsort(scores)[::-1]
            pos_rank = np.where(ranked_indices == all_lyrics.index(lyrics[i]))[0][0] + 1
            
            mrr += 1 / pos_rank
            recall_at_5 += 1 if pos_rank <= 5 else 0
            recall_at_10 += 1 if pos_rank <= 10 else 0
    
    return {
        'MRR': mrr / num_queries,
        'Recall@5': recall_at_5 / num_queries,
        'Recall@10': recall_at_10 / num_queries
    }

test_metrics = compute_ranking_metrics(model, test_descs, test_lyrics, all_lyrics[:100])  # Subset for demo
print("Test Metrics:", test_metrics)

correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        preds = torch.sigmoid(outputs.logits).squeeze() > 0.5
        correct += (preds == labels.bool()).sum().item()
        total += len(labels)
print(f"Test Accuracy: {correct / total:.4f}")

# Step 6: Plots
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig('loss_curve.png')
plt.show()