In [None]:
!apt-get update -qq

In [None]:
!apt-get install -y ffmpeg libsndfile1

In [None]:
!pip install --upgrade pip

In [None]:
!pip install pydantic==2.11.0 rich==13.7.1 pyarrow==19.0.0 --force-reinstall

In [None]:
!pip install transformers datasets accelerate

In [None]:
!pip install audiomentations==0.36.0

In [None]:
!pip install  tensorboard numpy scipy datasets 

In [None]:
!pip install peft>=0.12.0 bitsandbytes>=0.43.3

In [None]:
!pip install tqdm==4.67.1 scikit-learn==1.2.2

In [None]:
pip install evaluate jiwer

In [None]:
# pip list

In [None]:
import datasets
print(datasets.__version__)

In [None]:
import os
import re
import json
import torch
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Union, Any
from tqdm.auto import tqdm

import librosa
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift

from datasets import Dataset, DatasetDict, Audio
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer, 
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# GPU Check
print(f"\n{'='*60}")
print(f"üñ•Ô∏è  GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üìä GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"{'='*60}\n")



# Dataset Loading and prasing

### Chunking long audio

In [None]:
def chunk_long_audio(audio_path, transcription, max_duration=30.0, overlap=1.0):
    """
    Split long audio files into chunks of max_duration seconds.
    
    Args:
        audio_path: Path to audio file
        transcription: Full transcription text
        max_duration: Maximum chunk duration in seconds (default 30s for Whisper)
        overlap: Overlap between chunks in seconds to avoid cutting words
    
    Returns:
        List of dicts with chunked audio paths and estimated transcriptions
    """
    audio, sr = librosa.load(audio_path, sr=16000)
    audio_duration = len(audio) / sr
    
    # If audio is short enough, return as-is
    if audio_duration <= max_duration:
        return [{'audio_path': audio_path, 'transcription': transcription}]
    
    # Calculate chunk parameters
    chunk_samples = int(max_duration * sr)
    overlap_samples = int(overlap * sr)
    step_samples = chunk_samples - overlap_samples
    
    chunks = []
    words = transcription.split()
    total_chunks = int(np.ceil((len(audio) - chunk_samples) / step_samples)) + 1
    words_per_chunk = max(1, len(words) // total_chunks)
    
    chunk_idx = 0
    word_start = 0
    
    for start in range(0, len(audio) - overlap_samples, step_samples):
        end = min(start + chunk_samples, len(audio))
        chunk_audio = audio[start:end]
        
        # Estimate text for this chunk (proportional split)
        word_end = min(word_start + words_per_chunk, len(words))
        
        # For last chunk, take remaining words
        if start + step_samples >= len(audio) - chunk_samples:
            word_end = len(words)
        
        chunk_text = ' '.join(words[word_start:word_end])
        
        # Save chunk temporarily
        chunk_filename = f"chunk_{chunk_idx}_{Path(audio_path).name}"
        chunk_path = f"/kaggle/working/chunks/{chunk_filename}"
        os.makedirs("/kaggle/working/chunks", exist_ok=True)
        sf.write(chunk_path, chunk_audio, sr)
        
        chunks.append({
            'audio_path': chunk_path,
            'transcription': chunk_text.strip()
        })
        
        word_start = word_end
        chunk_idx += 1
        
        # Stop if we've processed all audio
        if end >= len(audio):
            break
    
    print(f"  üìå Split {Path(audio_path).name} ({audio_duration:.1f}s) ‚Üí {len(chunks)} chunks")
    return chunks

In [None]:
def augment_audio(audio, sr):
    augmenter = Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
        Shift(min_shift=-0.5, max_shift=0.5, p=0.3),
    ])
    
    return augmenter(samples=audio, sample_rate=sr)

