# Train Whisper on English Songs

## 1. Setup
Required folders (relative to notebook):
- `data/songs/` : audio files (.mp3)
- `data/lyrics/` : matching text files (`<basename>.txt`)

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate
!pip install -q torch torchaudio
!pip install -q demucs
!pip install -q librosa soundfile
!pip install -q evaluate jiwer

print("✅ All packages installed successfully!")

In [None]:
# Import necessary libraries
import os
import subprocess
import torch
import librosa
import soundfile as sf
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
import json
import re
from tqdm import tqdm

# Transformers and datasets
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback
)
from datasets import Dataset, Audio
import evaluate

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Data Preparation Functions

In [None]:
def separate_vocals_demucs(input_song_path: str, output_dir: str = "temp_vocals") -> str:
    """
    Separate vocals from a song using Demucs.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"🎵 Separating vocals from: {os.path.basename(input_song_path)}")
    
    try:
        # Run Demucs vocal separation
        result = subprocess.run([
            "python", "-m", "demucs", 
            "--two-stems", "vocals",
            "-o", output_dir, 
            input_song_path
        ], capture_output=True, text=True, check=True)
        
        # Find the vocal track
        song_name = os.path.splitext(os.path.basename(input_song_path))[0]
        vocal_path = os.path.join(output_dir, "htdemucs", song_name, "vocals.wav")
        
        if os.path.exists(vocal_path):
            return vocal_path
        else:
            raise FileNotFoundError(f"Vocal track not found at {vocal_path}")
            
    except subprocess.CalledProcessError as e:
        print(f"❌ Error in Demucs separation: {e}")
        print(f"Stderr: {e.stderr}")
        raise


def load_and_preprocess_audio(audio_path: str, target_sample_rate: int = 16000) -> np.ndarray:
    """
    Load and preprocess audio file for Whisper.
    """
    # Load audio
    audio, sr = librosa.load(audio_path, sr=target_sample_rate)
    
    # Ensure audio is mono
    if audio.ndim > 1:
        audio = librosa.to_mono(audio)
    
    return audio


def normalize_text(text: str) -> str:
    """
    Normalize text for training.
    """
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Strip leading/trailing whitespace
    text = text.strip()
    
    return text

## 3. Dataset Creation

In [None]:
# Set your data paths
DATA_DIR = "data"  # Update this if your data folder is in a different location
SONGS_DIR = os.path.join(DATA_DIR, "songs")
LYRICS_DIR = os.path.join(DATA_DIR, "lyrics")
VOCALS_DIR = "processed_vocals"  # Where we'll store separated vocals

# Create output directory
os.makedirs(VOCALS_DIR, exist_ok=True)

# Verify data directories exist
if not os.path.exists(SONGS_DIR):
    raise FileNotFoundError(f"Songs directory not found: {SONGS_DIR}")
if not os.path.exists(LYRICS_DIR):
    raise FileNotFoundError(f"Lyrics directory not found: {LYRICS_DIR}")

print(f"✅ Data directories found:")
print(f"  Songs: {SONGS_DIR}")
print(f"  Lyrics: {LYRICS_DIR}")
print(f"  Output vocals: {VOCALS_DIR}")

In [None]:
def create_dataset():
    """
    Create dataset by processing songs and matching with lyrics.
    Returns a list of dicts where 'audio' is just the path (so HF Audio feature can load it).
    """
    dataset_records = []

    # Get all song files
    song_files = [f for f in os.listdir(SONGS_DIR) if f.endswith((".mp3", ".wav", ".m4a"))]
    print(f"Found {len(song_files)} song files")

    for song_file in tqdm(song_files, desc="Processing songs"):
        try:
            # Get base name without extension
            base_name = os.path.splitext(song_file)[0]

            # Check if corresponding lyrics file exists
            lyrics_file = f"{base_name}.txt"
            lyrics_path = os.path.join(LYRICS_DIR, lyrics_file)

            if not os.path.exists(lyrics_path):
                print(f"⚠️  Skipping {song_file}: No matching lyrics file found")
                continue

            # Load lyrics
            with open(lyrics_path, 'r', encoding='utf-8') as f:
                lyrics = f.read()

            lyrics = normalize_text(lyrics)

            if not lyrics:
                print(f"⚠️  Skipping {song_file}: Empty lyrics file")
                continue

            # Separate vocals using Demucs
            song_path = os.path.join(SONGS_DIR, song_file)
            vocal_path = separate_vocals_demucs(song_path, VOCALS_DIR)

            # IMPORTANT CHANGE:
            # Instead of embedding raw audio samples (which became Python lists when serialized),
            # store only the file path. The datasets Audio feature will load & (re)sample it.
            record = {
                'audio': vocal_path,   # path string only
                'text': lyrics,
                'song_name': base_name
            }

            dataset_records.append(record)
            print(f"✅ Processed: {song_file}")

        except Exception as e:
            print(f"❌ Error processing {song_file}: {e}")
            continue

    print(f"\n📊 Successfully processed {len(dataset_records)} songs")
    return dataset_records

# Create the dataset
dataset_records = create_dataset()

In [None]:
# Convert to Hugging Face dataset
if len(dataset_records) == 0:
    raise ValueError("No valid song-lyrics pairs found. Please check your data.")

# Create dataset
dataset = Dataset.from_list(dataset_records)

# Cast audio column to Audio feature with decode disabled to avoid torchcodec dependency
# We'll load audio manually in the prepare_dataset function.
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000, decode=False))

print(f"📁 Dataset created with {len(dataset)} examples")
print(f"Dataset features: {dataset.features}")

# Show a sample (only path available before manual load)
if len(dataset) > 0:
    sample = dataset[0]
    print(f"\n📝 Sample:")
    print(f"  Song: {sample['song_name']}")
    print(f"  Text length: {len(sample['text'])} characters")
    print(f"  Audio path: {sample['audio']['path'] if isinstance(sample['audio'], dict) else sample['audio']}")
    print(f"  Text preview: {sample['text'][:100]}...")

## 4. Model Setup and Data Preprocessing

In [None]:
# Initialize Whisper model and processor
MODEL_NAME = "openai/whisper-small"  # You can change this to base, medium, or large

print(f"🤖 Loading model: {MODEL_NAME}")

# Load processor and model
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME)
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)

# Move model to device
model = model.to(device)

print(f"✅ Model loaded successfully")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Preprocessing function 

def prepare_dataset(batch):
    # Get audio path
    audio_path = batch["audio"]["path"] if isinstance(batch["audio"], dict) else batch["audio"]

    # Load audio (mono, 16 kHz)
    audio_array, _ = librosa.load(audio_path, sr=16000)

    # Extract input features
    batch["input_features"] = feature_extractor(
        audio_array,
        sampling_rate=16000
    ).input_features[0]

    # Tokenize text -> labels
    batch["labels"] = tokenizer(batch["text"]).input_ids

    return batch

print("🔄 Preprocessing dataset (simple)...")
dataset = dataset.map(
    prepare_dataset,
    remove_columns=dataset.column_names,
    desc="Preprocessing"
)

print("✅ Dataset preprocessed")
print(f"Features: {dataset.features}")

## 5. Train-Validation Split

In [None]:
# Split dataset
if len(dataset) > 1:
    # If we have multiple songs, create train/validation split
    train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset = train_test_split["train"]
    eval_dataset = train_test_split["test"]
else:
    # If we only have one song, use it for both training and validation
    train_dataset = dataset
    eval_dataset = dataset
    print("⚠️  Only one song found. Using the same data for training and validation.")

print(f"📊 Dataset split:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(eval_dataset)}")

## 6. Data Collator

In [None]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator for speech-to-text tasks.
    """
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding token id's of the labels by -100
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

