# STT Model Training on Kaggle

**What this notebook does:**
1. Installs required packages
2. Transcribes audio with Whisper
3. Fine-tunes Whisper with LoRA
4. Saves trained model

**GPU Required:** Enable GPU in Settings ‚Üí Accelerator ‚Üí GPU P100 or T4

## Step 1: Check GPU

In [None]:
# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ùå GPU not enabled! Go to Settings ‚Üí Accelerator ‚Üí GPU")

## Step 2: Install Required Packages

In [None]:
# Install packages (takes 2-3 minutes)
!pip install -q transformers datasets accelerate peft
!pip install -q librosa soundfile jiwer tensorboard
!pip install -q openai-whisper
!pip install -q bitsandbytes

print("‚úÖ All packages installed!")

## Step 3: Setup Paths

**IMPORTANT:** Update the dataset path below to match your uploaded dataset name!

In [None]:
import os

# ============================================
# UPDATE THIS PATH TO YOUR DATASET!
# ============================================
# Format: /kaggle/input/YOUR-DATASET-NAME
DATASET_PATH = "/kaggle/input/stt-training-data"

# Output path (Kaggle working directory)
OUTPUT_PATH = "/kaggle/working"

# Check if dataset exists
if os.path.exists(DATASET_PATH):
    print(f"‚úÖ Dataset found at: {DATASET_PATH}")
    print("\nContents:")
    for item in os.listdir(DATASET_PATH):
        print(f"  - {item}")
else:
    print(f"‚ùå Dataset not found at: {DATASET_PATH}")
    print("\nAvailable datasets:")
    for item in os.listdir("/kaggle/input"):
        print(f"  - /kaggle/input/{item}")

In [None]:
# Explore dataset structure (run this to see what's inside)
import os

def show_tree(path, prefix="", max_depth=3, current_depth=0):
    """Show folder structure."""
    if current_depth >= max_depth:
        return
    
    try:
        items = sorted(os.listdir(path))
        for i, item in enumerate(items[:10]):  # Limit to 10 items
            item_path = os.path.join(path, item)
            is_last = (i == len(items[:10]) - 1)
            connector = "‚îî‚îÄ‚îÄ " if is_last else "‚îú‚îÄ‚îÄ "
            
            if os.path.isdir(item_path):
                print(f"{prefix}{connector}üìÅ {item}/")
                new_prefix = prefix + ("    " if is_last else "‚îÇ   ")
                show_tree(item_path, new_prefix, max_depth, current_depth + 1)
            else:
                print(f"{prefix}{connector}üìÑ {item}")
        
        if len(items) > 10:
            print(f"{prefix}    ... and {len(items) - 10} more items")
    except Exception as e:
        print(f"Error: {e}")

print("üìÇ Dataset Structure:")
print("=" * 50)
show_tree(DATASET_PATH)
print("=" * 50)

import whisper
import json
from pathlib import Path
from tqdm import tqdm

# Load Whisper model for transcription
print("Loading Whisper medium model...")
model = whisper.load_model("medium")
print("‚úÖ Model loaded!")

def fix_audio_path(original_path, dataset_path):
    """Convert Windows path to Kaggle path (handles nested extracted folder)."""
    path = original_path.replace('\\', '/')
    
    if 'extracted/' in path:
        relative = path.split('extracted/')[-1]
        # Handle nested extracted/extracted/ structure on Kaggle
        return f"{dataset_path}/extracted/extracted/{relative}"
    elif 'combined/' in path:
        relative = path.split('combined/')[-1]
        return f"{dataset_path}/combined/{relative}"
    return original_path

