# Cross-Model Comparison & Best Model Selection

This notebook aggregates results from all model families and performs comprehensive comparison:

## Model Families Compared:
1. **PhoWhisper** (Vietnamese-optimized Whisper variants)
2. **OpenAI Whisper** (Original multilingual models)
3. **Wav2Vec2-XLSR** (Vietnamese fine-tuned CTC models)
4. **Wav2Vn** (Vietnamese ASR model - Note: uses mock transcription if not publicly available)

## Analysis Performed:
- Overall best model by WER/CER
- Best model per dataset
- Speed vs accuracy trade-offs (RTF vs WER)
- Statistical significance testing
- Production deployment recommendations

**Prerequisites**: Run notebooks 01, 02, 03, and 04 first  
**Compatible with**: Local & Google Colab  
**Report output**: `/docs/reports/`

## 1. Environment Setup

In [None]:
# Cell 1: Environment detection and setup
%load_ext autoreload
%autoreload 3
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import notebook utilities
try:
    from src.notebook_utils import (
        detect_environment,
        setup_paths,
        print_environment_info,
        ReportGenerator
    )
except ImportError:
    notebook_dir = Path.cwd()
    if notebook_dir.name != 'notebooks':
        sys.path.insert(0, str(notebook_dir.parent))
    from src.notebook_utils import (
        detect_environment,
        setup_paths,
        print_environment_info,
        ReportGenerator
    )

# Detect environment
ENV = detect_environment()
print(f"[INFO] Running in: {ENV}")

# Setup paths
PATHS = setup_paths()
print(f"\n[OK] Project root: {PATHS['project_root']}")
print(f"[OK] Results directory: {PATHS['output_dir']}")
print(f"[OK] Reports directory: {PATHS['reports_dir']}")

In [None]:
# Cell 2: Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from glob import glob
import json
from scipy import stats

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11

print("[OK] All libraries imported successfully")

## 2. Load Results from Previous Evaluations

In [None]:
# Cell 3: Configuration
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = PATHS['output_dir'] / f"cross_comparison_{TIMESTAMP}"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
plots_dir = OUTPUT_DIR / "plots"
plots_dir.mkdir(exist_ok=True)

print(f"[CONFIG] Output directory: {OUTPUT_DIR}")

In [None]:
# Cell 4: Find and load all result files
results_dir = PATHS['output_dir']

# Search for CSV result files
phowhisper_csvs = list(results_dir.glob('phowhisper_*/phowhisper_results_*.csv'))
whisper_csvs = list(results_dir.glob('whisper_*/whisper_results_*.csv'))
wav2vec2_csvs = list(results_dir.glob('wav2vec2_*/wav2vec2_results_*.csv'))
wav2vn_csvs = list(results_dir.glob('wav2vn_*/wav2vn_results_*.csv'))

print(f"[INFO] Found result files:")
print(f"  - PhoWhisper: {len(phowhisper_csvs)} files")
print(f"  - Whisper: {len(whisper_csvs)} files")
print(f"  - Wav2Vec2: {len(wav2vec2_csvs)} files")
print(f"  - Wav2Vn: {len(wav2vn_csvs)} files")

# Load the most recent results for each model family
all_results = []

if phowhisper_csvs:
    latest_pho = sorted(phowhisper_csvs)[-1]
    df_pho = pd.read_csv(latest_pho)
    df_pho['model_family'] = 'PhoWhisper'
    all_results.append(df_pho)
    print(f"[OK] Loaded PhoWhisper: {latest_pho.name}")

if whisper_csvs:
    latest_whi = sorted(whisper_csvs)[-1]
    df_whi = pd.read_csv(latest_whi)
    df_whi['model_family'] = 'Whisper'
    all_results.append(df_whi)
    print(f"[OK] Loaded Whisper: {latest_whi.name}")

if wav2vec2_csvs:
    latest_w2v = sorted(wav2vec2_csvs)[-1]
    df_w2v = pd.read_csv(latest_w2v)
    df_w2v['model_family'] = 'Wav2Vec2'
    all_results.append(df_w2v)
    print(f"[OK] Loaded Wav2Vec2: {latest_w2v.name}")

