# [TEMPLATE] Custom ASR Model Evaluation

This is a **reusable template** for evaluating new ASR models on Vietnamese datasets.

## How to Use This Template:

1. **Duplicate this notebook** and rename it (e.g., `06_newmodel_evaluation.ipynb`)
2. **Update the configuration section** (Cell 4):
   - Change `MODEL_FAMILY` name
   - Update `MODELS_TO_TEST` list with your model IDs
   - Adjust dataset list if needed
3. **Run all cells** to perform evaluation

## What This Template Evaluates:

- **Datasets**: ViMD, BUD500, LSVSC, VLSP2020, VietMed (customizable)
- **Metrics**: WER, CER, MER, WIL, WIP, SER, RTF
- **Outputs**: CSV results, JSON reports, comprehensive visualizations

**Compatible with**: Local & Google Colab  
**Report output**: `/docs/reports/`

---

## Quick Start:

```python
# In Cell 4, update these variables:
MODEL_FAMILY = "YourModelFamily"  # e.g., "Seamless-M4T", "MMS"
MODELS_TO_TEST = [
    "your-org/model-name-1",
    "your-org/model-name-2",
]
```

## 1. Environment Setup & Dependencies

In [None]:
# Google Colab Setup - Run this cell FIRST if using Colab
import os

# Detect if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("[INFO] Running on Google Colab")
except ImportError:
    IN_COLAB = False
    print("[INFO] Running locally - skipping Colab setup")

if IN_COLAB:
    print("\n[SETUP] Setting up Google Colab environment...")
    
    # Clone repository
    REPO_URL = "https://github.com/quangnt03/vietnamese-asr-benchmark.git"
    REPO_NAME = "vietnamese_asr_benchmark"
    
    if not os.path.exists(REPO_NAME):
        print(f"[SETUP] Cloning repository from {REPO_URL}...")
        !git clone {REPO_URL}
        print("[OK] Repository cloned successfully")
    else:
        print(f"[INFO] Repository already exists at {REPO_NAME}")
    
    # Change to repository directory
    os.chdir(REPO_NAME)
    print(f"[OK] Changed directory to: {os.getcwd()}")
    
    # Install dependencies
    print("\n[SETUP] Installing dependencies...")
    !pip install -q -r requirements.txt
    print("[OK] Dependencies installed")
    
    print("\n[OK] Google Colab setup complete!")
    print("[INFO] You can now run the remaining cells")
else:
    print("[INFO] Local environment - ensure you are in the project root directory")

## Google Colab Setup (Run this cell first if using Colab)

In [None]:
# Cell 1: Environment detection and setup
%load_ext autoreload
%autoreload 3
import sys
from pathlib import Path

# Import notebook utilities
try:
    from src.notebook_utils import (
        detect_environment,
        setup_paths,
        install_dependencies,
        print_environment_info,
        ReportGenerator,
        create_evaluation_summary
    )
except ImportError:
    # If not in notebooks directory, add parent to path
    notebook_dir = Path.cwd()
    if notebook_dir.name != 'notebooks':
        sys.path.insert(0, str(notebook_dir.parent))
    from src.notebook_utils import (
        detect_environment,
        setup_paths,
        install_dependencies,
        print_environment_info,
        ReportGenerator,
        create_evaluation_summary
    )

# Detect environment
ENV = detect_environment()
print(f"[INFO] Running in: {ENV}")

# Install dependencies if needed (mainly for Colab)
install_dependencies(ENV)

# Setup paths
PATHS = setup_paths()
print(f"\n[OK] Project root: {PATHS['project_root']}")
print(f"[OK] Data directory: {PATHS['data_dir']}")
print(f"[OK] Config file: {PATHS['config_file']}")
print(f"[OK] Reports directory: {PATHS['reports_dir']}")

In [None]:
# Cell 2: Print environment info
print_environment_info()

In [None]:
# Cell 3: Import project modules
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
from datetime import datetime
import time
from tqdm.auto import tqdm

# Import project modules with proper src. prefix for IDE type hints
from src.dataset_loader import DatasetManager, AudioSample
from src.model_evaluator import ModelEvaluator, ModelFactory
from src.metrics import ASRMetrics, RTFTimer
from src.visualization import ASRVisualizer

print("[OK] All modules imported successfully")

## 2. Configuration