def transcribe_manifest(manifest_path, lang, output_path):
    """Transcribe all audio files in a manifest."""
    
    # Read manifest
    with open(manifest_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    print(f"Found {len(lines)} audio files to transcribe")
    
    # Show sample path conversion
    sample_entry = json.loads(lines[0].strip())
    sample_path = sample_entry.get('audio_path') or sample_entry.get('audio_filepath')
    fixed_sample = fix_audio_path(sample_path, DATASET_PATH)
    print(f"\nPath conversion example:")
    print(f"  Original: {sample_path}")
    print(f"  Fixed:    {fixed_sample}")
    print(f"  Exists:   {os.path.exists(fixed_sample)}")
    
    if not os.path.exists(fixed_sample):
        print("\n‚ö†Ô∏è File not found! Let me search for audio files...")
        for root, dirs, files in os.walk(DATASET_PATH):
            wav_files = [f for f in files if f.endswith('.wav')]
            if wav_files:
                print(f"Found .wav files in: {root}")
                print(f"Example file: {wav_files[0]}")
                break
        return []
    
    print("\n")
    results = []
    not_found = 0
    
    for line in tqdm(lines, desc=f"Transcribing {lang}"):
        entry = json.loads(line.strip())
        original_path = entry.get('audio_path') or entry.get('audio_filepath')
        
        # Fix path for Kaggle
        audio_path = fix_audio_path(original_path, DATASET_PATH)
        
        if os.path.exists(audio_path):
            try:
                result = model.transcribe(audio_path, language=lang)
                entry['text'] = result['text'].strip()
                entry['audio_filepath'] = audio_path
                results.append(entry)
            except Exception as e:
                print(f"Error transcribing {audio_path}: {e}")
        else:
            not_found += 1
            if not_found <= 3:
                print(f"File not found: {audio_path}")
    
    if not_found > 3:
        print(f"... and {not_found - 3} more files not found")
    
    print(f"\n‚úÖ Transcribed: {len(results)} files")
    print(f"‚ùå Not found: {not_found} files")
    
    # Save transcribed manifest
    with open(output_path, 'w', encoding='utf-8') as f:
        for entry in results:
            f.write(json.dumps(entry, ensure_ascii=False) + '\n')
    
    print(f"‚úÖ Saved to: {output_path}")
    return results

In [None]:
import whisper
import json
from pathlib import Path
from tqdm import tqdm

# Load Whisper model for transcription
print("Loading Whisper medium model...")
model = whisper.load_model("medium")
print("‚úÖ Model loaded!")

def fix_audio_path(original_path, dataset_path):
    """Convert Windows path to Kaggle path."""
    # Handle Windows paths like: C:\Users\DELL\Desktop\...\data\stt\extracted\...
    # Convert to: /kaggle/input/stt-training-data/extracted/...
    
    # Normalize path separators
    path = original_path.replace('\\', '/')
    
    # Find the 'extracted' part and build Kaggle path
    if 'extracted/' in path:
        # Get everything after 'extracted/'
        relative = path.split('extracted/')[-1]
        return f"{dataset_path}/extracted/{relative}"
    elif 'combined/' in path:
        relative = path.split('combined/')[-1]
        return f"{dataset_path}/combined/{relative}"
    else:
        return original_path

def transcribe_manifest(manifest_path, lang, output_path):
    """Transcribe all audio files in a manifest."""
    
    # Read manifest
    with open(manifest_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    print(f"Found {len(lines)} audio files to transcribe")
    
    # Show sample path conversion
    sample_entry = json.loads(lines[0].strip())
    sample_path = sample_entry.get('audio_path') or sample_entry.get('audio_filepath')
    fixed_sample = fix_audio_path(sample_path, DATASET_PATH)
    print(f"\nPath conversion example:")
    print(f"  Original: {sample_path}")
    print(f"  Fixed:    {fixed_sample}")
    print(f"  Exists:   {os.path.exists(fixed_sample)}\n")
    
    results = []
    not_found = 0
    
    for line in tqdm(lines, desc=f"Transcribing {lang}"):
        entry = json.loads(line.strip())
        original_path = entry.get('audio_path') or entry.get('audio_filepath')
        
        # Fix path for Kaggle
        audio_path = fix_audio_path(original_path, DATASET_PATH)
        
        if os.path.exists(audio_path):
            try:
                result = model.transcribe(audio_path, language=lang)
                entry['text'] = result['text'].strip()
                entry['audio_filepath'] = audio_path
                results.append(entry)
            except Exception as e:
                print(f"Error transcribing {audio_path}: {e}")
        else:
            not_found += 1
            if not_found <= 3:  # Only show first 3 errors
                print(f"File not found: {audio_path}")
    
    if not_found > 3:
        print(f"... and {not_found - 3} more files not found")
    
    print(f"\n‚úÖ Transcribed: {len(results)} files")
    print(f"‚ùå Not found: {not_found} files")
    
    # Save transcribed manifest
    with open(output_path, 'w', encoding='utf-8') as f:
        for entry in results:
            f.write(json.dumps(entry, ensure_ascii=False) + '\n')
    
    print(f"‚úÖ Saved to: {output_path}")
    return results

In [None]:
# Transcribe Hindi data
# Check both possible locations for the manifest
hi_manifest_options = [
    f"{DATASET_PATH}/combined/hi_train.jsonl",
    f"{DATASET_PATH}/hi_train.jsonl",
]

hi_manifest = None
for path in hi_manifest_options:
    if os.path.exists(path):
        hi_manifest = path
        break

if hi_manifest:
    print(f"Found Hindi manifest at: {hi_manifest}")
    hi_results = transcribe_manifest(
        hi_manifest, 
        lang="hi", 
        output_path=f"{OUTPUT_PATH}/hi_train_transcribed.jsonl"
    )
else:
    print("‚ùå Hindi manifest not found!")
    print("Checked locations:")
    for path in hi_manifest_options:
        print(f"  - {path}")

In [None]:
# Transcribe English data
en_manifest_options = [
    f"{DATASET_PATH}/combined/en_train.jsonl",
    f"{DATASET_PATH}/en_train.jsonl",
]

en_manifest = None
for path in en_manifest_options:
    if os.path.exists(path):
        en_manifest = path
        break

if en_manifest:
    print(f"Found English manifest at: {en_manifest}")
    en_results = transcribe_manifest(
        en_manifest, 
        lang="en", 
        output_path=f"{OUTPUT_PATH}/en_train_transcribed.jsonl"
    )
else:
    print("‚ùå English manifest not found!")
    print("Checked locations:")
    for path in en_manifest_options:
        print(f"  - {path}")

## Step 5: Prepare Dataset for Training

In [None]:
import json
import librosa
import numpy as np
from datasets import Dataset, Audio
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Load processor
print("Loading Whisper processor...")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")

def load_manifest(manifest_path):
    """Load manifest file and create dataset."""
    entries = []
    with open(manifest_path, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line.strip())
            if entry.get('text'):  # Only include transcribed entries
                entries.append({
                    'audio': entry.get('audio_filepath') or entry.get('audio_path'),
                    'text': entry['text']
                })
    return entries

# Load transcribed data
train_data = []

hi_transcribed = f"{OUTPUT_PATH}/hi_train_transcribed.jsonl"
en_transcribed = f"{OUTPUT_PATH}/en_train_transcribed.jsonl"

if os.path.exists(hi_transcribed):
    train_data.extend(load_manifest(hi_transcribed))
    print(f"Loaded {len(train_data)} Hindi samples")

if os.path.exists(en_transcribed):
    en_data = load_manifest(en_transcribed)
    train_data.extend(en_data)
    print(f"Loaded {len(en_data)} English samples")

print(f"\nTotal training samples: {len(train_data)}")

In [None]:
# Create HuggingFace dataset
from datasets import Dataset, Audio

dataset = Dataset.from_list(train_data)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

print(f"Dataset created with {len(dataset)} samples")
print(dataset)

In [None]:
def prepare_dataset(batch):
    """Prepare audio and text for training."""
    audio = batch["audio"]
    
    # Process audio
    batch["input_features"] = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    ).input_features[0]
    
    # Process text
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    
    return batch

# Process dataset
print("Processing dataset (this may take a while)...")
processed_dataset = dataset.map(
    prepare_dataset, 
    remove_columns=dataset.column_names,
    num_proc=1
)
print("‚úÖ Dataset processed!")

## Step 6: Setup Model with LoRA

In [None]:
from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Load base model
print("Loading Whisper small model...")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Configure LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("‚úÖ Model ready for training!")

## Step 7: Training

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=f"{OUTPUT_PATH}/whisper-finetuned",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    warmup_steps=50,
    num_train_epochs=3,
    fp16=True,
    logging_steps=25,
    save_steps=500,
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=225,
    report_to=["tensorboard"],
    push_to_hub=False,
)