print("✅ Data collator initialized")

## 7. Evaluation Metrics

In [None]:
# Load evaluation metrics
metric_wer = evaluate.load("wer")
metric_bleu = evaluate.load("bleu")

def compute_metrics(eval_pred):
    """
    Compute evaluation metrics.
    """
    pred_ids, label_ids = eval_pred
    
    # Replace -100 with pad token id
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Compute WER
    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
    
    # Compute BLEU
    bleu = metric_bleu.compute(
        predictions=pred_str, 
        references=[[ref] for ref in label_str]
    )["bleu"]
    
    return {
        "wer": wer,
        "bleu": bleu
    }

print("✅ Metrics initialized")

## 8. Training Configuration

In [None]:
# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-lyrics-model",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=50,
    num_train_epochs=10,
    learning_rate=1e-5,
    fp16=True,  # Enable mixed precision training
    eval_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    dataloader_num_workers=0,
    report_to=[]
)

print(f"📋 Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Output directory: {training_args.output_dir}")

## 9. Initialize Trainer

In [None]:
# Initialize trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

print("✅ Trainer initialized successfully")

## 10. Start Training

In [None]:
# Start training
print("🚀 Starting training...")
print("This may take a while depending on your dataset size and hardware.")

trainer.train()

print("✅ Training completed!")

