# Whisper Fine-tuning for Lyric Transcription

Fine-tuning Whisper on my song collection to improve lyric transcription accuracy. Using Demucs for vocal separation then training on the isolated vocals.

## Data Setup

Need audio files in `data/songs/` and matching lyrics in `data/lyrics/`. Filenames must match (e.g., `song1.mp3` → `song1.txt`).

In [None]:
# Install dependencies
!pip install -q transformers datasets accelerate
!pip install -q torch torchaudio
!pip install -q demucs
!pip install -q evaluate jiwer

print("Dependencies installed")

In [None]:
import os
import subprocess
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
import re
from tqdm import tqdm

from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import Dataset, Audio
import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Helper Functions

In [None]:
def separate_vocals_demucs(input_song_path: str, output_dir: str = "temp_vocals") -> str:
    """Extract vocals using Demucs"""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Processing: {os.path.basename(input_song_path)}")
    
    try:
        result = subprocess.run([
            "python", "-m", "demucs", 
            "--two-stems", "vocals",
            "-o", output_dir, 
            input_song_path
        ], capture_output=True, text=True, check=True)
        
        song_name = os.path.splitext(os.path.basename(input_song_path))[0]
        vocal_path = os.path.join(output_dir, "htdemucs", song_name, "vocals.wav")
        
        if os.path.exists(vocal_path):
            return vocal_path
        else:
            raise FileNotFoundError(f"Vocal track not found at {vocal_path}")
            
    except subprocess.CalledProcessError as e:
        print(f"Demucs error: {e}")
        raise


def normalize_text(text: str) -> str:
    """Clean up lyrics text and remove structural markers"""
    # Remove structural markers like [Intro], [Verse 1], [Chorus], etc.
    text = re.sub(r'\[.*?\]', '', text)
    # Remove extra whitespace (including from marker removal)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

## Dataset Creation

In [None]:
# Data paths
DATA_DIR = "data"
SONGS_DIR = os.path.join(DATA_DIR, "songs")
LYRICS_DIR = os.path.join(DATA_DIR, "lyrics")
VOCALS_DIR = "processed_vocals"

os.makedirs(VOCALS_DIR, exist_ok=True)

# Check directories exist
if not os.path.exists(SONGS_DIR):
    raise FileNotFoundError(f"Songs directory not found: {SONGS_DIR}")
if not os.path.exists(LYRICS_DIR):
    raise FileNotFoundError(f"Lyrics directory not found: {LYRICS_DIR}")

print(f"Songs: {SONGS_DIR}")
print(f"Lyrics: {LYRICS_DIR}")
print(f"Output: {VOCALS_DIR}")

In [None]:
def create_dataset():
    """Process songs and match with lyrics"""
    dataset_records = []

    song_files = [f for f in os.listdir(SONGS_DIR) if f.endswith((".mp3", ".wav", ".m4a"))]
    print(f"Found {len(song_files)} songs")

    for song_file in tqdm(song_files, desc="Processing"):
        try:
            base_name = os.path.splitext(song_file)[0]
            lyrics_file = f"{base_name}.txt"
            lyrics_path = os.path.join(LYRICS_DIR, lyrics_file)

            if not os.path.exists(lyrics_path):
                print(f"Skipping {song_file}: no lyrics file")
                continue

            # Load lyrics
            with open(lyrics_path, 'r', encoding='utf-8') as f:
                lyrics = f.read()

            lyrics = normalize_text(lyrics)

            if not lyrics:
                print(f"Skipping {song_file}: empty lyrics")
                continue

            # Separate vocals
            song_path = os.path.join(SONGS_DIR, song_file)
            vocal_path = separate_vocals_demucs(song_path, VOCALS_DIR)

            # Store path only (HF Audio will load it)
            record = {
                'audio': vocal_path,
                'text': lyrics,
                'song_name': base_name
            }

            dataset_records.append(record)
            print(f"✓ {song_file}")

        except Exception as e:
            print(f"Error with {song_file}: {e}")
            continue

    print(f"Processed {len(dataset_records)} songs")
    return dataset_records

dataset_records = create_dataset()

In [None]:
# Convert to HF dataset
if len(dataset_records) == 0:
    raise ValueError("No valid song-lyrics pairs found")

dataset = Dataset.from_list(dataset_records)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000, decode=False))

print(f"Dataset: {len(dataset)} examples")
print(f"Features: {dataset.features}")

# Preview
if len(dataset) > 0:
    sample = dataset[0]
    print(f"\nSample:")
    print(f"  Song: {sample['song_name']}")
    print(f"  Text: {len(sample['text'])} chars")
    print(f"  Preview: {sample['text'][:100]}...")

## Load Whisper Model

In [None]:
# Using small model - good balance of performance/resources
MODEL_NAME = "openai/whisper-small"

print(f"Loading {MODEL_NAME}")

processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

model = model.to(device)

print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
def prepare_dataset(batch):
    # Get audio path
    audio_path = batch["audio"]["path"] if isinstance(batch["audio"], dict) else batch["audio"]

    # Load audio using torchaudio
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Resample to 16kHz if needed
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)
    
    # Convert to numpy for feature extractor
    audio_array = waveform.squeeze().numpy()

    # Extract input features using processor
    batch["input_features"] = processor.feature_extractor(
        audio_array,
        sampling_rate=16000
    ).input_features[0]

    # Tokenize text using processor
    tokenized = processor.tokenizer(
        batch["text"],
        max_length=448,
        truncation=True,
        return_overflowing_tokens=False
    )
    labels = tokenized.input_ids
    if len(labels) == 0:
        labels = [processor.tokenizer.pad_token_id]
    batch["labels"] = labels

    return batch

print("Preprocessing dataset...")
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names, desc="Processing")

print("Dataset preprocessed")
print(f"Features: {dataset.features}")

## Train/Val Split

In [None]:
# Split dataset (80/20)
if len(dataset) > 1:
    train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset = train_test_split["train"]
    eval_dataset = train_test_split["test"]
else:
    # Only one song - use for both (not ideal but works)
    train_dataset = dataset
    eval_dataset = dataset
    print("Warning: Only one song, using for both train/val")

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(eval_dataset)} samples")

## Data Collator

In [None]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator for speech-to-text tasks.
    """
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding token id's of the labels by -100
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

print("✅ Data collator initialized")

## Evaluation Metrics

In [None]:
# Track WER during training
metric_wer = evaluate.load("wer")

def compute_metrics(eval_pred):
    pred_ids, label_ids = eval_pred
    
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer}

print("Metrics ready")

## Training Config

In [None]:
# Training settings - reduce batch size if OOM
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-lyrics-model",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=50,
    num_train_epochs=10,
    learning_rate=1e-5,
    fp16=True,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    dataloader_num_workers=0,
    report_to=[]
)

print(f"Epochs: {training_args.num_train_epochs}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"LR: {training_args.learning_rate}")

## Initialize Trainer

In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

print("Trainer ready")

## Train Model

In [None]:
print("Starting training...")
trainer.train()
print("Training complete")

## Save Model

In [None]:
model_save_path = "./whisper-lyrics-final"

trainer.save_model(model_save_path)
processor.save_pretrained(model_save_path)

print(f"Model saved to: {model_save_path}")

# Create zip for download
!zip -r whisper-lyrics-model.zip whisper-lyrics-final/
print("Backup created: whisper-lyrics-model.zip")

## Done

In [None]:
print("Training complete!")
print(f"Dataset: {len(dataset)} songs")
print(f"Model: {MODEL_NAME}")
print(f"Saved: {model_save_path}")