# Bi-Encoder Architecture for Lyrics-Description Matching
## Улучшенная архитектура с раздельными энкодерами

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import pickle
import os
from typing import Dict, List, Tuple

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Data Loading and Preprocessing

In [None]:
# Load datasets
interpretation_ds = load_dataset("jamimulgrave/Song-Interpretation-Dataset")['train']
enrich_ds = load_dataset("seungheondoh/enrich-music4all")['train']

print(f"Interpretation dataset size: {len(interpretation_ds)}")
print(f"Enrich dataset size: {len(enrich_ds)}")

In [None]:
# Create mappings
pseudo_map = {row['track_id']: row['pseudo_caption'] for row in enrich_ds}
artist_map = {row['track_id']: row['artist_name'] for row in enrich_ds}
tag_map = {row['track_id']: row.get('tag_list', []) for row in enrich_ds}

print(f"Mappings created: {len(pseudo_map)} tracks")

In [None]:
# Extract data
music4all_ids = interpretation_ds['music4all_id']
descriptions = interpretation_ds['comment']
lyrics_list = interpretation_ds['lyrics']
num_samples = len(music4all_ids)

print(f"Total samples: {num_samples}")

In [None]:
# Train/Val/Test split
train_idx = int(0.8 * num_samples)
val_idx = int(0.9 * num_samples)

train_ids = music4all_ids[:train_idx]
train_descs = descriptions[:train_idx]
train_lyrics = lyrics_list[:train_idx]

val_ids = music4all_ids[train_idx:val_idx]
val_descs = descriptions[train_idx:val_idx]
val_lyrics = lyrics_list[train_idx:val_idx]

test_ids = music4all_ids[val_idx:]
test_descs = descriptions[val_idx:]
test_lyrics = lyrics_list[val_idx:]

