# 🇻🇳 Vietnamese GEC with Contrastive Learning - Google Colab

**Clean & Simple**: Clone repository and run training pipeline for Vietnamese Grammatical Error Correction with BARTpho/ViT5 + Contrastive Learning.

## 📋 Pipeline Overview:
1. **Setup & Clone Repository** - Install dependencies and clone source code
2. **Data Preparation** - Load and preprocess viGEC dataset  
3. **Base Model Training** - Fine-tune BARTpho/ViT5 with hyperparameter optimization
4. **Negative Sample Generation** - Generate negative samples for contrastive learning
5. **Contrastive Learning Training** - Train with contrastive loss + R-Drop
6. **Inference & Evaluation** - Test and evaluate the model

⏰ **Estimated Total Time**: 4-9 hours (depending on GPU)  
🚀 **Ready to Run**: All import issues fixed, clean codebase

## 🚀 Step 1: Setup and Clone Repository

In [None]:
# Check GPU availability
import torch
print(f"🔥 CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU available - training will be very slow!")

In [None]:
# Install required packages including comprehensive metrics
print("📦 Installing dependencies...")
!pip install numpy
!pip3 install torch torchaudio torchvision torchtext torchdata
!pip install transformers datasets accelerate
!pip install optuna  wandb lightning
!pip install sentencepiece tokenizers nltk sacrebleu evaluate rouge-score
!pip install pandas scikit-learn tqdm rich omegaconf hydra-core
!pip install underthesea pyvi ipywidgets matplotlib seaborn
!pip install -U datasets huggingface_hub fsspec
!pip install optuna-integration[pytorch_lightning]

print("✅ All packages installed successfully!")
print("🎯 Comprehensive metrics available:")
print("   • F0.5, Precision, Recall (Edit-level)")
print("   • BLEU, GLEU (Translation metrics)")
print("   • ROUGE-1, ROUGE-2, ROUGE-L (Token overlap)")
print("   • Input-preserving Edit Ratio")

In [None]:
# Clone the repository (replace with your actual GitHub repository URL)
import os

# Change this to your actual repository URL
REPO_URL = "https://github.com/YOUR_USERNAME/CL_GEC.git"  # Update this!
PROJECT_DIR = "/content/CL_GEC"

