# Vietnamese GEC with Contrastive Learning - Complete Pipeline

This notebook implements the complete pipeline for training Vietnamese Grammatical Error Correction models with Contrastive Learning as described in the research paper.

## Pipeline Overview:
1. **Data Preparation** - Load and preprocess viGEC dataset
2. **Base Model Training** - Fine-tune BARTpho/ViT5 with hyperparameter optimization
3. **Negative Sample Generation** - Generate negative samples for contrastive learning
4. **Contrastive Learning Training** - Train with contrastive loss + R-Drop
5. **Inference** - Use contrastive search for generation
6. **Evaluation** - Comprehensive evaluation with F0.5, BLEU, IE/OE analysis

## 🚀 Setup and Installation

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.36.0 datasets==2.15.0 accelerate==0.25.0
!pip install optuna==3.4.0 wandb==0.16.0 lightning==2.1.0
!pip install sentencepiece tokenizers nltk sacrebleu evaluate rouge-score
!pip install pandas numpy scikit-learn tqdm rich omegaconf hydra-core
!pip install underthesea pyvi ipywidgets matplotlib seaborn

In [None]:
# Clone the repository (if needed)
# !git clone https://github.com/your-repo/CL_GEC.git
# %cd CL_GEC

# Or upload files directly to Colab
import os
os.makedirs('./models', exist_ok=True)
os.makedirs('./data', exist_ok=True)
os.makedirs('./evaluation_results', exist_ok=True)

print("📁 Directories created successfully!")

In [None]:
# Upload all Python files to Colab
# Use the file upload button in Colab to upload:
# - data_utils.py
# - base_trainer.py  
# - negative_sampler.py
# - contrastive_trainer.py
# - inference.py
# - evaluator.py
# - evaluate_model.py

# Verify files are uploaded
required_files = [
    'data_utils.py', 'base_trainer.py', 'negative_sampler.py',
    'contrastive_trainer.py', 'inference.py', 'evaluator.py', 'evaluate_model.py'
]

for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file} found")
    else:
        print(f"❌ {file} missing - please upload this file")

## 📊 Step 1: Data Preparation

In [None]:
# Import necessary modules
import torch
import wandb
from rich.console import Console
from data_utils import load_vigec_dataset, save_processed_data, get_model_and_tokenizer

console = Console()

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
console.print(f"🔥 Using device: {device}")

if torch.cuda.is_available():
    console.print(f"GPU: {torch.cuda.get_device_name(0)}")
    console.print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Login to wandb for experiment tracking
!wandb login

# Set wandb project
wandb.login()
console.print("📈 Wandb setup complete!")

In [None]:
# Load and preprocess viGEC dataset
console.print("📥 Loading viGEC dataset...")

# Load the dataset
data = load_vigec_dataset(dataset_name="phuhuy-se1/viGEC")

# Save processed data
save_processed_data(data, "./data/processed")

console.print(f"✅ Data preprocessing completed!")
for split, split_data in data.items():
    console.print(f"  {split}: {len(split_data)} samples")

## 🎯 Step 2: Base Model Training with Hyperparameter Optimization

In [None]:
# Choose your base model
# Options: "vinai/bartpho-syllable", "vinai/bartpho-word", "VietAI/vit5-base", "VietAI/vit5-large"

MODEL_NAME = "vinai/bartpho-syllable"  # Change this as needed

console.print(f"🤖 Selected model: {MODEL_NAME}")

In [None]:
from base_trainer import BaseTrainer

# Create base trainer
base_trainer = BaseTrainer(
    model_name=MODEL_NAME,
    data_dir="./data/processed",
    output_dir="./models/base",
    hyperopt=True  # Set to False to skip hyperparameter optimization
)

console.print("🏗️ Base trainer initialized!")

In [None]:
# Start base model training
# This will:
# 1. Run hyperparameter optimization (30 trials)
# 2. Train final model with best parameters
# 3. Save model and tokenizer

console.print("🚀 Starting base model training...")
console.print("⏰ This may take 2-4 hours depending on your setup")

base_trainer.train()

console.print("✅ Base model training completed!")

## 🎭 Step 3: Negative Sample Generation

In [None]:
from negative_sampler import NegativeSampleGenerator

# Create negative sample generator using the trained base model
BASE_MODEL_PATH = "./models/base/final"

console.print("🎭 Initializing negative sample generator...")