if wav2vn_csvs:
    latest_w2vn = sorted(wav2vn_csvs)[-1]
    df_w2vn = pd.read_csv(latest_w2vn)
    df_w2vn['model_family'] = 'Wav2Vn'
    all_results.append(df_w2vn)
    print(f"[OK] Loaded Wav2Vn: {latest_w2vn.name}")
    print(f"[NOTE] Wav2Vn uses mock transcription - results are not real")

if not all_results:
    print("\n[WARNING] No result files found!")
    print("[INFO] Please run notebooks 01-04 first to generate results.")
else:
    # Combine all results
    combined_df = pd.concat(all_results, ignore_index=True)
    print(f"\n[OK] Combined results: {len(combined_df)} rows")
    print(f"[INFO] Model families: {combined_df['model_family'].unique().tolist()}")
    print(f"[INFO] Datasets: {combined_df['dataset'].unique().tolist()}")

In [None]:
# Cell 5: Display combined results
if all_results:
    print("[INFO] Combined Evaluation Results:")
    display_cols = ['model_family', 'model', 'dataset', 'WER', 'CER', 'RTF', 'samples_processed']
    print(combined_df[display_cols].to_string(index=False))

## 3. Overall Best Model Analysis

In [None]:
# Cell 6: Find best models overall
if all_results:
    # Best by WER
    best_wer_idx = combined_df['WER'].idxmin()
    best_wer = combined_df.loc[best_wer_idx]
    
    # Best by CER
    best_cer_idx = combined_df['CER'].idxmin()
    best_cer = combined_df.loc[best_cer_idx]
    
    # Best by RTF (fastest)
    best_rtf_idx = combined_df['RTF'].idxmin()
    best_rtf = combined_df.loc[best_rtf_idx]
    
    print("[TARGET] BEST MODELS OVERALL")
    print("="*60)
    print(f"\n[1] Best WER:")
    print(f"    Model: {best_wer['model']}")
    print(f"    Dataset: {best_wer['dataset']}")
    print(f"    WER: {best_wer['WER']:.4f}")
    print(f"    CER: {best_wer['CER']:.4f}")
    print(f"    RTF: {best_wer['RTF']:.4f}")
    
    print(f"\n[2] Best CER:")
    print(f"    Model: {best_cer['model']}")
    print(f"    Dataset: {best_cer['dataset']}")
    print(f"    WER: {best_cer['WER']:.4f}")
    print(f"    CER: {best_cer['CER']:.4f}")
    print(f"    RTF: {best_cer['RTF']:.4f}")
    
    print(f"\n[3] Fastest (Best RTF):")
    print(f"    Model: {best_rtf['model']}")
    print(f"    Dataset: {best_rtf['dataset']}")
    print(f"    WER: {best_rtf['WER']:.4f}")
    print(f"    CER: {best_rtf['CER']:.4f}")
    print(f"    RTF: {best_rtf['RTF']:.4f}")
    print("="*60)

In [None]:
# Cell 7: Average performance by model family
if all_results:
    print("\n[CHART] Average Performance by Model Family:")
    family_avg = combined_df.groupby('model_family')[['WER', 'CER', 'MER', 'RTF']].mean()
    print(family_avg.to_string())
    
    print("\n[CHART] Average Performance by Model:")
    model_avg = combined_df.groupby('model')[['WER', 'CER', 'MER', 'RTF']].mean().sort_values('WER')
    print(model_avg.to_string())

In [None]:
# Cell 8: Best model per dataset
if all_results:
    print("\n[LIST] Best Model per Dataset (by WER):")
    print("="*60)
    
    for dataset in combined_df['dataset'].unique():
        dataset_df = combined_df[combined_df['dataset'] == dataset]
        best_idx = dataset_df['WER'].idxmin()
        best = dataset_df.loc[best_idx]
        
        print(f"\n{dataset}:")
        print(f"  Model: {best['model']}")
        print(f"  WER: {best['WER']:.4f} | CER: {best['CER']:.4f} | RTF: {best['RTF']:.4f}")