# Clone or update repository
if not os.path.exists(PROJECT_DIR):
    print(f"📥 Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL} {PROJECT_DIR}
else:
    print("📁 Repository already exists, pulling latest changes...")
    %cd {PROJECT_DIR}
    !git pull

# Change to project directory
%cd {PROJECT_DIR}
print(f"📂 Working directory: {os.getcwd()}")

# List files to verify
print("\n📋 Project files:")
!ls -la *.py

## 📊 Step 2: Data Preparation and System Check

## 🔧 New Features & Parameters

### ✨ Enhanced BaseTrainer Features:

1. **📊 Dataset Configuration**:
   - `dataset_name`: Choose dataset version (e.g., "phuhuy-se1/viGEC-v2")
   - `train_subset_ratio`: Use subset of training data (0.0-1.0) 
   - `validation_subset_ratio`: Use subset of validation data
   - `test_subset_ratio`: Use subset of test data

2. **🔍 Customizable Search Space**:
   - Define learning rate ranges
   - Configure weight decay options
   - Set batch size choices
   - Customize warmup ratios

3. **⚡ Flexible Training Modes**:
   - Hyperparameter optimization only
   - Training with specific parameters
   - Combined optimization + training

### 💡 Benefits:
- **Faster experimentation** with data subsets
- **Better hyperparameter control** 
- **Dataset version management**
- **Memory-efficient training** for limited resources

In [None]:
# Test imports and system readiness
from data_utils import (
    load_vigec_dataset, 
    get_model_and_tokenizer, 
    can_train_base_model,
    check_dataset_format
)
from rich.console import Console

console = Console()

# Check system readiness
console.print("[bold blue]🔍 System Readiness Check[/bold blue]")
system_ready = can_train_base_model()

# Check dataset format
console.print("\n[bold blue]📋 Dataset Format Check[/bold blue]")
dataset_ready = check_dataset_format()

if system_ready and dataset_ready:
    console.print("\n[bold green]✅ All checks passed! Ready to proceed.[/bold green]")
else:
    console.print("\n[bold red]❌ System not ready. Please check requirements.[/bold red]")

In [None]:
# Multi-GPU setup check and optimization
console.print("[bold blue]🚀 Multi-GPU Setup Check[/bold blue]")

# Check multi-GPU availability
device_count = torch.cuda.device_count()
if device_count > 1:
    console.print(f"[green]✅ {device_count} GPUs detected![/green]")
    for i in range(device_count):
        gpu_name = torch.cuda.get_device_name(i)
        memory_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
        console.print(f"  GPU {i}: {gpu_name} ({memory_gb:.1f}GB)")
    
    console.print(f"[yellow]💡 Multi-GPU training will provide ~{device_count * 0.85:.1f}x speedup[/yellow]")
    console.print(f"[blue]📊 Recommended batch size adjustment for {device_count} GPUs[/blue]")
    
    # Test multi-GPU functionality
    try:
        from base_trainer import test_multi_gpu
        console.print("[green]✅ Multi-GPU support verified[/green]")
    except ImportError:
        console.print("[yellow]⚠️  Multi-GPU test not available, but should work fine[/yellow]")
        
else:
    console.print(f"[blue]ℹ️  Single GPU training: {torch.cuda.get_device_name()}[/blue]")

console.print("\n[green]🔧 System optimized for available hardware[/green]")

In [None]:
# Load and prepare dataset with configurable parameters
console.print("[bold blue]📊 Loading viGEC Dataset[/bold blue]")

# Dataset configuration - modify these as needed
DATASET_CONFIG = {
    "dataset_name": "phuhuy-se1/viGEC-v2",  # Change to "phuhuy-se1/viGEC-v2" for version 2
    "train_subset_ratio": 1.0,  # Use 10% of training data for faster processing in Colab
    "validation_subset_ratio": 1.0,  # Use 20% of validation data  
    "test_subset_ratio": 0.5   # Use 5% of test data for faster evaluation
}

console.print(f"[yellow]📋 Dataset Configuration:[/yellow]")
for key, value in DATASET_CONFIG.items():
    console.print(f"  {key}: {value}")

# Load dataset with configurable parameters
data = load_vigec_dataset(
    dataset_name=DATASET_CONFIG["dataset_name"],
    train_subset_ratio=DATASET_CONFIG["train_subset_ratio"],
    validation_subset_ratio=DATASET_CONFIG["validation_subset_ratio"],
    test_subset_ratio=DATASET_CONFIG["test_subset_ratio"]
)

console.print(f"\n[green]Dataset loaded successfully![/green]")
for split, split_data in data.items():
    console.print(f"  {split}: {len(split_data)} samples")
    
# Show subset ratios effect
console.print(f"\n[blue]📊 Subset Effects:[/blue]")
console.print(f"  Training samples: ~{len(data['train'])} (subset ratio: {DATASET_CONFIG['train_subset_ratio']})")
console.print(f"  Validation samples: ~{len(data['validation'])} (subset ratio: {DATASET_CONFIG['validation_subset_ratio']})")
console.print(f"  Test samples: ~{len(data['test'])} (subset ratio: {DATASET_CONFIG['test_subset_ratio']})")

# Save processed data
from data_utils import save_processed_data
save_processed_data(data, "./data/processed")
console.print("\n[blue]✅ Data saved to ./data/processed/[/blue]")

## 🤖 Step 3: Model Selection and Testing

In [None]:
# Choose your model - uncomment one of these:
MODEL_NAME = "vinai/bartpho-syllable"  # Recommended for Vietnamese
# MODEL_NAME = "VietAI/vit5-base"     # Alternative option
# MODEL_NAME = "VietAI/vit5-large"    # Larger model (requires more GPU memory)

console.print(f"[bold blue]🤖 Loading Model: {MODEL_NAME}[/bold blue]")

# Load model and tokenizer
model, tokenizer = get_model_and_tokenizer(MODEL_NAME)

console.print(f"[green]✅ Model loaded successfully![/green]")
console.print(f"  Model: {model.__class__.__name__}")
console.print(f"  Tokenizer: {tokenizer.__class__.__name__}")
console.print(f"  Vocabulary size: {len(tokenizer)}")

# Test tokenization
test_text = "Tôi đang học tiếng việt."
tokens = tokenizer(test_text, return_tensors="pt")
console.print(f"\n[blue]🧪 Tokenization Test:[/blue]")
console.print(f"  Input: {test_text}")
console.print(f"  Tokens: {tokens['input_ids'].shape}")

## 🏋️ Step 4: Base Model Training

In [None]:
# Configure training parameters
TRAINING_CONFIG = {
    "model_name": MODEL_NAME,
    "output_dir": "./models/base_model",
    "max_epochs": 3,  # Reduced for Colab
    "batch_size": 8,  # Adjust based on GPU memory
    "use_wandb": True,  # Set to False if you don't want to use Weights & Biases
    "run_optimization": False,  # Set to True for hyperparameter optimization (takes longer)
    
    # New dataset parameters
    "dataset_name": "phuhuy-se1/viGEC",  # Change to "phuhuy-se1/viGEC-v2" for version 2
    "train_subset_ratio": 0.1,  # Use 10% of training data for faster training in Colab
    "validation_subset_ratio": 0.2,  # Use 20% of validation data
    "test_subset_ratio": 0.05,  # Use 5% of test data
    
    # Custom search space for hyperparameter optimization (if enabled)
    "search_space": {
        'learning_rate': {'low': 1e-5, 'high': 5e-4, 'log': True},
        'weight_decay': {'low': 0.001, 'high': 0.05, 'log': True},
        'label_smoothing': {'low': 0.0, 'high': 0.2},
        'batch_size': [8, 16, 24],  # Smaller batch sizes for Colab
        'warmup_ratio': {'low': 0.05, 'high': 0.15}
    }
}

console.print("[bold blue]🏋️ Base Model Training Configuration:[/bold blue]")
for key, value in TRAINING_CONFIG.items():
    if key != "search_space":  # Don't print the search space dict for brevity
        console.print(f"  {key}: {value}")

if TRAINING_CONFIG["run_optimization"]:
    console.print("\n[yellow]⚠️ Hyperparameter optimization enabled - this will take longer but may improve results[/yellow]")
    console.print(f"[blue]Search space configured with {len(TRAINING_CONFIG['search_space'])} parameters[/blue]")
else:
    console.print("\n[blue]ℹ️ Using default parameters for faster training[/blue]")

In [None]:
# Start base model training with improved error handling
from base_trainer import BaseTrainer

console.print("[bold green]🚀 Starting Base Model Training...[/bold green]")

# Check and adjust batch size for multi-GPU
original_batch_size = TRAINING_CONFIG["batch_size"]
device_count = torch.cuda.device_count()

if device_count > 1:
    # For multi-GPU, adjust batch size
    adjusted_batch_size = max(1, original_batch_size // device_count)
    console.print(f"[yellow]📊 Multi-GPU batch size adjustment:[/yellow]")
    console.print(f"  Original batch size: {original_batch_size}")
    console.print(f"  Per-GPU batch size: {adjusted_batch_size}")
    console.print(f"  Total effective batch size: {adjusted_batch_size * device_count}")
    final_batch_size = adjusted_batch_size
else:
    final_batch_size = original_batch_size
    console.print(f"[blue]📊 Single GPU batch size: {final_batch_size}[/blue]")

# Create base trainer with enhanced parameters
try:
    base_trainer = BaseTrainer(
        model_name=TRAINING_CONFIG["model_name"],
        data_dir="./data/processed",  # Use the processed data directory
        output_dir=TRAINING_CONFIG["output_dir"],
        hyperopt=TRAINING_CONFIG["run_optimization"],  # Enable/disable hyperopt
        use_wandb=TRAINING_CONFIG["use_wandb"],
        
        # New dataset parameters
        dataset_name=TRAINING_CONFIG["dataset_name"],
        train_subset_ratio=TRAINING_CONFIG["train_subset_ratio"],
        validation_subset_ratio=TRAINING_CONFIG["validation_subset_ratio"],
        test_subset_ratio=TRAINING_CONFIG["test_subset_ratio"]
    )
    
    console.print("[green]✅ BaseTrainer initialized successfully[/green]")
    
except Exception as e:
    console.print(f"[red]❌ Error initializing BaseTrainer: {e}[/red]")
    console.print("[yellow]💡 Trying with fallback configuration...[/yellow]")
    
    # Fallback configuration
    base_trainer = BaseTrainer(
        model_name=TRAINING_CONFIG["model_name"],
        output_dir=TRAINING_CONFIG["output_dir"],
        hyperopt=False,  # Disable hyperopt for fallback
        use_wandb=False  # Disable wandb for fallback
    )

# Train the model with error handling
try:
    # Check for multi-GPU setup and adjust batch size accordingly
    device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
    console.print(f"[blue]🖥️  Available GPUs: {device_count}[/blue]")
    
    # Adjust batch size for multi-GPU if needed
    if device_count > 1:
        # Each GPU gets a portion of the batch
        per_gpu_batch_size = max(1, final_batch_size // device_count)
        total_effective_batch_size = per_gpu_batch_size * device_count
        console.print(f"[yellow]🔄 Multi-GPU detected: Using {per_gpu_batch_size} batch size per GPU[/yellow]")
        console.print(f"[yellow]   Total effective batch size: {total_effective_batch_size}[/yellow]")
        final_batch_size = per_gpu_batch_size
    
    if ENABLE_HYPEROPT:
        console.print("[yellow]🔬 Running hyperparameter optimization (this will take longer)...[/yellow]")
        study = base_trainer.optimize_hyperparameters(
            n_trials=N_TRIALS,
            max_epochs=TRAINING_CONFIG["max_epochs"],
            base_batch_size=final_batch_size
        )
        
        if study is not None:
            console.print(f"[green]✅ Best parameters: {study.best_params}[/green]")
            console.print(f"[green]✅ Best F0.5 score: {study.best_value:.4f}[/green]")
            
            # Train final model with best parameters  
            console.print("[blue]🏃 Training final model with best parameters...[/blue]")
            trained_model = base_trainer.train_with_params(
                params=study.best_params,
                max_epochs=TRAINING_CONFIG["max_epochs"],
                batch_size=study.best_params.get('batch_size', final_batch_size)
            )
        else:
            console.print("[yellow]⚠️  Hyperparameter optimization failed, using default training[/yellow]")
            trained_model = base_trainer.train(
                max_epochs=TRAINING_CONFIG["max_epochs"],
                batch_size=final_batch_size
            )
    else:
        console.print("[blue]🏃 Training with default parameters...[/blue]")
        
        # Train the model (hyperopt is controlled by the hyperopt parameter in constructor)
        trained_model = base_trainer.train(
            max_epochs=TRAINING_CONFIG["max_epochs"],
            batch_size=final_batch_size,
            search_space=None  # No search space needed for default training
        )

    console.print("[bold green]✅ Base model training completed![/bold green]")
    
except Exception as e:
    error_msg = str(e)
    console.print(f"[red]❌ Training failed: {error_msg}[/red]")
    
    # Specific error handling for common issues
    if "find_unused_parameters" in error_msg:
        console.print("[yellow]🔧 Trainer argument error detected - this has been fixed in latest code[/yellow]")
        console.print("[blue]💡 The error is caused by DDP strategy arguments being passed to Trainer[/blue]")
        console.print("[green]✅ Latest code filters these arguments automatically[/green]")
    elif "CUDA out of memory" in error_msg:
        console.print("[yellow]💾 GPU memory error - try reducing batch size[/yellow]")
        console.print(f"[blue]Current batch size: {final_batch_size}[/blue]")
        console.print("[blue]💡 Try setting final_batch_size = 2 or 1[/blue]")
    elif "No module named" in error_msg:
        console.print("[yellow]📦 Missing dependency - check package installation[/yellow]")
        console.print("[blue]💡 Try re-running the pip install cell above[/blue]")
    else:
        console.print("[yellow]💡 This might be due to memory constraints or configuration issues[/yellow]")
        console.print("[blue]🔧 Try reducing batch size or disabling hyperopt[/blue]")
    
    # Show the error for debugging
    import traceback
    console.print(f"[red]Error details: {traceback.format_exc()}[/red]")

## 🎯 Step 5: Negative Sample Generation

In [None]:
# Generate negative samples for contrastive learning
from negative_sampler import NegativeSampler
import os

console.print("[bold blue]🎯 Generating Negative Samples...[/bold blue]")

# Create negative sampler (use the final model from training)
base_model_path = os.path.join(TRAINING_CONFIG["output_dir"], "final_model")

# Check if trained model exists
if os.path.exists(base_model_path):
    console.print(f"[green]✅ Using trained model from {base_model_path}[/green]")
    model_path = base_model_path
else:
    console.print(f"[yellow]⚠️ Trained model not found, using base model {MODEL_NAME}[/yellow]")
    model_path = MODEL_NAME

negative_sampler = NegativeSampler(
    model_path=model_path,
    model_name=MODEL_NAME
)

# Generate negative samples for training data
# Use smaller subset for Colab to avoid memory issues
train_subset = data['train'][:1000] if len(data['train']) > 1000 else data['train']

contrastive_data = negative_sampler.generate_contrastive_dataset(
    data=train_subset,
    num_negatives=3,  # Generate 3 negative samples per positive
    output_file="./data/contrastive_train.json"
)

console.print(f"[green]✅ Generated {len(contrastive_data)} contrastive samples![/green]")
console.print("[blue]💾 Saved to ./data/contrastive_train.json[/blue]")

## 🔥 Step 6: Contrastive Learning Training

In [None]:
# Contrastive learning training
from contrastive_trainer import ContrastiveTrainer
import os
import json
import shutil

console.print("[bold blue]🔥 Starting Contrastive Learning Training...[/bold blue]")

# First, we need to prepare the contrastive data in the expected format
contrastive_data_dir = "./data/contrastive"
os.makedirs(contrastive_data_dir, exist_ok=True)

# Convert the contrastive data to the expected format for validation
validation_contrastive = []
for item in data['validation'][:200]:  # Use subset for validation
    validation_contrastive.append({
        'source': item['source'],
        'positive': item['target'],
        'negatives': [item['source']]  # Simple negative sample
    })

# Save validation data
with open(os.path.join(contrastive_data_dir, "validation_contrastive.json"), "w", encoding="utf-8") as f:
    json.dump(validation_contrastive, f, indent=2, ensure_ascii=False)

# Copy training contrastive data to the expected location
if os.path.exists("./data/contrastive_train.json"):
    shutil.copy("./data/contrastive_train.json", 
                os.path.join(contrastive_data_dir, "train_contrastive.json"))

# Create contrastive trainer
contrastive_trainer = ContrastiveTrainer(
    base_model_path=os.path.join(TRAINING_CONFIG["output_dir"], "final_model"),
    contrastive_data_dir=contrastive_data_dir,
    output_dir="./models/contrastive_model",
    hyperopt=False  # Disable hyperopt for faster training in Colab
)

# Train with contrastive learning
contrastive_trainer.train()

console.print("[bold green]✅ Contrastive learning training completed![/bold green]")

## 🧪 Step 7: Inference and Evaluation

In [None]:
# Load the best model for inference with improved handling
from inference import GECInference
import os

console.print("[bold blue]🧪 Setting up Inference...[/bold blue]")

# Determine which model to use for inference
contrastive_model_path = "./models/contrastive_model"
base_model_path = TRAINING_CONFIG["output_dir"]

# Check for final_model subdirectory
final_model_path = os.path.join(base_model_path, "final_model")

if os.path.exists(contrastive_model_path) and os.listdir(contrastive_model_path):
    model_path = contrastive_model_path
    console.print(f"[green]✅ Using contrastive model from {model_path}[/green]")
elif os.path.exists(final_model_path) and os.listdir(final_model_path):
    model_path = final_model_path
    console.print(f"[yellow]⚠️ Using base model (final) from {model_path}[/yellow]")
elif os.path.exists(base_model_path) and os.listdir(base_model_path):
    model_path = base_model_path
    console.print(f"[yellow]⚠️ Using base model from {model_path}[/yellow]")
else:
    model_path = MODEL_NAME
    console.print(f"[blue]ℹ️ Using original model {model_path}[/blue]")

# Debug: Check what's in the model directory
if model_path != MODEL_NAME:
    console.print(f"[dim]📁 Model directory contents:[/dim]")
    try:
        for item in os.listdir(model_path):
            console.print(f"  - {item}")
    except:
        console.print("  [red]Could not list directory[/red]")

# Create inference engine with proper error handling
try:
    console.print(f"[yellow]🔄 Loading inference engine from {model_path}...[/yellow]")
    
    # Fixed: Remove model_name parameter which doesn't exist in GECInference
    gec_inference = GECInference(
        model_path=model_path,
        use_contrastive_search=False,  # Disable for now to match training
        device="auto"
    )
    
    console.print("[green]✅ Inference engine ready![/green]")
    
    # Test the inference with a simple example
    test_input = "Tôi đang học tiếng việt."
    console.print(f"[blue]🧪 Testing inference:[/blue]")
    console.print(f"  Input: {test_input}")
    
    try:
        test_output = gec_inference.correct_text(test_input)
        console.print(f"  Output: {test_output}")
        console.print("[green]✅ Inference test successful![/green]")
    except Exception as e:
        console.print(f"[red]❌ Inference test failed: {e}[/red]")
        console.print("[yellow]💡 This might indicate a model loading issue[/yellow]")
    
    # Check if model has task prefix (for ViT5 models)
    if hasattr(gec_inference.tokenizer, 'task_prefix'):
        console.print(f"[blue]🏷️  Task prefix detected: '{gec_inference.tokenizer.task_prefix}'[/blue]")
    elif 'vit5' in MODEL_NAME.lower() or 'mt5' in MODEL_NAME.lower():
        # Manually set task prefix for ViT5/mT5 models
        gec_inference.tokenizer.task_prefix = "grammar: "
        console.print(f"[yellow]🏷️  Manually set task prefix: '{gec_inference.tokenizer.task_prefix}'[/yellow]")
    else:
        console.print("[blue]ℹ️  No task prefix needed for this model[/blue]")

except Exception as e:
    console.print(f"[red]❌ Failed to load inference engine: {e}[/red]")
    console.print("[yellow]💡 Falling back to original model...[/yellow]")
    
    # Fixed: Remove model_name parameter in fallback too
    try:
        gec_inference = GECInference(
            model_path=MODEL_NAME,
            use_contrastive_search=False,
            device="auto"
        )
        console.print("[yellow]⚠️ Using fallback original model[/yellow]")
    except Exception as fallback_error:
        console.print(f"[red]❌ Fallback also failed: {fallback_error}[/red]")
        console.print("[red]💥 Cannot proceed with inference. Check model and dependencies.[/red]")
        
        # Show detailed error information
        import traceback
        console.print(f"[red]Detailed error: {traceback.format_exc()}[/red]")

In [None]:
# Debug: Investigate training vs evaluation discrepancy
console.print("[bold red]🔍 DEBUGGING: Training vs Evaluation Discrepancy[/bold red]")

console.print("[yellow]📊 Analysis of the 91% vs 34% discrepancy:[/yellow]")
console.print("1. Training loss ≠ F0.5 evaluation metric")
console.print("2. Different preprocessing during training vs inference")
console.print("3. Task prefix issues with ViT5/mT5 models")
console.print("4. Tokenization differences")

# Test 1: Check training data format vs inference format
console.print("\n[blue]🧪 Test 1: Data Format Comparison[/blue]")
if 'data' in locals():
    sample_source = data['validation'][0]['source']
    sample_target = data['validation'][0]['target']
    
    console.print(f"Training source: '{sample_source}'")
    console.print(f"Training target: '{sample_target}'")
    
    # Test how inference processes this
    inference_result = gec_inference.correct_text(sample_source)
    console.print(f"Inference result: '{inference_result}'")
    
    # Check if task prefix was used during training
    if hasattr(gec_inference.tokenizer, 'task_prefix'):
        prefix = gec_inference.tokenizer.task_prefix
        console.print(f"Task prefix: '{prefix}'")
        
        # Test with manual prefix
        manual_prefix_input = prefix + sample_source
        console.print(f"Manual prefix input: '{manual_prefix_input}'")
        
        # Check tokenization
        tokens_without_prefix = gec_inference.tokenizer(sample_source, return_tensors="pt")
        tokens_with_prefix = gec_inference.tokenizer(manual_prefix_input, return_tensors="pt")
        
        console.print(f"Tokens without prefix length: {tokens_without_prefix['input_ids'].shape}")
        console.print(f"Tokens with prefix length: {tokens_with_prefix['input_ids'].shape}")

# Test 2: Check evaluation metric calculation
console.print("\n[blue]🧪 Test 2: Evaluation Metric Check[/blue]")
from evaluator import F05Evaluator

# Create evaluator
eval_test = F05Evaluator()

# Test on a simple example where we know the answer
test_source = "Tôi đang học tiếng việt"
test_target = "Tôi đang học tiếng Việt"  # Simple capitalization fix
test_prediction = "Tôi đang học tiếng Việt"  # Perfect prediction

f05_perfect = eval_test.calculate_f05(test_source, test_prediction, test_target)
console.print(f"Perfect prediction F0.5: {f05_perfect:.4f}")

# Test with a wrong prediction
test_prediction_wrong = "Tôi đang học tiếng anh"  # Wrong prediction
f05_wrong = eval_test.calculate_f05(test_source, test_prediction_wrong, test_target)
console.print(f"Wrong prediction F0.5: {f05_wrong:.4f}")

# Test with no change (common issue)
test_prediction_no_change = test_source  # No change
f05_no_change = eval_test.calculate_f05(test_source, test_prediction_no_change, test_target)
console.print(f"No change prediction F0.5: {f05_no_change:.4f}")

console.print("\n[yellow]💡 Common issues that cause this discrepancy:[/yellow]")
console.print("1. Model not making enough changes (conservative predictions)")
console.print("2. Task prefix missing during inference")
console.print("3. Different loss function vs evaluation metric")
console.print("4. Model trained on different data format than inference expects")

# Test 3: Loss vs F0.5 relationship
console.print("\n[blue]🧪 Test 3: Understanding the Loss-F0.5 Gap[/blue]")
console.print("Training loss is cross-entropy loss on token predictions")
console.print("F0.5 measures edit-level precision/recall at word level")
console.print("A model can have low token-level loss but poor edit-level performance")

console.print("\n[red]🎯 RECOMMENDATIONS TO FIX:[/red]")
console.print("1. Ensure task prefix is used consistently")
console.print("2. Check model is actually making corrections (not just copying)")
console.print("3. Verify training data preprocessing matches inference")
console.print("4. Consider using F0.5 as training metric instead of cross-entropy")

# Evaluate on test set with consistency checking
from evaluator import F05Evaluator
import numpy as np

console.print("[bold blue]📊 Evaluating on Test Set...[/bold blue]")

# Important note about validation vs evaluation consistency
console.print("[bold yellow]⚠️  Important Note: Validation Consistency[/bold yellow]")
console.print("[blue]If training validation shows high metrics (e.g., 91%) but evaluation shows low metrics (e.g., 34%),[/blue]")
console.print("[blue]this is usually due to differences in:[/blue]")
console.print("[blue]  1. Task prefix handling (ViT5/mT5 models need 'grammar: ' prefix)[/blue]")
console.print("[blue]  2. Generation parameters (num_beams, etc.)[/blue]")
console.print("[blue]  3. Validation subset vs full evaluation[/blue]")
console.print("[green]✅ Latest code fixes these inconsistencies in base_trainer.py[/green]")

# Create evaluator - check if we need to pass tokenizer
try:
    # Try with tokenizer first
    evaluator = F05Evaluator(tokenizer=gec_inference.tokenizer)
except:
    # Fallback to no tokenizer
    evaluator = F05Evaluator()

# Evaluate on test set (using subset for faster evaluation)
test_data_subset = data['test'][:100]  # Use 100 samples for evaluation
sources = [item['source'] for item in test_data_subset]
references = [item['target'] for item in test_data_subset]

# Generate predictions
console.print("[yellow]🔮 Generating predictions...[/yellow]")
predictions = []
for i, source in enumerate(sources):
    if i % 20 == 0:  # Progress indicator
        console.print(f"[blue]Processing {i+1}/{len(sources)}...[/blue]")
    pred = gec_inference.correct_text(source)
    predictions.append(pred)

# Calculate metrics
console.print("[yellow]📈 Calculating metrics...[/yellow]")
try:
    # Try batch evaluation first
    results = evaluator.evaluate_batch(predictions, references, sources)
except AttributeError:
    # Fallback to individual evaluation
    f05_scores = []
    for pred, ref, src in zip(predictions, references, sources):
        f05 = evaluator.calculate_f05(src, pred, ref)
        f05_scores.append(f05)
    
    results = {
        "f05_score": np.mean(f05_scores),
        "num_samples": len(f05_scores)
    }

console.print("\n[bold green]📈 Evaluation Results:[/bold green]")
for metric, value in results.items():
    if isinstance(value, float):
        console.print(f"  {metric}: {value:.4f}")
    else:
        console.print(f"  {metric}: {value}")

# Validation consistency check
avg_f05 = results.get("f05_score", 0.0)
console.print(f"\n[bold cyan]🔍 Consistency Analysis:[/bold cyan]")
console.print(f"[blue]Evaluation F0.5 Score: {avg_f05:.4f} ({avg_f05*100:.1f}%)[/blue]")

if avg_f05 < 0.5:  # Less than 50%
    console.print("[yellow]⚠️  Low evaluation score detected![/yellow]")
    console.print("[blue]💡 If training validation showed much higher scores, this suggests:[/blue]")
    console.print("   1. Task prefix was missing during training validation")
    console.print("   2. Generation parameters differed between training and inference") 
    console.print("   3. Training validation used subset, evaluation uses different data")
    console.print("[green]✅ Use the updated base_trainer.py for consistent metrics[/green]")
elif avg_f05 > 0.8:  # Greater than 80%
    console.print("[green]✅ Good evaluation score! Training validation was likely accurate.[/green]")
else:
    console.print("[yellow]📊 Moderate evaluation score. Check if this matches training validation.[/yellow]")

# Show some examples
console.print("\n[bold cyan]🔍 Sample Results:[/bold cyan]")
for i in range(min(5, len(sources))):
    f05_single = evaluator.calculate_f05(sources[i], predictions[i], references[i])
    console.print(f"\n{i+1}. Source: {sources[i]}")
    console.print(f"   Target: {references[i]}")
    console.print(f"   Prediction: {predictions[i]}")
    console.print(f"   F0.5 Score: {f05_single:.4f}")

# Additional debugging option
console.print("\n[bold blue]🔧 Debug Option:[/bold blue]")
console.print("[yellow]To check validation consistency, run:[/yellow]")
console.print("[dim]python validation_consistency_checker.py[/dim]")

In [None]:
# Interactive testing
console.print("[bold blue]🎮 Interactive Testing[/bold blue]")

# Test samples
test_sentences = [
    "Tôi đang học tiếng việt ở trường đại học.",
    "Hôm nay trời rất đẹp và tôi muốn đi chơi.",
    "Cô ấy làm việc tại một công ty lớn ở Hà Nội.",
    "Chúng tôi sẽ đi du lịch vào cuối tuần này."
]

console.print("\n[yellow]📝 Test Results:[/yellow]")
for i, sentence in enumerate(test_sentences, 1):
    corrected = gec_inference.correct_text(sentence)
    console.print(f"\n{i}. Original: {sentence}")
    console.print(f"   Corrected: {corrected}")

# Custom input
console.print("\n[bold cyan]✏️ Try your own text:[/bold cyan]")
print("Enter Vietnamese text to correct (or 'quit' to exit):")

while True:
    user_input = input("> ")
    if user_input.lower() == 'quit':
        break
    
    corrected = gec_inference.correct_text(user_input)
    print(f"Corrected: {corrected}\n")

In [None]:
# Comprehensive evaluation with all metrics
from evaluator import GECEvaluator
import numpy as np

console.print("[bold blue]📊 Comprehensive Evaluation with All Metrics[/bold blue]")

console.print("[yellow]🔍 Metrics to be calculated:[/yellow]")
console.print("  • F0.5 (Edit-level, precision-weighted)")
console.print("  • Precision & Recall (Edit-level)")  
console.print("  • BLEU (Traditional n-gram overlap)")
console.print("  • GLEU (Better for GEC tasks)")
console.print("  • ROUGE-1, ROUGE-2, ROUGE-L (Token overlap)")
console.print("  • Input-preserving Edit Ratio")

# Create comprehensive evaluator
evaluator = GECEvaluator(tokenizer=gec_inference.tokenizer)

# Evaluate on test set (using subset for faster evaluation)
test_data_subset = data['test'][:100]  # Use 100 samples for evaluation
sources = [item['source'] for item in test_data_subset]
references = [item['target'] for item in test_data_subset]

console.print(f"[blue]📋 Evaluating on {len(sources)} samples...[/blue]")

# Generate predictions
console.print("[yellow]🔮 Generating predictions...[/yellow]")
predictions = []
for i, source in enumerate(sources):
    if i % 20 == 0:  # Progress indicator
        console.print(f"[blue]Processing {i+1}/{len(sources)}...[/blue]")
    pred = gec_inference.correct_text(source)
    predictions.append(pred)

# Calculate all metrics
console.print("[yellow]📈 Calculating comprehensive metrics...[/yellow]")
comprehensive_results = evaluator.calculate_all_metrics(
    sources=sources,
    predictions=predictions,
    targets=references,
    print_results=True  # This will print a nice table
)

# Additional analysis
console.print("\n[bold cyan]🔍 Detailed Analysis:[/bold cyan]")

# Show distribution of F0.5 scores
from evaluator import F05Evaluator
f05_evaluator = F05Evaluator()
individual_f05_scores = []
for src, pred, ref in zip(sources, predictions, references):
    f05 = f05_evaluator.calculate_f05(src, pred, ref)
    individual_f05_scores.append(f05)

f05_array = np.array(individual_f05_scores)
console.print(f"[blue]📊 F0.5 Score Distribution:[/blue]")
console.print(f"  Mean: {f05_array.mean():.4f}")
console.print(f"  Std:  {f05_array.std():.4f}")
console.print(f"  Min:  {f05_array.min():.4f}")
console.print(f"  Max:  {f05_array.max():.4f}")

# Count perfect predictions
perfect_predictions = sum(1 for f05 in individual_f05_scores if f05 == 1.0)
console.print(f"  Perfect predictions: {perfect_predictions}/{len(individual_f05_scores)} ({perfect_predictions/len(individual_f05_scores)*100:.1f}%)")

# Count no-change predictions
no_change_predictions = sum(1 for src, pred in zip(sources, predictions) if src.strip() == pred.strip())
console.print(f"  No-change predictions: {no_change_predictions}/{len(sources)} ({no_change_predictions/len(sources)*100:.1f}%)")

# Show some examples categorized by performance
console.print("\n[bold cyan]🔍 Sample Results by Performance:[/bold cyan]")

# Sort examples by F0.5 score
examples_with_scores = list(zip(sources, predictions, references, individual_f05_scores))
examples_with_scores.sort(key=lambda x: x[3], reverse=True)

# Best examples
console.print("\n[green]🏆 Best Examples (F0.5 ≥ 0.8):[/green]")
best_examples = [ex for ex in examples_with_scores if ex[3] >= 0.8]
for i, (src, pred, ref, score) in enumerate(best_examples[:3]):
    console.print(f"{i+1}. Source:     {src}")
    console.print(f"   Target:     {ref}")
    console.print(f"   Prediction: {pred}")
    console.print(f"   F0.5:       {score:.4f}")
    console.print()

# Worst examples
console.print("[red]🔍 Challenging Examples (F0.5 ≤ 0.2):[/red]")
worst_examples = [ex for ex in examples_with_scores if ex[3] <= 0.2]
for i, (src, pred, ref, score) in enumerate(worst_examples[:3]):
    console.print(f"{i+1}. Source:     {src}")
    console.print(f"   Target:     {ref}")
    console.print(f"   Prediction: {pred}")
    console.print(f"   F0.5:       {score:.4f}")
    console.print()

# Save comprehensive results
import json
comprehensive_results['individual_f05_scores'] = individual_f05_scores
comprehensive_results['examples'] = {
    'best': [{'source': ex[0], 'prediction': ex[1], 'target': ex[2], 'f05': ex[3]} 
             for ex in best_examples[:5]],
    'worst': [{'source': ex[0], 'prediction': ex[1], 'target': ex[2], 'f05': ex[3]} 
              for ex in worst_examples[:5]]
}

console.print(f"\n[blue]💾 Comprehensive results saved to variable 'comprehensive_results'[/blue]")
console.print("[green]✅ Evaluation completed with all metrics![/green]")

## 💾 Step 8: Save and Export Results

In [None]:
# Save results and create export package
import json
import zipfile
from datetime import datetime

console.print("[bold blue]💾 Saving Results and Creating Export Package...[/bold blue]")

# Create results summary
results_summary = {
    "timestamp": datetime.now().isoformat(),
    "model_name": MODEL_NAME,
    "training_config": TRAINING_CONFIG,
    "evaluation_results": results,
    "test_samples": len(test_data),
    "model_paths": {
        "base_model": TRAINING_CONFIG["output_dir"],
        "contrastive_model": "./models/contrastive_model"
    }
}

# Save results
with open("./results_summary.json", "w", encoding="utf-8") as f:
    json.dump(results_summary, f, indent=2, ensure_ascii=False)

console.print("[green]✅ Results saved to ./results_summary.json[/green]")

# Create downloadable package
console.print("[yellow]📦 Creating export package...[/yellow]")

with zipfile.ZipFile("vietnamese_gec_models.zip", "w", zipfile.ZIP_DEFLATED) as zipf:
    # Add results
    zipf.write("results_summary.json")
    
    # Add model files (if they exist)
    import glob
    for model_file in glob.glob("./models/**/*.bin", recursive=True):
        zipf.write(model_file)
    for config_file in glob.glob("./models/**/config.json", recursive=True):
        zipf.write(config_file)
    
    # Add data samples
    if os.path.exists("./data/contrastive_train.json"):
        zipf.write("./data/contrastive_train.json")

console.print("[bold green]🎉 Export package created: vietnamese_gec_models.zip[/bold green]")
console.print("[blue]📁 You can download this file from the Colab file browser[/blue]")

# Display final summary
console.print("\n[bold cyan]🏆 Training Pipeline Completed Successfully![/bold cyan]")
console.print(f"[green]✅ Base model trained and saved[/green]")
console.print(f"[green]✅ Contrastive learning applied[/green]")
console.print(f"[green]✅ Model evaluated on test set[/green]")
console.print(f"[green]✅ Results exported for download[/green]")