In [None]:
def parse_text_file(text_path):
    """
    Parse text files handling both formats:
    - '1.Text here' or '1. Text here' ‚Üí removes number prefix
    - 'Text here' ‚Üí uses as-is
    Returns: list of (line_number, cleaned_text) tuples
    """
    with open(text_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    parsed = []
    for line in lines:
        original = line.strip()
        if not original:
            continue
            
        # Extract line number if present: "1." or "100." etc
        match = re.match(r'^(\d+)\.?\s*(.+)$', original)
        if match:
            line_num = int(match.group(1))
            text = match.group(2).strip()
        else:
            line_num = None
            text = original
        
        parsed.append((line_num, text))
    
    return parsed


def extract_audio_number(filename):
    """
    Extract number from various audio filename formats:
    - 'Voice 1.mp3' ‚Üí 1
    - '1_load my wallet.mp3' ‚Üí 1
    - 'Standard recording 1.mp3' ‚Üí 1
    - 'recording_001.mp3' ‚Üí 1
    Returns: number or None
    """
    # Try pattern: '1_text.mp3'
    match = re.match(r'^(\d+)_', filename)
    if match:
        return int(match.group(1))
    
    # Try pattern: 'text 1.mp3' or 'text_1.mp3'
    match = re.search(r'[\s_](\d+)\.mp3$', filename)
    if match:
        return int(match.group(1))
    
    # Try pattern: 'text001.mp3'
    match = re.search(r'(\d+)\.mp3$', filename)
    if match:
        return int(match.group(1))
    
    return None


def create_dataset_manifest(base_path):
    """
    Create audio-text pairs with intelligent matching.
    Handles both numbered and sequential matching.
    """
    
    dataset_entries = []
    
    # Mapping: audio_folder ‚Üí text_file
    mappings = {
        'Voice_memo': 'voice_text_dataset/voice_text_dataset/voice_memo_text_data.txt',
        'chinese_accent': 'voice_text_dataset/voice_text_dataset/chinese_accent_text_data.txt',
        'voice_record': 'voice_text_dataset/voice_text_dataset/voice_record_text_data.txt',
        # 'my_voice': 'voice_text_dataset/voice_text_dataset/my_voice_text_data.txt'
    }
    
    total_mismatches = 0
    
    for audio_dir, text_file in mappings.items():
        audio_path = Path(base_path) / audio_dir
        text_path = Path(base_path) / text_file
        
        if not audio_path.exists():
            print(f"‚ö†Ô∏è  Skipping: {audio_dir} (not found)")
            continue
        
        if not text_path.exists():
            print(f"‚ö†Ô∏è  Skipping: {text_file} (not found)")
            continue
        
        # Get audio files
        audio_files = sorted(audio_path.glob('*.mp3'))
        
        # Parse transcriptions
        transcriptions = parse_text_file(text_path)
        
        # Create mapping: number ‚Üí text
        text_dict = {}
        for line_num, text in transcriptions:
            if line_num:
                text_dict[line_num] = text
        
        # Match audio files to transcriptions
        matched = 0
        for audio_file in audio_files:
            # Try to extract number from filename
            audio_num = extract_audio_number(audio_file.name)
            
            if audio_num and audio_num in text_dict:
                # Number-based matching
                transcription = text_dict[audio_num]
                matched += 1
            elif len(transcriptions) > 0:
                # Sequential fallback (use first available)
                line_num, transcription = transcriptions.pop(0)
                matched += 1
            else:
                print(f"‚ùå No transcription for: {audio_file.name}")
                total_mismatches += 1
                continue
            
            dataset_entries.append({
                'audio_path': str(audio_file),
                'transcription': transcription,
                'category': audio_dir,
                'filename': audio_file.name
            })
        
        print(f"‚úÖ {audio_dir}: {matched} files matched")
    
    # Handle single file with multi-sentence transcription
    single_file = Path(base_path) / 'my_voice_sample.mp3'
    single_text = Path(base_path) / 'voice_text_dataset/voice_text_dataset/my_voice_text_data.txt'
    
    if single_file.exists() and single_text.exists():
        with open(single_text, 'r', encoding='utf-8') as f:
            # Split by sentence if multiple exist
            text = f.read().strip()
            sentences = re.split(r'[.!?]+\s+', text)
            # Use full text as one entry
            full_text = ' '.join(sentences).strip()
            
        dataset_entries.append({
            'audio_path': str(single_file),
            'transcription': full_text,
            'category': 'single_voice',
            'filename': single_file.name
        })
        print(f"‚úÖ single_voice: 1 file matched")
    
    print(f"\n{'='*60}")
    print(f"üìä Total matched: {len(dataset_entries)} samples")
    if total_mismatches > 0:
        print(f"‚ö†Ô∏è  Mismatches: {total_mismatches} files")
    print(f"{'='*60}\n")
    
    return dataset_entries


# Create manifest
BASE_PATH = '/kaggle/input/fine-tuning-dataset'
print("üîç Scanning dataset...\n")
dataset_manifest = create_dataset_manifest(BASE_PATH)

# Convert to DataFrame
df = pd.DataFrame(dataset_manifest)
print("\nüìã Dataset Breakdown:")
print(df.groupby('category').size())
print(f"\nüíæ Saving manifest...")
df.to_csv('dataset_manifest.csv', index=False)
print("‚úÖ Saved to dataset_manifest.csv")

# Display sample
print("\nüìù Sample entries:")
print(df.head(3)[['filename', 'transcription', 'category']])


# Audio Validation & Duration Calculation

In [None]:
import subprocess

def get_audio_duration(audio_path):
    """Get duration in seconds using FFmpeg"""
    try:
        result = subprocess.run(
            ['ffprobe', '-v', 'error', '-show_entries', 'format=duration',
             '-of', 'default=noprint_wrappers=1:nokey=1', audio_path],
            stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
        )
        if result.returncode != 0:
            print(f"‚ùå FFmpeg error for {audio_path}: {result.stdout.strip()}")
            return 0
        return float(result.stdout.strip())
    except Exception as e:
        print(f"‚ùå Error processing {audio_path}: {e}")
        return 0
print("\n‚è±Ô∏è  Calculating total duration...")
df['duration'] = df['audio_path'].apply(get_audio_duration)

total_duration_sec = df['duration'].sum()
total_duration_min = total_duration_sec / 60
total_duration_hr = total_duration_min / 60

print(f"\n{'='*60}")
print(f"üìä Dataset Statistics:")
print(f"   Total Samples: {len(df)}")
print(f"   Total Duration: {total_duration_min:.1f} minutes ({total_duration_hr:.2f} hours)")
print(f"   Avg Duration: {df['duration'].mean():.1f} seconds")
print(f"   Min Duration: {df['duration'].min():.1f} seconds")
print(f"   Max Duration: {df['duration'].max():.1f} seconds")
print(f"{'='*60}\n")

if total_duration_hr < 0.5:
    print("‚ö†Ô∏è  WARNING: Dataset < 0.5 hours")
    print("   Expected: Overfitting likely, limited generalization")
    print("   Recommendation: Collect 10-20 more hours for production use")
    print("   Proceeding with POC training...\n")


# Data Augmentation (Critical for Small Datasets)

In [None]:
!apt-get update -qq && apt-get install -y ffmpeg

In [None]:
def augment_audio(audio, sr):
    """
    Apply audio augmentations to increase dataset diversity.
    Helps combat overfitting on small datasets.
    """
    augmenter = Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
        Shift(min_shift=-0.5, max_shift=0.5, p=0.3),
    ])
    
    return augmenter(samples=audio, sample_rate=sr)