## 11. Save the Trained Model

In [None]:
# Save the final model
model_save_path = "./whisper-lyrics-final"

trainer.save_model(model_save_path)
processor.save_pretrained(model_save_path)

print(f"💾 Model saved to: {model_save_path}")

# Create a zip file for easy download
!zip -r whisper-lyrics-model.zip whisper-lyrics-final/
print("📦 Model packaged as whisper-lyrics-model.zip")

## 12. Test the Trained Model

In [None]:
# Load the trained model for testing
from transformers import pipeline

# Create ASR pipeline with your trained model
trained_asr = pipeline(
    "automatic-speech-recognition",
    model=model_save_path,
    tokenizer=model_save_path,
    feature_extractor=model_save_path,
    device=0 if torch.cuda.is_available() else -1
)

print("✅ Trained model loaded for testing")

In [None]:
# Test on a sample from your dataset
if len(eval_dataset) > 0:
    # Get a test sample
    test_sample = eval_dataset[0]
    
    # Reconstruct audio array for testing
    audio_array = test_sample['input_features']
    
    # Get prediction from trained model
    print("🎤 Testing trained model...")
    
    # Note: We need to convert input_features back to audio array for the pipeline
    # This is a simplified approach - in practice, you'd keep the original audio
    
    print("📝 Model is ready for testing!")
    print("To test with new audio files, use the trained_asr pipeline with audio files.")
    
else:
    print("⚠️  No evaluation data available for testing")

## 13. Model Usage Instructions

In [None]:
# Instructions for using the trained model
usage_code = '''
# How to use your trained model:

from transformers import pipeline

# Load your trained model
asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model="./whisper-lyrics-final",
    device=0 if torch.cuda.is_available() else -1
)

# Use with Demucs vocal separation (same as your original code)
def transcribe_with_trained_model(song_path):
    # 1. Separate vocals with Demucs
    vocal_path = separate_vocals_demucs(song_path)
    
    # 2. Transcribe with your trained model
    result = asr_pipeline(
        vocal_path, 
        return_timestamps=True,
        generate_kwargs={"language": "en"}
    )
    
    return result["text"]

# Example usage
lyrics = transcribe_with_trained_model("path/to/your/song.mp3")
print(lyrics)
'''

print("📋 Usage Instructions:")
print(usage_code)

# Save usage instructions to file
with open("model_usage.py", "w") as f:
    f.write(usage_code)

print("💾 Usage instructions saved to model_usage.py")

## 14. Training Summary

In [None]:
# Display training summary
print("🎉 Training Complete!")
print("=" * 50)
print(f"📊 Dataset size: {len(dataset)} songs")
print(f"🤖 Base model: {MODEL_NAME}")
print(f"💾 Saved model: {model_save_path}")
print(f"📦 Zip file: whisper-lyrics-model.zip")
print("=" * 50)

print("\n📋 Next Steps:")
print("1. Download whisper-lyrics-model.zip")
print("2. Extract the model files")
print("3. Use the model with your original pipeline")
print("4. Test on new songs to evaluate performance")

print("\n💡 Tips for better results:")
print("- Use more training data for better performance")
print("- Ensure lyrics are high quality and properly formatted")
print("- Consider using a larger base model (medium/large) if you have GPU memory")
print("- Fine-tune hyperparameters based on your specific use case")