[CUSTOMIZE THIS SECTION] Update the variables below for your model family.

In [None]:
# Cell 4: Configuration
# [CUSTOMIZE] Model configuration
MODEL_FAMILY = "CustomModel"  # CHANGE THIS: e.g., "Seamless-M4T", "MMS", "Canary"
MODELS_TO_TEST = [
    # [CUSTOMIZE] Add your model IDs here
    # Examples:
    # "facebook/seamless-m4t-large",
    # "facebook/mms-1b-all",
    # "your-org/your-model-name",
]

# Dataset configuration
DATASETS_TO_TEST = [
    "ViMD",
    "BUD500",
    "LSVSC",
    "VLSP2020",
    "VietMed"
]

# Evaluation configuration
MAX_SAMPLES_PER_DATASET = None  # None = all samples, or set to e.g., 50 for quick testing
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# Output configuration
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = PATHS['output_dir'] / f"{MODEL_FAMILY.lower().replace(' ', '_')}_{TIMESTAMP}"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"[CONFIG] Model family: {MODEL_FAMILY}")
print(f"[CONFIG] Models to test: {len(MODELS_TO_TEST)}")
print(f"[CONFIG] Datasets to test: {len(DATASETS_TO_TEST)}")
print(f"[CONFIG] Max samples per dataset: {MAX_SAMPLES_PER_DATASET or 'All'}")
print(f"[CONFIG] Output directory: {OUTPUT_DIR}")

# Validation
if not MODELS_TO_TEST:
    print("\n[WARNING] No models specified in MODELS_TO_TEST!")
    print("[INFO] Please update Cell 4 with your model IDs before running evaluation.")

## 3. Load Datasets

In [None]:
# Cell 5: Initialize dataset manager
dataset_manager = DatasetManager(config_file=PATHS['config_file'])
print("[OK] Dataset manager initialized")

In [None]:
# Cell 6: Load all datasets
datasets_loaded = {}
dataset_stats = []

for dataset_name in tqdm(DATASETS_TO_TEST, desc="Loading datasets"):
    try:
        # Load dataset
        samples = dataset_manager.load_dataset(
            dataset_name=dataset_name
        )
        # Get test split
        test_samples = samples['test']
        
        # Limit samples if specified
        if MAX_SAMPLES_PER_DATASET:
            test_samples = test_samples[:MAX_SAMPLES_PER_DATASET]
        
        datasets_loaded[dataset_name] = test_samples
        
        # Collect stats
        dataset_stats.append({
            'Dataset': dataset_name,
            'Total Samples': len(samples['train']) + len(samples['val']) + len(samples['test']),
            'Test Samples': len(test_samples),
            'Used Samples': len(test_samples)
        })
        
        print(f"[OK] {dataset_name}: {len(test_samples)} test samples loaded")
    except Exception as e:
        print(f"[WARNING] Failed to load {dataset_name}: {e}")
        datasets_loaded[dataset_name] = []

# Display stats
stats_df = pd.DataFrame(dataset_stats)
print("\n[INFO] Dataset Statistics:")
print(stats_df.to_string(index=False))

## 4. Initialize Models

In [None]:
# Cell 7: Initialize model evaluator
model_evaluator = ModelEvaluator()
metrics_calculator = ASRMetrics()

print("[OK] Model evaluator and metrics calculator initialized")

In [None]:
## 5. Run Evaluation

# Cell 8: Main evaluation loop
results = []
total_start_time = time.time()
total_samples_processed = 0

# Check if models are specified
if not MODELS_TO_TEST:
    print("[ERROR] No models specified in MODELS_TO_TEST!")
    print("[INFO] Please update Cell 4 and re-run from there.")