def create_augmented_dataset(df, augmentation_factor=2):
    """
    Create augmented copies of dataset with audio chunking for long files.
    """
    augmented_entries = []
    
    print(f"üîÑ Creating {augmentation_factor}x augmented dataset with chunking...")
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        # Check if audio needs chunking (>30s)
        if row['duration'] > 30.0:
            print(f"\n‚ö†Ô∏è  Long audio detected: {row['filename']} ({row['duration']:.1f}s)")
            chunks = chunk_long_audio(row['audio_path'], row['transcription'], max_duration=30.0)
            
            # Add all chunks
            for chunk in chunks:
                chunk_entry = row.to_dict()
                chunk_entry['audio_path'] = chunk['audio_path']
                chunk_entry['transcription'] = chunk['transcription']
                chunk_entry['filename'] = Path(chunk['audio_path']).name
                chunk_entry['duration'] = librosa.get_duration(path=chunk['audio_path'])
                augmented_entries.append(chunk_entry)
            
            # Apply augmentation to chunks if needed
            for aug_idx in range(augmentation_factor - 1):
                for chunk in chunks:
                    audio, sr = librosa.load(chunk['audio_path'], sr=16000)
                    aug_audio = augment_audio(audio, sr)
                    
                    aug_filename = f"aug_{aug_idx}_{Path(chunk['audio_path']).name}"
                    aug_path = f"/kaggle/working/augmented/{aug_filename}"
                    os.makedirs("/kaggle/working/augmented", exist_ok=True)
                    sf.write(aug_path, aug_audio, sr)
                    
                    aug_entry = row.to_dict()
                    aug_entry['audio_path'] = aug_path
                    aug_entry['transcription'] = chunk['transcription']
                    aug_entry['category'] = f"{row['category']}_aug"
                    aug_entry['filename'] = aug_filename
                    aug_entry['duration'] = len(aug_audio) / sr
                    augmented_entries.append(aug_entry)
        else:
            # Normal processing for short audio
            augmented_entries.append(row.to_dict())
            
            for aug_idx in range(augmentation_factor - 1):
                audio, sr = librosa.load(row['audio_path'], sr=16000)
                aug_audio = augment_audio(audio, sr)
                
                aug_filename = f"aug_{aug_idx}_{Path(row['audio_path']).name}"
                aug_path = f"/kaggle/working/augmented/{aug_filename}"
                os.makedirs("/kaggle/working/augmented", exist_ok=True)
                sf.write(aug_path, aug_audio, sr)
                
                augmented_entries.append({
                    'audio_path': aug_path,
                    'transcription': row['transcription'],
                    'category': f"{row['category']}_aug",
                    'filename': aug_filename,
                    'duration': len(aug_audio) / sr
                })
    
    return pd.DataFrame(augmented_entries)