print(f"Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

## Bi-Encoder Model Architecture

In [None]:
class BiEncoder(nn.Module):
    """Bi-Encoder with separate encoders for descriptions and lyrics."""
    
    def __init__(self, model_name='allenai/longformer-base-4096', embedding_dim=768, projection_dim=512):
        super(BiEncoder, self).__init__()
        
        # Shared base encoder (can be split into two if needed)
        self.desc_encoder = AutoModel.from_pretrained(model_name)
        self.lyrics_encoder = AutoModel.from_pretrained(model_name)
        
        # Projection heads for contrastive learning
        self.desc_projection = nn.Sequential(
            nn.Linear(embedding_dim, projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
        
        self.lyrics_projection = nn.Sequential(
            nn.Linear(embedding_dim, projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
        
    def encode_description(self, input_ids, attention_mask):
        """Encode description text."""
        outputs = self.desc_encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token embedding
        pooled = outputs.last_hidden_state[:, 0, :]
        projected = self.desc_projection(pooled)
        # L2 normalize for cosine similarity
        return F.normalize(projected, p=2, dim=1)
    
    def encode_lyrics(self, input_ids, attention_mask):
        """Encode lyrics text."""
        outputs = self.lyrics_encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token embedding
        pooled = outputs.last_hidden_state[:, 0, :]
        projected = self.lyrics_projection(pooled)
        # L2 normalize for cosine similarity
        return F.normalize(projected, p=2, dim=1)
    
    def forward(self, desc_input_ids, desc_attention_mask, lyrics_input_ids, lyrics_attention_mask):
        """Forward pass returns both embeddings."""
        desc_emb = self.encode_description(desc_input_ids, desc_attention_mask)
        lyrics_emb = self.encode_lyrics(lyrics_input_ids, lyrics_attention_mask)
        return desc_emb, lyrics_emb

## Contrastive Loss with In-Batch Negatives

In [None]:
class ContrastiveLoss(nn.Module):
    """InfoNCE / NT-Xent loss for contrastive learning."""
    
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, desc_embeddings, lyrics_embeddings):
        """
        Args:
            desc_embeddings: [batch_size, embedding_dim]
            lyrics_embeddings: [batch_size, embedding_dim]
        """
        batch_size = desc_embeddings.shape[0]
        
        # Compute similarity matrix
        # [batch_size, batch_size]
        similarity_matrix = torch.matmul(desc_embeddings, lyrics_embeddings.T) / self.temperature
        
        # Labels: diagonal elements are positive pairs
        labels = torch.arange(batch_size, device=desc_embeddings.device)
        
        # Bidirectional loss (description->lyrics and lyrics->description)
        loss_desc = self.criterion(similarity_matrix, labels)
        loss_lyrics = self.criterion(similarity_matrix.T, labels)
        
        return (loss_desc + loss_lyrics) / 2

## Dataset for Bi-Encoder

In [None]:
class BiEncoderDataset(Dataset):
    """Dataset that tokenizes descriptions and lyrics separately."""
    
    def __init__(self, descriptions: List[str], lyrics: List[str], tokenizer, 
                 max_desc_length=512, max_lyrics_length=4096):
        self.descriptions = descriptions
        self.lyrics = lyrics
        self.tokenizer = tokenizer
        self.max_desc_length = max_desc_length
        self.max_lyrics_length = max_lyrics_length
        
        assert len(descriptions) == len(lyrics), "Descriptions and lyrics must have same length"
    
    def __len__(self):
        return len(self.descriptions)
    
    def __getitem__(self, idx):
        desc = self.descriptions[idx]
        lyric = self.lyrics[idx]
        
        # Tokenize description
        desc_encoding = self.tokenizer(
            desc,
            truncation=True,
            max_length=self.max_desc_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Tokenize lyrics
        lyrics_encoding = self.tokenizer(
            lyric,
            truncation=True,
            max_length=self.max_lyrics_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'desc_input_ids': desc_encoding['input_ids'].squeeze(0),
            'desc_attention_mask': desc_encoding['attention_mask'].squeeze(0),
            'lyrics_input_ids': lyrics_encoding['input_ids'].squeeze(0),
            'lyrics_attention_mask': lyrics_encoding['attention_mask'].squeeze(0)
        }

## Initialize Model and Data

In [None]:
# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')
model = BiEncoder(model_name='allenai/longformer-base-4096', embedding_dim=768, projection_dim=512)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(f"Model loaded on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

In [None]:
# Create datasets
train_dataset = BiEncoderDataset(train_descs, train_lyrics, tokenizer, 
                                  max_desc_length=512, max_lyrics_length=4096)
val_dataset = BiEncoderDataset(val_descs, val_lyrics, tokenizer, 
                                max_desc_length=512, max_lyrics_length=4096)
test_dataset = BiEncoderDataset(test_descs, test_lyrics, tokenizer, 
                                 max_desc_length=512, max_lyrics_length=4096)

print(f"Train dataset: {len(train_dataset)}")
print(f"Val dataset: {len(val_dataset)}")
print(f"Test dataset: {len(test_dataset)}")

In [None]:
# Create dataloaders
batch_size = 8  # Adjust based on GPU memory

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## Training Setup

In [None]:
# Training hyperparameters
num_epochs = 10
learning_rate = 2e-5
warmup_steps = 500
temperature = 0.07

# Initialize optimizer and loss
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
criterion = ContrastiveLoss(temperature=temperature)

# Learning rate scheduler with warmup
from torch.optim.lr_scheduler import OneCycleLR

scheduler = OneCycleLR(
    optimizer,
    max_lr=learning_rate,
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    anneal_strategy='cos'
)

print("Training setup complete")

## Training Loop

In [None]:
# Training state
train_losses = []
val_losses = []
best_val_loss = float('inf')
checkpoint_dir = 'persistent_volume'
os.makedirs(checkpoint_dir, exist_ok=True)

# Load checkpoint if exists
checkpoint_path = os.path.join(checkpoint_dir, 'bi_encoder_checkpoint.pth')
start_epoch = 0

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    print(f"Resuming from epoch {start_epoch}")

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        # Move batch to device
        desc_input_ids = batch['desc_input_ids'].to(device)
        desc_attention_mask = batch['desc_attention_mask'].to(device)
        lyrics_input_ids = batch['lyrics_input_ids'].to(device)
        lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
        
        # Forward pass
        desc_emb, lyrics_emb = model(desc_input_ids, desc_attention_mask, 
                                      lyrics_input_ids, lyrics_attention_mask)
        
        # Compute loss
        loss = criterion(desc_emb, lyrics_emb)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion, device):
    """Evaluate on validation/test set."""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            desc_input_ids = batch['desc_input_ids'].to(device)
            desc_attention_mask = batch['desc_attention_mask'].to(device)
            lyrics_input_ids = batch['lyrics_input_ids'].to(device)
            lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
            
            desc_emb, lyrics_emb = model(desc_input_ids, desc_attention_mask, 
                                          lyrics_input_ids, lyrics_attention_mask)
            
            loss = criterion(desc_emb, lyrics_emb)
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
# Main training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("="*50)
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
    train_losses.append(train_loss)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validate
    val_loss = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    print(f"Val Loss: {val_loss:.4f}")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = os.path.join(checkpoint_dir, 'bi_encoder_best.pth')
        torch.save(model.state_dict(), best_model_path)
        print(f"✓ New best model saved! Val Loss: {val_loss:.4f}")

## Evaluation Metrics

In [None]:
def compute_embeddings(model, dataloader, device):
    """Compute embeddings for all samples."""
    model.eval()
    desc_embeddings = []
    lyrics_embeddings = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing embeddings"):
            desc_input_ids = batch['desc_input_ids'].to(device)
            desc_attention_mask = batch['desc_attention_mask'].to(device)
            lyrics_input_ids = batch['lyrics_input_ids'].to(device)
            lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
            
            desc_emb, lyrics_emb = model(desc_input_ids, desc_attention_mask, 
                                          lyrics_input_ids, lyrics_attention_mask)
            
            desc_embeddings.append(desc_emb.cpu())
            lyrics_embeddings.append(lyrics_emb.cpu())
    
    desc_embeddings = torch.cat(desc_embeddings, dim=0)
    lyrics_embeddings = torch.cat(lyrics_embeddings, dim=0)
    
    return desc_embeddings, lyrics_embeddings


def compute_retrieval_metrics(desc_embeddings, lyrics_embeddings, k_values=[1, 5, 10]):
    """Compute retrieval metrics (Recall@K, MRR)."""
    # Compute similarity matrix
    similarity_matrix = torch.matmul(desc_embeddings, lyrics_embeddings.T)
    
    # For each description, rank lyrics by similarity
    num_queries = similarity_matrix.shape[0]
    
    recall_at_k = {k: 0 for k in k_values}
    mrr = 0
    
    for i in range(num_queries):
        # Get ranking (indices sorted by similarity)
        ranking = torch.argsort(similarity_matrix[i], descending=True)
        
        # Find position of correct match (diagonal element)
        correct_idx = i
        position = (ranking == correct_idx).nonzero(as_tuple=True)[0].item()
        rank = position + 1
        
        # Update MRR
        mrr += 1.0 / rank
        
        # Update Recall@K
        for k in k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    # Normalize
    mrr /= num_queries
    for k in k_values:
        recall_at_k[k] /= num_queries
    
    return mrr, recall_at_k

In [None]:
# Load best model for evaluation
best_model_path = os.path.join(checkpoint_dir, 'bi_encoder_best.pth')
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    print("Best model loaded for evaluation")

In [None]:
# Compute test embeddings
test_desc_emb, test_lyrics_emb = compute_embeddings(model, test_loader, device)

print(f"Description embeddings shape: {test_desc_emb.shape}")
print(f"Lyrics embeddings shape: {test_lyrics_emb.shape}")

In [None]:
# Compute retrieval metrics
mrr, recall_at_k = compute_retrieval_metrics(test_desc_emb, test_lyrics_emb, k_values=[1, 5, 10, 20])

print("\n" + "="*50)
print("RETRIEVAL METRICS (Test Set)")
print("="*50)
print(f"MRR: {mrr:.4f}")
for k, value in recall_at_k.items():
    print(f"Recall@{k}: {value:.4f}")

## Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
k_list = list(recall_at_k.keys())
recall_values = [recall_at_k[k] for k in k_list]
plt.bar(k_list, recall_values, color='skyblue')
plt.xlabel('K')
plt.ylabel('Recall@K')
plt.title('Retrieval Performance')
plt.grid(True, axis='y')

plt.tight_layout()
plt.savefig(os.path.join(checkpoint_dir, 'training_results.png'), dpi=300)
plt.show()

## Save Final Model

In [None]:
# Save final model
final_model_path = os.path.join(checkpoint_dir, 'bi_encoder_final.pth')
torch.save({
    'model_state_dict': model.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'test_mrr': mrr,
    'test_recall': recall_at_k,
    'hyperparameters': {
        'learning_rate': learning_rate,
        'batch_size': batch_size,
        'temperature': temperature,
        'num_epochs': num_epochs
    }
}, final_model_path)

print(f"Final model saved to {final_model_path}")