## 4. Comprehensive Visualizations

In [None]:
# Cell 9: WER comparison across all models and datasets
if all_results:
    fig, ax = plt.subplots(figsize=(16, 8))
    
    # Pivot for grouped bar chart
    pivot_wer = combined_df.pivot_table(index='dataset', columns='model', values='WER', aggfunc='mean')
    pivot_wer.plot(kind='bar', ax=ax, width=0.8)
    
    ax.set_title('Word Error Rate (WER) - All Models Comparison', fontsize=16, fontweight='bold')
    ax.set_xlabel('Dataset', fontsize=13)
    ax.set_ylabel('WER (Lower is Better)', fontsize=13)
    ax.legend(title='Model', bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=9)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(plots_dir / 'all_models_wer_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] WER comparison plot saved")

In [None]:
# Cell 10: Model family comparison boxplot
if all_results:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # WER
    sns.boxplot(data=combined_df, x='model_family', y='WER', ax=axes[0, 0])
    axes[0, 0].set_title('WER Distribution by Model Family', fontweight='bold')
    axes[0, 0].set_ylabel('WER')
    
    # CER
    sns.boxplot(data=combined_df, x='model_family', y='CER', ax=axes[0, 1])
    axes[0, 1].set_title('CER Distribution by Model Family', fontweight='bold')
    axes[0, 1].set_ylabel('CER')
    
    # RTF
    sns.boxplot(data=combined_df, x='model_family', y='RTF', ax=axes[1, 0])
    axes[1, 0].set_title('RTF Distribution by Model Family', fontweight='bold')
    axes[1, 0].set_ylabel('RTF')
    axes[1, 0].axhline(y=1.0, color='r', linestyle='--', linewidth=1, label='Real-time')
    axes[1, 0].legend()
    
    # WIP
    sns.boxplot(data=combined_df, x='model_family', y='WIP', ax=axes[1, 1])
    axes[1, 1].set_title('WIP Distribution by Model Family', fontweight='bold')
    axes[1, 1].set_ylabel('WIP (Higher is Better)')
    
    for ax in axes.flat:
        ax.set_xlabel('Model Family')
    
    plt.tight_layout()
    plt.savefig(plots_dir / 'model_family_distributions.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] Model family distribution plot saved")

In [None]:
# Cell 11: Speed vs Accuracy trade-off (RTF vs WER)
if all_results:
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Scatter plot with model family colors
    for family in combined_df['model_family'].unique():
        family_df = combined_df[combined_df['model_family'] == family]
        ax.scatter(family_df['RTF'], family_df['WER'], label=family, s=100, alpha=0.6)
    
    # Add model labels
    for idx, row in combined_df.iterrows():
        ax.annotate(row['model'].split('/')[-1][:15], 
                   (row['RTF'], row['WER']), 
                   fontsize=8, alpha=0.7, 
                   xytext=(5, 5), textcoords='offset points')
    
    ax.axvline(x=1.0, color='red', linestyle='--', linewidth=2, alpha=0.5, label='Real-time threshold')
    ax.set_xlabel('Real-Time Factor (RTF) - Lower is Faster', fontsize=13)
    ax.set_ylabel('Word Error Rate (WER) - Lower is Better', fontsize=13)
    ax.set_title('Speed vs Accuracy Trade-off: RTF vs WER', fontsize=16, fontweight='bold')
    ax.legend(loc='best')
    ax.grid(True, alpha=0.3)
    
    # Annotate ideal region
    ax.text(0.05, 0.95, 'IDEAL\n(Fast + Accurate)', 
           transform=ax.transAxes, fontsize=12, 
           verticalalignment='top', 
           bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))
    
    plt.tight_layout()
    plt.savefig(plots_dir / 'speed_accuracy_tradeoff.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] Speed vs accuracy plot saved")

In [None]:
# Cell 12: Comprehensive metrics heatmap
if all_results:
    fig, ax = plt.subplots(figsize=(18, 12))
    
    # Create pivot table for heatmap
    heatmap_data = combined_df.set_index(['model', 'dataset'])[['WER', 'CER', 'MER', 'WIL', 'WIP', 'SER', 'RTF']]
    
    sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='RdYlGn_r', 
                cbar_kws={'label': 'Metric Value'}, ax=ax, linewidths=0.5)
    
    ax.set_title('Comprehensive Metrics Heatmap - All Models & Datasets', fontsize=16, fontweight='bold')
    ax.set_xlabel('Metric', fontsize=13)
    ax.set_ylabel('Model + Dataset', fontsize=13)
    
    plt.tight_layout()
    plt.savefig(plots_dir / 'comprehensive_metrics_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] Comprehensive heatmap saved")

In [None]:
# Cell 13: Error type breakdown
if all_results and 'insertions' in combined_df.columns:
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Average error types by model
    error_cols = ['insertions', 'deletions', 'substitutions']
    error_avg = combined_df.groupby('model')[error_cols].mean()
    
    error_avg.plot(kind='bar', stacked=True, ax=ax, 
                   color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
    
    ax.set_title('Error Type Breakdown by Model', fontsize=16, fontweight='bold')
    ax.set_xlabel('Model', fontsize=13)
    ax.set_ylabel('Average Error Count', fontsize=13)
    ax.legend(title='Error Type', labels=['Insertions', 'Deletions', 'Substitutions'])
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(plots_dir / 'error_breakdown.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("[OK] Error breakdown plot saved")

## 5. Statistical Analysis

In [None]:
# Cell 14: Statistical significance testing
if all_results and len(combined_df['model_family'].unique()) >= 2:
    print("[INFO] Statistical Significance Testing (ANOVA)")
    print("="*60)
    
    # ANOVA for WER across model families
    families = [group['WER'].values for name, group in combined_df.groupby('model_family')]
    f_stat, p_value = stats.f_oneway(*families)
    
    print(f"\nWER across Model Families:")
    print(f"  F-statistic: {f_stat:.4f}")
    print(f"  P-value: {p_value:.6f}")
    
    if p_value < 0.05:
        print(f"  Result: Statistically significant difference (p < 0.05)")
    else:
        print(f"  Result: No significant difference (p >= 0.05)")
    
    # Correlation analysis
    print(f"\n\n[CHART] Metric Correlations:")
    metric_cols = ['WER', 'CER', 'MER', 'WIL', 'WIP', 'SER', 'RTF']
    correlation = combined_df[metric_cols].corr()
    print(correlation.to_string())

## 6. Production Recommendations

In [None]:
# Cell 15: Production deployment recommendations
if all_results:
    print("[TARGET] PRODUCTION DEPLOYMENT RECOMMENDATIONS")
    print("="*60)
    
    # Best for accuracy (average WER)
    model_wer_avg = combined_df.groupby('model')['WER'].mean().sort_values()
    best_accuracy = model_wer_avg.index[0]
    best_accuracy_wer = model_wer_avg.iloc[0]
    
    # Best for speed (average RTF)
    model_rtf_avg = combined_df.groupby('model')['RTF'].mean().sort_values()
    best_speed = model_rtf_avg.index[0]
    best_speed_rtf = model_rtf_avg.iloc[0]
    
    # Best balanced (WER * RTF score)
    model_scores = combined_df.groupby('model').agg({'WER': 'mean', 'RTF': 'mean'})
    model_scores['balance_score'] = model_scores['WER'] * model_scores['RTF']
    best_balanced = model_scores['balance_score'].idxmin()
    balanced_wer = model_scores.loc[best_balanced, 'WER']
    balanced_rtf = model_scores.loc[best_balanced, 'RTF']
    
    print(f"\n1. BEST FOR ACCURACY (Lowest WER):")
    print(f"   Recommendation: {best_accuracy}")
    print(f"   Average WER: {best_accuracy_wer:.4f}")
    print(f"   Use case: Offline transcription, high accuracy requirements")
    
    print(f"\n2. BEST FOR SPEED (Lowest RTF):")
    print(f"   Recommendation: {best_speed}")
    print(f"   Average RTF: {best_speed_rtf:.4f}")
    print(f"   Use case: Real-time transcription, latency-sensitive applications")
    
    print(f"\n3. BEST BALANCED (Speed + Accuracy):")
    print(f"   Recommendation: {best_balanced}")
    print(f"   Average WER: {balanced_wer:.4f}")
    print(f"   Average RTF: {balanced_rtf:.4f}")
    print(f"   Use case: General-purpose transcription")
    
    # Dataset-specific recommendations
    print(f"\n4. DATASET-SPECIFIC RECOMMENDATIONS:")
    for dataset in combined_df['dataset'].unique():
        dataset_df = combined_df[combined_df['dataset'] == dataset]
        best_model = dataset_df.loc[dataset_df['WER'].idxmin(), 'model']
        best_wer = dataset_df['WER'].min()
        print(f"   {dataset}: {best_model} (WER: {best_wer:.4f})")
    
    print("\n" + "="*60)

## 7. Generate Final Report

In [None]:
# Cell 16: Save combined results
if all_results:
    # Save combined CSV
    csv_path = OUTPUT_DIR / f"combined_results_{TIMESTAMP}.csv"
    combined_df.to_csv(csv_path, index=False)
    print(f"[OK] Combined results saved: {csv_path}")
    
    # Save summary statistics
    summary_stats = {
        'overall_best_wer': {
            'model': best_wer['model'],
            'dataset': best_wer['dataset'],
            'wer': float(best_wer['WER']),
            'cer': float(best_wer['CER']),
            'rtf': float(best_wer['RTF'])
        },
        'model_family_averages': family_avg.to_dict(),
        'recommendations': {
            'best_accuracy': best_accuracy,
            'best_speed': best_speed,
            'best_balanced': best_balanced
        }
    }
    
    json_path = OUTPUT_DIR / f"summary_statistics_{TIMESTAMP}.json"
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(summary_stats, f, indent=2, ensure_ascii=False)
    print(f"[OK] Summary statistics saved: {json_path}")

In [None]:
# Cell 17: Generate comprehensive markdown report
if all_results:
    report_generator = ReportGenerator(reports_dir=PATHS['reports_dir'])
    
    report_data = {
        'models': combined_df['model'].unique().tolist(),
        'datasets': combined_df['dataset'].unique().tolist(),
        'metrics_summary': {i: row.to_dict() for i, row in combined_df.iterrows()},
        'best_model': {
            'model_name': best_wer['model'],
            'dataset': best_wer['dataset'],
            'WER': best_wer['WER'],
            'CER': best_wer['CER'],
            'RTF': best_wer['RTF']
        }
    }
    
    report_path = report_generator.generate_model_report(
        model_family="Cross-Model_Comparison",
        results=report_data,
        output_filename=f"Báo_cáo_Tổng_hợp_{TIMESTAMP}.md"
    )
    
    print(f"\n[OK] Comprehensive report generated: {report_path}")

## 8. Summary

In [None]:
# Cell 18: Final summary
if all_results:
    print("\n" + "="*60)
    print("[OK] CROSS-MODEL COMPARISON COMPLETE")
    print("="*60)
    print(f"\n[INFO] Generated outputs:")
    print(f"  1. Combined results CSV: {csv_path}")
    print(f"  2. Summary statistics JSON: {json_path}")
    print(f"  3. Comprehensive report: {report_path}")
    print(f"  4. Visualizations: {plots_dir}/")
    print(f"\n[NOTE] All files saved in: {OUTPUT_DIR}")
    print(f"[NOTE] Reports saved in: {PATHS['reports_dir']}")
    print("\n" + "="*60)
else:
    print("\n[WARNING] Please run notebooks 01-04 first to generate evaluation results.")