In [None]:
# Existing augmentation code...
df_augmented = create_augmented_dataset(df, augmentation_factor=2)

# Add base_category for stratification (strips '_aug')
df_augmented['base_category'] = df_augmented['category'].str.replace('_aug', '')

# Optional: Verify counts on base_category (all should now >=2)
print("üìä Base Category Counts:")
print(df_augmented['base_category'].value_counts())
print("\nBase Categories with <2 samples:", df_augmented['base_category'].value_counts()[df_augmented['base_category'].value_counts() < 2].index.tolist())

In [None]:
len(df_augmented)

# train test split

In [None]:
from sklearn.model_selection import train_test_split

# 80% train, 10% val, 10% test
train_df, temp_df = train_test_split(df_augmented, test_size=0.2, random_state=42, 
                                      stratify=df_augmented['base_category'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42,
                                    stratify=temp_df['base_category'])

print(f"üìä Dataset Split:")
print(f"   Train: {len(train_df)} samples ({len(train_df)/len(df_augmented)*100:.1f}%)")
print(f"   Val:   {len(val_df)} samples ({len(val_df)/len(df_augmented)*100:.1f}%)")
print(f"   Test:  {len(test_df)} samples ({len(test_df)/len(df_augmented)*100:.1f}%)\n")

# loading the model

In [None]:
print("üì• Loading Whisper Large-v3...\n")

model_name = "openai/whisper-large-v3"

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="English", task="transcribe")
processor = WhisperProcessor.from_pretrained(model_name, language="English", task="transcribe")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True  # Optional: nested quantization for extra memory savings
)

model = WhisperForConditionalGeneration.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

print("‚úÖ Model loaded!\n")

# applying lora confriguation

In [None]:
print("üîß Applying LoRA configuration...\n")
# Prepare model for LoRA training
model = prepare_model_for_kbit_training(model)
# Optionally freeze the encoder to focus LoRA on decoder only (for efficiency)
model.model.encoder.requires_grad_(False)
# LoRA Configuration for Whisper
lora_config = LoraConfig(
    r=32,                          # LoRA rank (higher = more parameters, better quality)
    lora_alpha=64,                 # LoRA scaling factor
    target_modules=[               # Use simple suffixes to match all relevant layers (applies to both encoder/decoder)
        "q_proj",
        "v_proj",
        "k_proj",
        "out_proj",
        "fc1",
        "fc2"
    ],
    lora_dropout=0.05,             # Dropout for LoRA layers
    bias="none",                   # Don't train biases
    # task_type="SEQ_2_SEQ_LM"       # Critical fix: Use seq2seq for Whisper
)
# Apply LoRA to model
model = get_peft_model(model, lora_config)
# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"{'='*60}")
print(f"üìä LoRA Model Statistics:")
print(f"   Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
print(f"   Total params: {total_params:,}")
print(f"   Memory reduction: ~{100 - (trainable_params/total_params*100):.1f}%")
print(f"{'='*60}\n")

# Prepare Dataset for Training (SAME AS BEFORE)


In [None]:
def prepare_dataset_entry(batch):
    """
    Prepare dataset entry with token length validation.
    """
    audio, sr = librosa.load(batch["audio_path"], sr=16000)
    
    batch["input_features"] = processor.feature_extractor(
        audio, sampling_rate=16000
    ).input_features[0]
    
    # Tokenize with truncation as safety measure
    labels = tokenizer(
        batch["transcription"],
        truncation=True,
        max_length=448  # Whisper's max token length
    ).input_ids
    
    batch["labels"] = labels
    
    # Warning if truncation occurred
    if len(labels) >= 448:
        print(f"‚ö†Ô∏è  Truncated labels for: {batch.get('audio_path', 'unknown')} (originally {len(labels)} tokens)")
    
    return batch


def df_to_dataset(df):
    return Dataset.from_dict({
        "audio_path": df["audio_path"].tolist(),
        "transcription": df["transcription"].tolist()
    })

print("üîÑ Processing dataset...")
train_dataset = df_to_dataset(train_df).map(prepare_dataset_entry, remove_columns=["audio_path", "transcription"])
val_dataset = df_to_dataset(val_df).map(prepare_dataset_entry, remove_columns=["audio_path", "transcription"])
test_dataset = df_to_dataset(test_df).map(prepare_dataset_entry, remove_columns=["audio_path", "transcription"])

print("‚úÖ Dataset ready!\n")


# data collateral

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Extract input features from batch
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        # Extract labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        
        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        # If bos token is prepended in previous tokenization step, remove it
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        
        batch["labels"] = labels
        
        return batch  # ‚úÖ Return the full batch dict

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


# evaluation metrics

In [None]:
# Metric computation
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

# Training Configuration (OPTIMIZED FOR LORA)


In [None]:

# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
model.config.use_cache = False

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-large-v3-malaysian-lora",
    per_device_train_batch_size=4,      # Reduce if OOM
    gradient_accumulation_steps=4,       # Adjust based on memory
    learning_rate=1e-3,
    warmup_steps=50,
    max_steps=500,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    remove_unused_columns=False,
    dataloader_num_workers=2,
)