# Create trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=processed_dataset,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
)

print("‚úÖ Trainer ready!")

In [None]:
# Start training!
print("üöÄ Starting training...")
print("This will take 1-2 hours depending on your data size.")
print("-" * 50)

trainer.train()

print("-" * 50)
print("‚úÖ Training complete!")

## Step 8: Save Model

In [None]:
# Save the fine-tuned model
model_save_path = f"{OUTPUT_PATH}/whisper-stt-finetuned"

# Save LoRA weights
model.save_pretrained(model_save_path)
processor.save_pretrained(model_save_path)

print(f"‚úÖ Model saved to: {model_save_path}")
print("\nFiles saved:")
for f in os.listdir(model_save_path):
    print(f"  - {f}")

## Step 9: Test the Model

In [None]:
# Test with a sample audio
import torch

def transcribe_audio(audio_path):
    """Transcribe audio using fine-tuned model."""
    # Load audio
    audio, sr = librosa.load(audio_path, sr=16000)
    
    # Process
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    input_features = inputs.input_features.to(model.device)
    
    # Generate
    with torch.no_grad():
        predicted_ids = model.generate(input_features, max_length=225)
    
    # Decode
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription

# Test with first audio file from dataset
if len(train_data) > 0:
    test_audio = train_data[0]['audio']
    print(f"Testing with: {test_audio}")
    print(f"Original text: {train_data[0]['text']}")
    print(f"Model output: {transcribe_audio(test_audio)}")

## Step 10: Download Model

1. Click on the folder icon on the left sidebar
2. Navigate to `/kaggle/working/whisper-stt-finetuned`
3. Right-click ‚Üí Download

Or create a zip file:

In [None]:
# Create zip for easy download
import shutil

shutil.make_archive(
    f"{OUTPUT_PATH}/whisper-stt-model",
    'zip',
    model_save_path
)

print(f"‚úÖ Model zipped: {OUTPUT_PATH}/whisper-stt-model.zip")
print("\nDownload this file from the Output section on the right panel!")

---

## Done!

Your trained model is saved in `/kaggle/working/whisper-stt-finetuned`

To download:
1. Go to **Output** tab on right panel
2. Download `whisper-stt-model.zip`

To use the model later, load with:
```python
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor

base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model = PeftModel.from_pretrained(base_model, "path/to/whisper-stt-finetuned")
processor = WhisperProcessor.from_pretrained("path/to/whisper-stt-finetuned")
```