generator = NegativeSampleGenerator(
    model_path=BASE_MODEL_PATH,
    device="auto"
)

console.print("✅ Negative sample generator ready!")

In [None]:
from data_utils import load_processed_data
import os

# Load processed data
data = load_processed_data("./data/processed")

# Generate contrastive datasets
os.makedirs("./data/contrastive", exist_ok=True)

console.print("🔄 Generating negative samples...")
console.print("⏰ This may take 1-2 hours depending on dataset size")

for split in ['train', 'validation']:
    if split in data:
        console.print(f"Processing {split} split...")
        
        output_path = f"./data/contrastive/{split}_contrastive.json"
        
        contrastive_data = generator.generate_contrastive_dataset(
            data[split],
            output_path,
            batch_size=8,
            max_samples=None  # Set to smaller number for testing, e.g., 1000
        )
        
        # Analyze quality
        generator.analyze_negatives_quality(contrastive_data, sample_size=5)

console.print("✅ Negative sample generation completed!")

## 🔄 Step 4: Contrastive Learning Training

In [None]:
from contrastive_trainer import ContrastiveTrainer

# Create contrastive trainer
contrastive_trainer = ContrastiveTrainer(
    base_model_path=BASE_MODEL_PATH,
    contrastive_data_dir="./data/contrastive",
    output_dir="./models/contrastive",
    hyperopt=True  # Set to False to skip hyperparameter optimization
)

console.print("🔄 Contrastive trainer initialized!")

In [None]:
# Start contrastive learning training
# This will:
# 1. Run hyperparameter optimization for λ, γ, k
# 2. Train final model with contrastive loss + R-Drop
# 3. Save final contrastive model

console.print("🚀 Starting contrastive learning training...")
console.print("⏰ This may take 1-3 hours")

contrastive_trainer.train()

console.print("✅ Contrastive learning training completed!")

## 🔮 Step 5: Inference with Contrastive Search

In [None]:
from inference import GECInference

# Load the trained contrastive model
CONTRASTIVE_MODEL_PATH = "./models/contrastive/final"

# Create inference engines
console.print("🔮 Initializing inference engines...")

# Contrastive search inference
contrastive_inference = GECInference(
    model_path=CONTRASTIVE_MODEL_PATH,
    use_contrastive_search=True,
    contrastive_alpha=0.7,
    contrastive_k=5
)

# Beam search inference for comparison
beam_inference = GECInference(
    model_path=CONTRASTIVE_MODEL_PATH,
    use_contrastive_search=False
)

console.print("✅ Inference engines ready!")

In [None]:
# Test inference with sample texts
test_texts = [
    "Tôi đi học trường đại học.",
    "Hôm nay tôi không đi làm.",
    "Cô ấy rất đẹp và thông minh.",
    "Chúng tôi sẽ đi du lịch vào tuần tới.",
    "Anh ấy làm việc ở công ty lớn."
]

console.print("🧪 Testing inference on sample texts...")

for i, text in enumerate(test_texts):
    console.print(f"\n[bold cyan]Example {i+1}:[/bold cyan]")
    console.print(f"[yellow]Original:[/yellow] {text}")
    
    # Contrastive search
    contrastive_result = contrastive_inference.correct_text(text)
    console.print(f"[green]Contrastive:[/green] {contrastive_result}")
    
    # Beam search
    beam_result = beam_inference.correct_text(text)
    console.print(f"[blue]Beam:[/blue] {beam_result}")

In [None]:
# Interactive correction (optional)
# Uncomment to enable interactive mode

# console.print("🎮 Interactive mode - Enter text to correct (type 'quit' to exit):")
# contrastive_inference.interactive_correction()

## 📊 Step 6: Comprehensive Evaluation

In [None]:
from evaluate_model import ModelEvaluator

# Create model evaluator
evaluator = ModelEvaluator(
    model_path=CONTRASTIVE_MODEL_PATH,
    data_dir="./data/processed",
    output_dir="./evaluation_results"
)

console.print("📊 Model evaluator initialized!")

In [None]:
# Run comprehensive evaluation
console.print("🔍 Starting comprehensive evaluation...")
console.print("⏰ This may take 30-60 minutes")

# Evaluate on test set with different decoding strategies
evaluation_results = evaluator.evaluate_on_test_set(
    max_samples=None,  # Set to smaller number for testing, e.g., 500
    batch_size=8
)

console.print("✅ Evaluation completed!")