# Initialize trainer

In [None]:
# Verify your dataset structure
print("Checking dataset structure:")
print("Train sample keys:", train_dataset[0].keys())
print("Input features shape:", np.array(train_dataset[0]["input_features"]).shape)
print("Labels sample:", train_dataset[0]["labels"][:10])

# Test data collator
print("\nTesting data collator:")
sample_batch = [train_dataset[0], train_dataset[1]]

collated = data_collator(sample_batch)
print("Collated batch keys:", collated.keys())
print("Input features shape:", collated["input_features"].shape)
print("Labels shape:", collated["labels"].shape)

# ‚úÖ CRITICAL FIX: Pass feature_extractor as tokenizer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,  # ‚úÖ Use feature_extractor, NOT tokenizer
)

print("üöÄ Starting LoRA training...\n")
print("="*60)

# Clear GPU cache before training
torch.cuda.empty_cache()

trainer.train()

print("\n" + "="*60)
print("‚úÖ Training complete!")
print("="*60)

# evaluation on test 

In [None]:
print("\nüß™ Evaluating on test set...")
test_results = trainer.evaluate(test_dataset)

print(f"\n{'='*60}")
print(f"üìä Final Test Results:")
print(f"   WER: {test_results['eval_wer']:.2f}%")
print(f"{'='*60}\n")


In [None]:
# STEP 14: Save LoRA Adapters (IMPORTANT!)
# ============================================================================

print("üíæ Saving LoRA adapters...")

# Save only LoRA weights (very small file!)
model.save_pretrained("./whisper-large-v3-lora-adapters")
processor.save_pretrained("./whisper-large-v3-lora-adapters")

print("‚úÖ LoRA adapters saved! (~10-50MB instead of 3GB)")


In [None]:
# STEP 15: Merge LoRA weights with base model (Optional)
# ============================================================================

print("\nüîÄ Merging LoRA weights into full model...")

# Merge and save full model
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./whisper-large-v3-malaysian-merged")
processor.save_pretrained("./whisper-large-v3-malaysian-merged")

print("‚úÖ Merged model saved!")


In [None]:
print("\nüéØ Testing predictions...\n")

from transformers import pipeline

pipe = pipeline(
    "automatic-speech-recognition",
    model="./whisper-large-v3-malaysian-merged",
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=0 if torch.cuda.is_available() else -1
)

test_samples = test_df.sample(min(5, len(test_df)))

for idx, row in test_samples.iterrows():
    prediction = pipe(row['audio_path'])["text"]
    print(f"File: {row['filename']}")
    print(f"Expected: {row['transcription']}")
    print(f"Predicted: {prediction}")
    print(f"-" * 40)

print("\n‚úÖ LoRA fine-tuning complete!")
print("\nüì¶ You now have:")
print("   1. LoRA adapters: ./whisper-large-v3-lora-adapters (~10-50MB)")
print("   2. Merged model: ./whisper-large-v3-malaysian-merged (~3GB)")