else:
    # Iterate through each model
    for model_name in MODELS_TO_TEST:
        print(f"\n{'='*60}")
        print(f"[INFO] Evaluating model: {model_name}")
        print(f"{'='*60}")
        
        # Load model
        try:
            model = ModelFactory.create_model(model_name)
            model.load_model()
            print(f"[OK] Model loaded successfully")
        except Exception as e:
            print(f"[ERROR] Failed to load model {model_name}: {e}")
            continue
        
        # Evaluate on each dataset
        for dataset_name, test_samples in datasets_loaded.items():
            if not test_samples:
                print(f"[WARNING] Skipping {dataset_name} - no samples")
                continue
            
            print(f"\n[INFO] Testing on {dataset_name} ({len(test_samples)} samples)...")
            
            # Prepare for evaluation
            references = []
            hypotheses = []
            audio_durations = []
            processing_times = []
            
            # Process each sample
            for sample in tqdm(test_samples, desc=f"{dataset_name}", leave=False):
                try:
                    # Transcribe with RTF measurement
                    with RTFTimer() as timer:
                        hypothesis = model.transcribe(sample.audio_path)
                    
                    # Store results
                    references.append(sample.transcription)
                    hypotheses.append(hypothesis)
                    
                    # Get audio duration for RTF calculation
                    import librosa
                    duration = librosa.get_duration(path=sample.audio_path)
                    audio_durations.append(duration)
                    processing_times.append(timer.elapsed_time)
                    
                    total_samples_processed += 1
                except Exception as e:
                    print(f"[WARNING] Failed to process sample {sample.file_id}: {e}")
                    continue
            
            # Calculate metrics
            if references and hypotheses:
                metrics = metrics_calculator.calculate_all_metrics(
                    references=references,
                    hypotheses=hypotheses
                )
                
                # Calculate RTF
                total_audio_duration = sum(audio_durations)
                total_processing_time = sum(processing_times)
                rtf = total_processing_time / total_audio_duration if total_audio_duration > 0 else 0
                
                # Store results
                result = {
                    'model': model_name,
                    'dataset': dataset_name,
                    'samples_processed': len(references),
                    'WER': metrics['wer'],
                    'CER': metrics['cer'],
                    'MER': metrics['mer'],
                    'WIL': metrics['wil'],
                    'WIP': metrics['wip'],
                    'SER': metrics['ser'],
                    'RTF': rtf,
                    'insertions': metrics['insertions'],
                    'deletions': metrics['deletions'],
                    'substitutions': metrics['substitutions'],
                    'total_audio_duration': total_audio_duration,
                    'total_processing_time': total_processing_time
                }
                results.append(result)
                
                print(f"[OK] WER: {metrics['wer']:.4f} | CER: {metrics['cer']:.4f} | RTF: {rtf:.4f}")
            else:
                print(f"[WARNING] No valid results for {dataset_name}")

    total_evaluation_time = time.time() - total_start_time

    print(f"\n\n{'='*60}")
    print(f"[OK] Evaluation completed!")
    print(f"[INFO] Total time: {total_evaluation_time:.2f}s ({total_evaluation_time/60:.2f} minutes)")
    print(f"[INFO] Total samples processed: {total_samples_processed}")
    print(f"{'='*60}")

In [None]:
## 6. Results Analysis

In [None]:
# Cell 9: Create results DataFrame
if results:
    results_df = pd.DataFrame(results)

    # Display results
    print("[INFO] Complete Results:")
    print(results_df.to_string(index=False))

    # Save to CSV
    csv_path = OUTPUT_DIR / f"{MODEL_FAMILY.lower().replace(' ', '_')}_results_{TIMESTAMP}.csv"
    results_df.to_csv(csv_path, index=False)
    print(f"\n[OK] Results saved to: {csv_path}")
else:
    print("[WARNING] No results to display. Please check if models were specified and loaded correctly.")

In [None]:
# Cell 10: Summary statistics
if results:
    print("\n[CHART] Average Performance by Model:")
    model_avg = results_df.groupby('model')[['WER', 'CER', 'MER', 'RTF']].mean()
    print(model_avg.to_string())

    print("\n[CHART] Average Performance by Dataset:")
    dataset_avg = results_df.groupby('dataset')[['WER', 'CER', 'MER', 'RTF']].mean()
    print(dataset_avg.to_string())

# Cell 11: Find best model
if results:
    best_wer_idx = results_df['WER'].idxmin()
    best_model_info = results_df.loc[best_wer_idx]

    print("[TARGET] Best Model (Lowest WER):")
    print(f"  Model: {best_model_info['model']}")
    print(f"  Dataset: {best_model_info['dataset']}")
    print(f"  WER: {best_model_info['WER']:.4f}")
    print(f"  CER: {best_model_info['CER']:.4f}")
    print(f"  RTF: {best_model_info['RTF']:.4f}")

In [None]:
## 7. Visualizations

