# Wav2Vec2-XLSR Model Evaluation - Google Colab Standalone

This notebook evaluates **Wav2Vec2-XLSR Vietnamese models** on Vietnamese datasets.

## Models Evaluated:
- anuragshas/wav2vec2-large-xlsr-53-vietnamese
- nguyenvulebinh/wav2vec2-base-vietnamese-250h

## Features:
- Complete standalone execution (no external files needed)
- Downloads datasets from HuggingFace automatically
- Uses full datasets (respects existing train/val/test splits)
- Calculates comprehensive metrics (WER, CER, MER, WIL, WIP, SER, RTF)
- Generates visualizations
- Exports results as CSV for cross-model comparison

**Runtime**: GPU recommended (T4 or better)

## Step 1: Install Dependencies

In [None]:
print('[SETUP] Installing required packages...')
!pip install -q transformers==4.57.1 torch torchcodec torchaudio librosa soundfile jiwer datasets accelerate pandas matplotlib seaborn scipy tqdm
print('[OK] All packages installed successfully!')

# Check GPU availability
import torch
print(f'\n[INFO] CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'[INFO] GPU: {torch.cuda.get_device_name(0)}')
else:
    print('[WARNING] Running on CPU - evaluation will be slower')

## Step 2: Embedded Helper Functions

In [None]:
# Embedded Vietnamese ASR Evaluation Code

import time
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Optional
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

# METRICS MODULE
from jiwer import wer, cer, mer, wil, wip, process_words

class ASRMetrics:
    @staticmethod
    def calculate_all_metrics(references: List[str], hypotheses: List[str]) -> Dict:
        ref_text = ' '.join(references)
        hyp_text = ' '.join(hypotheses)
        output = process_words(ref_text, hyp_text)
        return {
            'wer': wer(ref_text, hyp_text),
            'cer': cer(ref_text, hyp_text),
            'mer': mer(ref_text, hyp_text),
            'wil': wil(ref_text, hyp_text),
            'wip': wip(ref_text, hyp_text),
            'ser': sum(1 for r, h in zip(references, hypotheses) if r != h) / len(references),
            'insertions': output.insertions,
            'deletions': output.deletions,
            'substitutions': output.substitutions
        }

class RTFTimer:
    def __init__(self):
        self.start_time = None
        self.elapsed_time = None
    def __enter__(self):
        self.start_time = time.time()
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.elapsed_time = time.time() - self.start_time

# DATASET LOADER MODULE
from datasets import load_dataset, Audio
import soundfile as sf
import tempfile

@dataclass
class AudioSample:
    audio_path: str
    transcription: str
    duration: float = 0.0
    sample_rate: int = 16000
    dataset: str = ''
    split: str = ''

def load_huggingface_dataset(dataset_name: str, max_samples: int = None) -> Dict[str, List[AudioSample]]:
    """Load dataset with proper split handling."""
    configs = {
        'ViMD': {'id': 'nguyendv02/ViMD_Dataset', 'splits': ['train', 'test', 'valid'], 
                 'audio_col': 'audio', 'text_col': 'text'},
        'BUD500': {'id': 'linhtran92/viet_bud500', 'splits': ['train', 'validation', 'test'], 
                   'audio_col': 'audio', 'text_col': 'transcription'},
        'LSVSC': {'id': 'doof-ferb/LSVSC', 'splits': ['train', 'validation', 'test'], 
                  'audio_col': 'audio', 'text_col': 'transcription'},
        'VLSP2020': {'id': 'doof-ferb/vlsp2020_vinai_100h', 'splits': ['train'], 
                     'audio_col': 'audio', 'text_col': 'transcription'},
        'VietMed': {'id': 'leduckhai/VietMed', 'splits': ['train', 'test', 'dev'], 
                    'audio_col': 'audio', 'text_col': 'text'}
    }
    
    config = configs[dataset_name]
    print(f"[INFO] Loading {dataset_name} from HuggingFace Hub...")
    
    samples_by_split = {'train': [], 'val': [], 'test': []}
    temp_dir = Path(tempfile.gettempdir()) / 'asr_audio' / dataset_name
    temp_dir.mkdir(parents=True, exist_ok=True)
    
    all_samples = []
    for split in config['splits']:
        try:
            print(f"  Loading {split} split...")
            dataset = load_dataset(config['id'], split=split, trust_remote_code=True)
            
            if config['audio_col'] in dataset.column_names:
                dataset = dataset.cast_column(config['audio_col'], Audio(sampling_rate=16000))
            
            # Limit only if specified (for quick testing)
            if max_samples and len(dataset) > max_samples:
                dataset = dataset.select(range(max_samples))
            
            samples = []
            for idx, item in enumerate(tqdm(dataset, desc=f"{split}", leave=False)):
                try:
                    audio_data = item[config['audio_col']]
                    audio_path = str(temp_dir / f"{split}_{idx}.wav")
                    sf.write(audio_path, audio_data['array'], audio_data['sampling_rate'])
                    duration = len(audio_data['array']) / audio_data['sampling_rate']
                    transcription = str(item[config['text_col']]).strip().lower()
                    
                    sample = AudioSample(
                        audio_path=audio_path,
                        transcription=transcription,
                        duration=duration,
                        sample_rate=audio_data['sampling_rate'],
                        dataset=dataset_name,
                        split=split
                    )
                    samples.append(sample)
                except Exception as e:
                    continue
            
            # Map to standard splits
            if split in ['train', 'training']:
                samples_by_split['train'].extend(samples)
            elif split in ['val', 'validation', 'dev', 'valid']:
                samples_by_split['val'].extend(samples)
            elif split in ['test', 'testing']:
                samples_by_split['test'].extend(samples)
            
            all_samples.extend(samples)
            print(f"  [OK] Loaded {len(samples)} samples from {split}")
        except Exception as e:
            print(f"  [WARNING] Failed to load {split}: {e}")
    
    # Special handling for VLSP2020 (only has train split)
    if dataset_name == 'VLSP2020' and all_samples:
        print(f"  [INFO] VLSP2020 only has train split - creating train/val/test splits (70/15/15)")
        np.random.seed(42)
        indices = np.random.permutation(len(all_samples))
        train_end = int(len(all_samples) * 0.7)
        val_end = int(len(all_samples) * 0.85)
        
        samples_by_split['train'] = [all_samples[i] for i in indices[:train_end]]
        samples_by_split['val'] = [all_samples[i] for i in indices[train_end:val_end]]
        samples_by_split['test'] = [all_samples[i] for i in indices[val_end:]]
        
        print(f"  [OK] Split into: train={len(samples_by_split['train'])}, "
              f"val={len(samples_by_split['val'])}, test={len(samples_by_split['test'])}")
    
    return samples_by_split

# MODEL EVALUATOR MODULE
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import librosa

class Wav2Vec2Model:
    def __init__(self, model_id: str):
        self.model_id = model_id
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.processor = None
        self.model = None
    
    def load_model(self):
        print(f"[INFO] Loading {self.model_id}...")
        try:
            self.processor = Wav2Vec2Processor.from_pretrained(self.model_id)
            self.model = Wav2Vec2ForCTC.from_pretrained(self.model_id)
            self.model.to(self.device)
            self.model.eval()
            print(f"[OK] Model loaded on {self.device}")
        except Exception as e:
            print(f"[ERROR] Failed to load model: {e}")
            raise
    
    def transcribe(self, audio_path: str) -> str:
        try:
            audio, sr = librosa.load(audio_path, sr=16000)
            input_values = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(self.device)
            
            with torch.no_grad():
                logits = self.model(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = self.processor.batch_decode(predicted_ids)[0]
            return transcription.strip().lower()
        except Exception as e:
            print(f"[ERROR] Transcription failed: {e}")
            return ""

print('[OK] All helper functions loaded successfully!')

## Step 3: Configuration

In [None]:
from datetime import datetime

MODELS_TO_TEST = [
    'anuragshas/wav2vec2-large-xlsr-53-vietnamese',
    'nguyenvulebinh/wav2vec2-base-vietnamese-250h',
]

DATASETS_TO_TEST = [
    'ViMD',
    # 'BUD500',
    # 'LSVSC',
    # 'VLSP2020',
    # 'VietMed',
]

# Set to None to use FULL datasets, or set a number for quick testing
MAX_SAMPLES_PER_SPLIT = None  # None = use all data

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_CSV = f"/content/wav2vec2_results_{TIMESTAMP}.csv"

print(f'[CONFIG] Models: {len(MODELS_TO_TEST)}')
print(f'[CONFIG] Datasets: {len(DATASETS_TO_TEST)}')
print(f'[CONFIG] Samples per split: {MAX_SAMPLES_PER_SPLIT or "ALL (full dataset)"}')
print(f'[CONFIG] Output: {OUTPUT_CSV}')

## Step 4: Load Datasets

In [None]:
datasets_loaded = {}

for dataset_name in DATASETS_TO_TEST:
    print(f"\n{'='*60}")
    print(f"Loading dataset: {dataset_name}")
    print(f"{'='*60}")
    
    try:
        splits = load_huggingface_dataset(dataset_name, max_samples=MAX_SAMPLES_PER_SPLIT)
        datasets_loaded[dataset_name] = splits
        
        print(f"\n[OK] {dataset_name} loaded:")
        print(f"  - Train: {len(splits['train'])} samples")
        print(f"  - Val: {len(splits['val'])} samples")
        print(f"  - Test: {len(splits['test'])} samples")
    except Exception as e:
        print(f"[ERROR] Failed to load {dataset_name}: {e}")
        datasets_loaded[dataset_name] = {'train': [], 'val': [], 'test': []}

print(f"\n[OK] Successfully loaded {len(datasets_loaded)} datasets")

## Step 5: Run Evaluation

In [None]:
results = []
metrics_calculator = ASRMetrics()
total_start_time = time.time()

for model_id in MODELS_TO_TEST:
    print(f"\n\n{'='*70}")
    print(f"EVALUATING MODEL: {model_id}")
    print(f"{'='*70}\n")
    
    try:
        model = Wav2Vec2Model(model_id)
        model.load_model()
    except Exception as e:
        print(f"[ERROR] Skipping {model_id}: {e}")
        continue
    
    for dataset_name, splits in datasets_loaded.items():
        test_samples = splits['test']
        
        if not test_samples:
            print(f"[WARNING] No test samples for {dataset_name}, skipping...")
            continue
        
        print(f"\n[INFO] Evaluating on {dataset_name} ({len(test_samples)} samples)...")
        
        references = []
        hypotheses = []
        audio_durations = []
        processing_times = []
        
        for sample in tqdm(test_samples, desc=f"{dataset_name}", leave=False):
            try:
                with RTFTimer() as timer:
                    hypothesis = model.transcribe(sample.audio_path)
                
                references.append(sample.transcription)
                hypotheses.append(hypothesis)
                audio_durations.append(sample.duration)
                processing_times.append(timer.elapsed_time)
            except Exception as e:
                continue
        
        if references and hypotheses:
            metrics = metrics_calculator.calculate_all_metrics(references, hypotheses)
            total_audio_duration = sum(audio_durations)
            total_processing_time = sum(processing_times)
            rtf = total_processing_time / total_audio_duration if total_audio_duration > 0 else 0
            
            result = {
                'model': model_id,
                'dataset': dataset_name,
                'samples_processed': len(references),
                'WER': metrics['wer'],
                'CER': metrics['cer'],
                'MER': metrics['mer'],
                'WIL': metrics['wil'],
                'WIP': metrics['wip'],
                'SER': metrics['ser'],
                'RTF': rtf,
                'insertions': metrics['insertions'],
                'deletions': metrics['deletions'],
                'substitutions': metrics['substitutions'],
                'total_audio_duration': total_audio_duration,
                'total_processing_time': total_processing_time
            }
            results.append(result)
            print(f"  [OK] WER: {metrics['wer']:.4f} | CER: {metrics['cer']:.4f} | RTF: {rtf:.4f}")
    
    del model
    torch.cuda.empty_cache()

total_time = time.time() - total_start_time
print(f"\n\n{'='*70}")
print(f"EVALUATION COMPLETE! Time: {total_time/60:.2f} min")
print(f"{'='*70}")

## Step 6-8: Results, Visualizations & Export

(Same as notebooks 01-02)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

results_df = pd.DataFrame(results)
print(results_df[['model', 'dataset', 'WER', 'CER', 'RTF']].to_string(index=False))

# Save and download
results_df.to_csv(OUTPUT_CSV, index=False)
print(f"\n[OK] Results saved to: {OUTPUT_CSV}")

try:
    from google.colab import files
    files.download(OUTPUT_CSV)
except:
    print("[INFO] File saved locally")