In [None]:
# Error type analysis
console.print("🔬 Running error type analysis...")

error_analysis = evaluator.evaluate_error_types(
    max_samples=1000  # Limit for faster analysis
)

console.print("✅ Error type analysis completed!")

In [None]:
# Display evaluation visualizations
from IPython.display import Image, display
import os

# Show evaluation comparison plot
plot_path = "./evaluation_results/evaluation_comparison.png"
if os.path.exists(plot_path):
    console.print("📈 Evaluation Comparison Visualization:")
    display(Image(plot_path))
else:
    console.print("❌ Visualization not found")

In [None]:
# Show evaluation results summary
import pandas as pd

# Load and display comparison table
csv_path = "./evaluation_results/strategy_comparison.csv"
if os.path.exists(csv_path):
    df = pd.read_csv(csv_path)
    console.print("📋 Strategy Comparison Results:")
    display(df)
else:
    console.print("❌ Comparison table not found")

## 📁 Results and Model Export

In [None]:
# Summary of trained models and results
console.print("\n[bold green]🎉 Training Pipeline Completed Successfully![/bold green]")

console.print("\n📁 [bold]Generated Models and Results:[/bold]")
console.print(f"  📦 Base Model: ./models/base/final")
console.print(f"  🔄 Contrastive Model: ./models/contrastive/final")
console.print(f"  📊 Evaluation Results: ./evaluation_results/")
console.print(f"  🎭 Contrastive Data: ./data/contrastive/")

# List all generated files
import os

def list_files_recursive(directory):
    files = []
    for root, dirs, filenames in os.walk(directory):
        for filename in filenames:
            files.append(os.path.join(root, filename))
    return files

console.print("\n📋 [bold]All Generated Files:[/bold]")

for directory in ['./models', './evaluation_results', './data/contrastive']:
    if os.path.exists(directory):
        files = list_files_recursive(directory)
        console.print(f"\n  📂 {directory}:")
        for file in files[:10]:  # Show first 10 files
            console.print(f"    📄 {file}")
        if len(files) > 10:
            console.print(f"    ... and {len(files) - 10} more files")

In [None]:
# Download models and results (for local use)
# Uncomment to create zip files for download

# import shutil

# console.print("📦 Creating downloadable archives...")

# # Create zip files
# shutil.make_archive('contrastive_gec_model', 'zip', './models/contrastive/final')
# shutil.make_archive('evaluation_results', 'zip', './evaluation_results')

# console.print("✅ Archives created:")
# console.print("  📦 contrastive_gec_model.zip - Trained model")
# console.print("  📦 evaluation_results.zip - Evaluation results")
# console.print("\n💾 Use the file browser to download these files")

## 🚀 Quick Usage Guide

Once training is complete, you can use the model for inference:

In [None]:
# Quick usage example
console.print("🚀 [bold]Quick Usage Example:[/bold]")

# Load the model
from inference import GECInference

# Initialize
gec_model = GECInference(
    model_path="./models/contrastive/final",
    use_contrastive_search=True
)

# Correct text
text = "Tôi đi học trường đại học."
corrected = gec_model.correct_text(text)

console.print(f"Original: {text}")
console.print(f"Corrected: {corrected}")

console.print("\n💡 [bold]Usage Tips:[/bold]")
console.print("  🎯 Use contrastive_search=True for better quality")
console.print("  ⚡ Use contrastive_search=False for faster inference")
console.print("  📊 Adjust alpha and k parameters for fine-tuning")
console.print("  📁 Process files with correct_file() method")

## 📝 Configuration and Hyperparameters

Key hyperparameters used in this pipeline:

### Base Training:
- **Learning Rate**: Optimized via Optuna (typically 1e-5 to 1e-4)
- **Label Smoothing**: 0.1
- **Batch Size**: 8-32 (depending on GPU memory)
- **Max Length**: 384 tokens
- **Epochs**: 5-10

### Contrastive Learning:
- **λ (lambda_cl)**: 1.0 (balance between CE and CL loss)
- **γ (temperature)**: 0.25 (contrastive loss temperature)
- **R-Drop α**: 4.0 (R-Drop regularization strength)
- **Epochs**: 3-5

### Contrastive Search:
- **α (alpha)**: 0.7 (balance between confidence and diversity)
- **k**: 5 (top-k candidates)
- **Beam Size**: 1 (as recommended in paper)

These parameters can be adjusted based on your specific needs and computational resources.