# Cell 12: Create visualizations
if results:
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Set style
    sns.set_style("whitegrid")
    plt.rcParams['figure.figsize'] = (12, 6)

    # Create plots directory
    plots_dir = OUTPUT_DIR / "plots"
    plots_dir.mkdir(exist_ok=True)

    # Initialize visualizer
    visualizer = ASRVisualizer(output_dir=str(plots_dir))

    print("[OK] Visualizer initialized")
else:
    print("[WARNING] No results to visualize")

In [None]:
# Cell 13: WER comparison plot
if results:
    plt.figure(figsize=(14, 6))
    pivot_wer = results_df.pivot(index='dataset', columns='model', values='WER')
    pivot_wer.plot(kind='bar', ax=plt.gca())
    plt.title(f'Word Error Rate (WER) Comparison - {MODEL_FAMILY} Models', fontsize=14, fontweight='bold')
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('WER (Lower is Better)', fontsize=12)
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(plots_dir / 'wer_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] WER comparison plot saved")

In [None]:
# Cell 14: CER comparison plot
if results:
    plt.figure(figsize=(14, 6))
    pivot_cer = results_df.pivot(index='dataset', columns='model', values='CER')
    pivot_cer.plot(kind='bar', ax=plt.gca())
    plt.title(f'Character Error Rate (CER) Comparison - {MODEL_FAMILY} Models', fontsize=14, fontweight='bold')
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('CER (Lower is Better)', fontsize=12)
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(plots_dir / 'cer_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] CER comparison plot saved")

# Cell 15: RTF comparison plot
if results:
    plt.figure(figsize=(14, 6))
    pivot_rtf = results_df.pivot(index='dataset', columns='model', values='RTF')
    pivot_rtf.plot(kind='bar', ax=plt.gca())
    plt.title(f'Real-Time Factor (RTF) Comparison - {MODEL_FAMILY} Models', fontsize=14, fontweight='bold')
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('RTF (Lower is Better, <1.0 = Real-time)', fontsize=12)
    plt.axhline(y=1.0, color='r', linestyle='--', label='Real-time threshold')
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(plots_dir / 'rtf_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] RTF comparison plot saved")

In [None]:
# Cell 16: Heatmap of all metrics
if results:
    plt.figure(figsize=(16, 10))
    heatmap_data = results_df.set_index(['model', 'dataset'])[['WER', 'CER', 'MER', 'WIL', 'WIP', 'SER', 'RTF']]
    sns.heatmap(heatmap_data, annot=True, fmt='.4f', cmap='RdYlGn_r', cbar_kws={'label': 'Metric Value'})
    plt.title(f'All Metrics Heatmap - {MODEL_FAMILY} Models', fontsize=14, fontweight='bold')
    plt.xlabel('Metric', fontsize=12)
    plt.ylabel('Model + Dataset', fontsize=12)
    plt.tight_layout()
    plt.savefig(plots_dir / 'metrics_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] Metrics heatmap saved")

In [None]:
## 8. Generate Report

# Cell 17: Generate comprehensive report
if results:
    report_generator = ReportGenerator(reports_dir=PATHS['reports_dir'])

    # Prepare report data
    report_data = {
        'models': MODELS_TO_TEST,
        'datasets': DATASETS_TO_TEST,
        'metrics_summary': {i: row.to_dict() for i, row in results_df.iterrows()},
        'best_model': {
            'model_name': best_model_info['model'],
            'dataset': best_model_info['dataset'],
            'WER': best_model_info['WER'],
            'CER': best_model_info['CER'],
            'RTF': best_model_info['RTF']
        },
        'evaluation_time': total_evaluation_time,
        'total_samples': total_samples_processed
    }

    # Generate Markdown report
    report_path = report_generator.generate_model_report(
        model_family=MODEL_FAMILY,
        results=report_data,
        output_filename=f"Báo_cáo_{MODEL_FAMILY.replace(' ', '_')}_{TIMESTAMP}.md"
    )

    # Save JSON results
    json_path = report_generator.save_results_json(
        results=report_data,
        filename=f"{MODEL_FAMILY.lower().replace(' ', '_')}_results_{TIMESTAMP}.json"
    )

    print(f"\n[OK] Markdown report: {report_path}")
    print(f"[OK] JSON results: {json_path}")

In [None]:
# Cell 18: Print evaluation summary
if results:
    summary = create_evaluation_summary(
        model_family=MODEL_FAMILY,
        models_tested=MODELS_TO_TEST,
        datasets_tested=DATASETS_TO_TEST,
        total_samples=total_samples_processed,
        total_time=total_evaluation_time
    )
    print(summary)

## 9. Export & Conclusion

## Template Usage Examples

### Example 1: Evaluate Meta's Seamless M4T

```python
# Cell 4 configuration:
MODEL_FAMILY = "Seamless-M4T"
MODELS_TO_TEST = [
    "facebook/seamless-m4t-large",
    "facebook/seamless-m4t-medium",
]
```

### Example 2: Evaluate Meta's MMS (Massively Multilingual Speech)

```python
# Cell 4 configuration:
MODEL_FAMILY = "MMS"
MODELS_TO_TEST = [
    "facebook/mms-1b-all",
    "facebook/mms-300m",
]
```

### Example 3: Evaluate NVIDIA Canary

```python
# Cell 4 configuration:
MODEL_FAMILY = "Canary"
MODELS_TO_TEST = [
    "nvidia/canary-1b",
]
```

### Example 4: Quick Testing (Limited Samples)

```python
# Cell 4 configuration:
MODEL_FAMILY = "YourModel"
MODELS_TO_TEST = ["your-model-id"]
MAX_SAMPLES_PER_DATASET = 10  # Test with only 10 samples per dataset
DATASETS_TO_TEST = ["ViMD", "VLSP2020"]  # Test on fewer datasets
```

---

## Tips for Using This Template

1. **Start Small**: Test with `MAX_SAMPLES_PER_DATASET = 10` first to verify everything works
2. **Monitor Memory**: Large models may require GPU with sufficient VRAM
3. **Save Frequently**: Results are automatically saved after each evaluation
4. **Compare Results**: Use notebook `05_cross_model_comparison.ipynb` to compare across model families
5. **Check Model IDs**: Verify HuggingFace model IDs are correct before running
6. **Custom Models**: Follow the "How to Add Custom Models" section above

---

## Next Steps

After evaluating your models:

1. Review results in the generated CSV file
2. Check visualizations in the `plots/` directory
3. Read the comprehensive report in `docs/reports/`
4. Compare with other models using `05_cross_model_comparison.ipynb`
5. Share your findings or contribute back to the project

---

**Happy Evaluating!**

---

## How to Add Custom Models to the Framework

If your model is not in the default list, you'll need to add it to `src/model_evaluator.py`:

### Step 1: Add Model Configuration

```python
# In src/model_evaluator.py, add to MODEL_CONFIGS dict:
'your-model-key': ModelConfig(
    name='YourModelName',
    model_id='your-org/your-model-id',
    model_type='your-model-type'  # e.g., 'whisper', 'wav2vec2', 'custom'
),
```

### Step 2: Create Model Class (if needed)

If your model type is new, create a class:

```python
class YourCustomModel(BaseASRModel):
    def load_model(self):
        # Load your model here
        from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
        self.processor = AutoProcessor.from_pretrained(self.config.model_id)
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(self.config.model_id)
        
    def transcribe(self, audio_path: str) -> str:
        # Implement transcription logic
        # Return transcribed text
        pass
```

### Step 3: Register in ModelFactory

```python
# In ModelFactory.create_model(), add your model type:
elif config.model_type == 'your-model-type':
    return YourCustomModel(config)
```

### Step 4: Use in This Notebook

```python
MODELS_TO_TEST = [
    "your-model-key",  # Use the key from MODEL_CONFIGS
]
```

In [None]:
# Cell 19: Final summary
if results:
    print(f"[OK] {MODEL_FAMILY} Evaluation Complete!")
    print("\n[INFO] Generated outputs:")
    print(f"  1. Results CSV: {csv_path}")
    print(f"  2. Markdown Report: {report_path}")
    print(f"  3. JSON Results: {json_path}")
    print(f"  4. Visualizations: {plots_dir}/")
    print("\n[NOTE] All files are saved in:")
    print(f"  - Results: {OUTPUT_DIR}")
    print(f"  - Reports: {PATHS['reports_dir']}")
else:
    print("[INFO] No evaluation was performed.")
    print("[INFO] Please update Cell 4 with your model IDs and re-run the notebook.")