# 🇻🇳 Vietnamese GEC with Contrastive Learning - Google Colab (BARTpho Fixed)

**✅ Bug Fix Applied:** Fixed `'BartphoTokenizer' object has no attribute 'vocab'` error

Complete pipeline for training Vietnamese Grammatical Error Correction models with Contrastive Learning.

## 🐛 Recent Fixes:
- **BARTpho Tokenizer Fix**: Resolved vocabulary access issue for SentencePiece tokenizers
- **AdamW Import Fix**: Fixed `ImportError: cannot import name 'AdamW' from 'transformers'` 
- **PyTorch Lightning**: Added missing Lightning installation to dependencies
- **Improved Compatibility**: Better support for both BARTpho and ViT5 models
- **Error Handling**: Added safe vocabulary checking methods

## 📋 Pipeline Overview:
1. **Setup & Installation** - Install dependencies and create project structure
2. **Data Preparation** - Load and preprocess viGEC dataset  
3. **Base Model Training** - Fine-tune BARTpho/ViT5 with hyperparameter optimization
4. **Negative Sample Generation** - Generate negative samples for contrastive learning
5. **Contrastive Learning Training** - Train with contrastive loss + R-Drop
6. **Inference & Evaluation** - Test and evaluate the model

⏰ **Estimated Total Time**: 4-9 hours (depending on GPU)
🔧 **BARTpho Issue**: RESOLVED ✅

## 🚀 Setup and Installation

In [None]:
# Install required packages - FIXED VERSIONS FOR COMPATIBILITY
print("📦 Installing compatible package versions...")

# Install packages in correct order to avoid conflicts
!pip install "numpy>=1.21.0,<2.0.0"  # Install numpy < 2.0 to avoid wandb conflicts
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.36.0 datasets==2.15.0 accelerate==0.25.0
!pip install sentencepiece tokenizers nltk sacrebleu evaluate rouge-score
!pip install pandas scikit-learn tqdm rich omegaconf
!pip install underthesea pyvi ipywidgets matplotlib seaborn
!pip install "optuna>=3.4.0,<4.0.0" "wandb>=0.16.0,<0.17.0"
!pip install "lightning>=2.0.0" "pytorch-lightning>=2.0.0"  # Fixed: Added PyTorch Lightning

print("✅ All packages installed successfully!")

# Create the COMPLETE fixed data_utils.py file
data_utils_complete = '''"""
Data utilities for Vietnamese GEC with viGEC dataset - COMPLETE FIXED VERSION
Resolves: 'BartphoTokenizer' object has no attribute 'vocab' error
"""

import os
import re
import unicodedata
from typing import Dict, List, Tuple, Optional, Union
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset as HFDataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
import logging
import json
from rich.console import Console
from rich.progress import track
import numpy as np

console = Console()
logger = logging.getLogger(__name__)

class ViGECDataset(Dataset):
    """Vietnamese GEC Dataset for training"""
    
    def __init__(
        self,
        data: List[Dict],
        tokenizer: AutoTokenizer,
        max_length: int = 384,
        is_train: bool = True
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_train = is_train
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        source = item['source']
        target = item['target']
        
        # Add task prefix for ViT5
        if hasattr(self.tokenizer, 'task_prefix'):
            source = self.tokenizer.task_prefix + source
        
        # Tokenize source
        source_encoding = self.tokenizer(
            source,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Tokenize target
        target_encoding = self.tokenizer(
            target,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': source_encoding['input_ids'].squeeze(),
            'attention_mask': source_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze(),
            'decoder_attention_mask': target_encoding['attention_mask'].squeeze(),
            'source_text': source,
            'target_text': target
        }

class ContrastiveDataset(Dataset):
    """Dataset for contrastive learning with positive/negative pairs"""
    
    def __init__(
        self,
        data: List[Dict],
        tokenizer: AutoTokenizer,
        max_length: int = 384
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        source = item['source']
        positive = item.get('positive', item.get('target'))  # gold target
        negatives = item.get('negatives', [])  # list of negative samples
        
        # Add task prefix for ViT5
        if hasattr(self.tokenizer, 'task_prefix'):
            source = self.tokenizer.task_prefix + source
        
        # Tokenize source
        source_encoding = self.tokenizer(
            source,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Tokenize positive
        positive_encoding = self.tokenizer(
            positive,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Tokenize negatives (ensure we have at least 3)
        negative_encodings = []
        for neg in negatives[:3]:  # Use up to 3 negatives
            neg_encoding = self.tokenizer(
                neg,
                max_length=self.max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            negative_encodings.append(neg_encoding)
        
        # Pad with duplicates if we don't have enough negatives
        while len(negative_encodings) < 3:
            negative_encodings.append(positive_encoding)
        
        return {
            'input_ids': source_encoding['input_ids'].squeeze(),
            'attention_mask': source_encoding['attention_mask'].squeeze(),
            'positive_ids': positive_encoding['input_ids'].squeeze(),
            'positive_attention_mask': positive_encoding['attention_mask'].squeeze(),
            'negative_ids': torch.stack([neg['input_ids'].squeeze() for neg in negative_encodings]),
            'negative_attention_mask': torch.stack([neg['attention_mask'].squeeze() for neg in negative_encodings]),
            'source_text': source,
            'positive_text': positive,
            'negative_texts': negatives
        }

def normalize_text(text: str) -> str:
    """Normalize Vietnamese text to UTF-8 NFC"""
    text = unicodedata.normalize('NFC', text)
    text = re.sub(r'\\s+', ' ', text)
    text = text.strip()
    return text

def clean_text(text: str) -> str:
    """Clean and preprocess Vietnamese text"""
    text = normalize_text(text)
    # Remove special characters but keep Vietnamese diacritics
    text = re.sub(r'[^\\w\\s\\u00C0-\\u1EF9\\u0300-\\u036F.,!?;:()"\\'\\'-]', '', text)
    # Fix spacing around punctuation
    text = re.sub(r'\\s*([.,!?;:])\\s*', r'\\1 ', text)
    text = re.sub(r'\\s*([()"\\'\\'])\\s*', r' \\1 ', text)
    # Remove extra spaces
    text = re.sub(r'\\s+', ' ', text).strip()
    return text

def load_vigec_dataset(
    dataset_name: str = "phuhuy-se1/viGEC",
    cache_dir: Optional[str] = None,
    test_subset_ratio: float = 0.05
) -> Dict[str, List[Dict]]:
    """Load and preprocess viGEC dataset"""
    
    console.print(f"[bold blue]Loading dataset: {dataset_name}[/bold blue]")
    
    try:
        dataset = load_dataset(dataset_name, cache_dir=cache_dir)
    except Exception as e:
        console.print(f"[red]❌ Error loading dataset: {e}[/red]")
        raise
    
    processed_data = {}
    
    for split in ['train', 'validation', 'test']:
        if split in dataset:
            console.print(f"[yellow]Processing {split} split...[/yellow]")
            
            split_data = []
            for item in track(dataset[split], description=f"Processing {split}"):
                # Handle different column names
                source = item.get('incorrect_text', item.get('source', ''))
                target = item.get('correct_text', item.get('target', ''))
                
                if isinstance(source, str) and isinstance(target, str):
                    source = clean_text(source)
                    target = clean_text(target)
                    
                    # Skip empty or very short texts
                    if len(source.split()) < 2 or len(target.split()) < 2:
                        continue
                    
                    split_data.append({
                        'source': source,
                        'target': target,
                        'id': item.get('id', len(split_data))
                    })
            
            # For test split, use only a subset for faster evaluation
            if split == 'test' and test_subset_ratio < 1.0:
                import random
                random.seed(42)  # For reproducibility
                subset_size = int(len(split_data) * test_subset_ratio)
                split_data = random.sample(split_data, subset_size)
                console.print(f"[blue]Using {subset_size} samples ({test_subset_ratio*100:.1f}%) from test set[/blue]")
            
            processed_data[split] = split_data
            console.print(f"[green]{split}: {len(split_data)} samples[/green]")
    
    return processed_data

def get_model_and_tokenizer(model_name: str):
    """Get model and tokenizer for Vietnamese GEC - COMPLETE FIXED VERSION"""
    
    console.print(f"[bold blue]Loading model: {model_name}[/bold blue]")
    
    try:
        if 'bartpho' in model_name.lower():
            # BARTpho models
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            console.print("[green]✅ BARTpho model loaded[/green]")
            
        elif 'vit5' in model_name.lower():
            # ViT5 models
            tokenizer = T5Tokenizer.from_pretrained(model_name)
            model = T5ForConditionalGeneration.from_pretrained(model_name)
            
            # Add task prefix for ViT5
            if not hasattr(tokenizer, 'task_prefix'):
                tokenizer.task_prefix = "grammatical error correction: "
                console.print(f"[yellow]Added ViT5 task prefix: {tokenizer.task_prefix}[/yellow]")
            
            console.print("[green]✅ ViT5 model loaded[/green]")
            
        else:
            # Generic seq2seq models
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            console.print("[green]✅ Generic seq2seq model loaded[/green]")
        
        # FIXED: Safe vocabulary checking for different tokenizer types
        special_tokens = ['<gec>', '</gec>']
        
        try:
            # Method 1: Standard vocab attribute (BERT, GPT, etc.)
            if hasattr(tokenizer, 'vocab'):
                vocab = tokenizer.vocab
                console.print("[blue]Using .vocab attribute[/blue]")
                
            # Method 2: get_vocab() method (SentencePiece, BARTpho, etc.)
            elif hasattr(tokenizer, 'get_vocab'):
                vocab = tokenizer.get_vocab()
                console.print(f"[blue]Using .get_vocab() method - vocab size: {len(vocab)}[/blue]")
                
            # Method 3: Fallback - try to access vocab through other methods
            elif hasattr(tokenizer, '_tokenizer') and hasattr(tokenizer._tokenizer, 'get_vocab'):
                vocab = tokenizer._tokenizer.get_vocab()
                console.print(f"[blue]Using ._tokenizer.get_vocab() - vocab size: {len(vocab)}[/blue]")
                
            else:
                # No vocab access method found - skip token addition
                console.print("[yellow]No vocab access method found, skipping special token addition[/yellow]")
                return model, tokenizer
            
            # Check which tokens are new
            new_tokens = [token for token in special_tokens if token not in vocab]
            
            if new_tokens:
                # Add new tokens
                added_tokens = tokenizer.add_tokens(new_tokens)
                if added_tokens > 0:
                    model.resize_token_embeddings(len(tokenizer))
                    console.print(f"[yellow]Added {added_tokens} new tokens: {new_tokens}[/yellow]")
                else:
                    console.print("[blue]Tokens already exist in vocabulary[/blue]")
            else:
                console.print("[blue]All special tokens already in vocabulary[/blue]")
                
        except Exception as e:
            console.print(f"[yellow]Warning: Could not check/add vocabulary - {e}[/yellow]")
            console.print("[yellow]Continuing without special token addition[/yellow]")
        
        return model, tokenizer
        
    except Exception as e:
        console.print(f"[red]❌ Error loading model {model_name}: {e}[/red]")
        raise

def create_data_loaders(
    data: Dict[str, List[Dict]],
    tokenizer: AutoTokenizer,
    batch_size: int = 16,
    max_length: int = 384,
    num_workers: int = 0  # Set to 0 for Colab compatibility
) -> Dict[str, DataLoader]:
    """Create data loaders for training"""
    
    data_loaders = {}
    
    for split, split_data in data.items():
        dataset = ViGECDataset(
            data=split_data,
            tokenizer=tokenizer,
            max_length=max_length,
            is_train=(split == 'train')
        )
        
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(split == 'train'),
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            drop_last=(split == 'train')
        )
        
        data_loaders[split] = data_loader
    
    return data_loaders

def create_contrastive_data_loaders(
    data_dir: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 384,
    num_workers: int = 0
) -> Dict[str, DataLoader]:
    """Create data loaders for contrastive learning"""
    
    data_loaders = {}
    
    for split in ['train', 'validation']:
        file_path = os.path.join(data_dir, f"{split}_contrastive.json")
        
        if os.path.exists(file_path):
            # Load contrastive data
            with open(file_path, 'r', encoding='utf-8') as f:
                split_data = json.load(f)
            
            dataset = ContrastiveDataset(
                data=split_data,
                tokenizer=tokenizer,
                max_length=max_length
            )
            
            data_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=(split == 'train'),
                num_workers=num_workers,
                pin_memory=torch.cuda.is_available(),
                drop_last=(split == 'train')
            )
            
            data_loaders[split] = data_loader
            console.print(f"[green]Created {split} contrastive dataloader: {len(dataset)} samples[/green]")
    
    return data_loaders

def save_processed_data(data: Dict[str, List[Dict]], output_dir: str):
    """Save processed data to disk"""
    
    os.makedirs(output_dir, exist_ok=True)
    
    for split, split_data in data.items():
        output_path = os.path.join(output_dir, f"{split}.json")
        
        df = pd.DataFrame(split_data)
        df.to_json(output_path, orient='records', force_ascii=False, indent=2)
        
        console.print(f"[green]Saved {split} data to {output_path}[/green]")

def load_processed_data(data_dir: str) -> Dict[str, List[Dict]]:
    """Load processed data from disk"""
    
    data = {}
    
    for split in ['train', 'validation', 'test']:
        file_path = os.path.join(data_dir, f"{split}.json")
        
        if os.path.exists(file_path):
            df = pd.read_json(file_path, orient='records')
            data[split] = df.to_dict('records')
            console.print(f"[green]Loaded {split}: {len(data[split])} samples[/green]")
    
    return data

def test_tokenizer_compatibility(model_name: str) -> bool:
    """Test tokenizer compatibility and vocabulary access"""
    
    console.print(f"[bold]🧪 Testing compatibility for: {model_name}[/bold]")
    
    try:
        model, tokenizer = get_model_and_tokenizer(model_name)
        
        # Test basic tokenization
        test_text = "Tôi đi học trường đại học."
        if hasattr(tokenizer, 'task_prefix'):
            test_text = tokenizer.task_prefix + test_text
        
        tokens = tokenizer(test_text, return_tensors="pt")
        console.print(f"[green]✅ Tokenization successful - shape: {tokens['input_ids'].shape}[/green]")
        
        # Test vocabulary access methods
        vocab_methods = []
        if hasattr(tokenizer, 'vocab'):
            vocab_methods.append('.vocab')
        if hasattr(tokenizer, 'get_vocab'):
            vocab_methods.append('.get_vocab()')
        if hasattr(tokenizer, '_tokenizer') and hasattr(tokenizer._tokenizer, 'get_vocab'):
            vocab_methods.append('._tokenizer.get_vocab()')
        
        console.print(f"[green]✅ Available vocab methods: {vocab_methods}[/green]")
        
        # Test data loading
        sample_data = [{'source': 'Tôi đi học.', 'target': 'Tôi đi học.'}]
        dataset = ViGECDataset(sample_data, tokenizer, max_length=128)
        sample = dataset[0]
        console.print(f"[green]✅ Dataset creation successful[/green]")
        
        return True
        
    except Exception as e:
        console.print(f"[red]❌ Error testing {model_name}: {e}[/red]")
        import traceback
        traceback.print_exc()
        return False

def verify_bartpho_fix() -> bool:
    """Verify that the BARTpho fix is working"""
    
    console.print("[bold green]🧪 Verifying BARTpho Fix...[/bold green]")
    
    models_to_test = [
        "vinai/bartpho-syllable",
        # "VietAI/vit5-base"  # Uncomment to test ViT5 as well
    ]
    
    results = {}
    
    for model_name in models_to_test:
        console.print(f"\\n{'='*50}")
        console.print(f"Testing: {model_name}")
        console.print(f"{'='*50}")
        
        success = test_tokenizer_compatibility(model_name)
        results[model_name] = success
        
        if success:
            console.print(f"[green]✅ {model_name}: PASSED[/green]")
        else:
            console.print(f"[red]❌ {model_name}: FAILED[/red]")
    
    # Summary
    console.print(f"\\n{'='*60}")
    console.print("[bold]🎯 Test Summary[/bold]")
    console.print(f"{'='*60}")
    
    for model_name, success in results.items():
        status = "✅ PASSED" if success else "❌ FAILED"
        console.print(f"{model_name}: {status}")
    
    all_passed = all(results.values())
    
    if all_passed:
        console.print("\\n[bold green]🎉 All tests passed! BARTpho fix is working correctly.[/bold green]")
    else:
        console.print("\\n[bold red]⚠️ Some tests failed. Check the errors above.[/bold red]")
    
    return all_passed

def check_system_requirements() -> Dict[str, bool]:
    """Check system requirements for training"""
    
    console.print("[bold blue]🔍 Checking System Requirements[/bold blue]")
    
    checks = {}
    
    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    checks['cuda'] = cuda_available
    console.print(f"GPU Available: {'✅' if cuda_available else '❌'}")
    
    if cuda_available:
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        console.print(f"GPU: {gpu_name}")
        console.print(f"GPU Memory: {gpu_memory:.1f} GB")
        checks['gpu_memory'] = gpu_memory
    else:
        checks['gpu_memory'] = 0
    
    # Check disk space
    import shutil
    disk_space = shutil.disk_usage('.').free / 1e9
    checks['disk_space'] = disk_space
    console.print(f"Available Disk Space: {disk_space:.1f} GB")
    
    # Check RAM
    try:
        import psutil
        ram = psutil.virtual_memory().total / 1e9
        checks['ram'] = ram
        console.print(f"Total RAM: {ram:.1f} GB")
    except ImportError:
        checks['ram'] = 0
        console.print("RAM: Unable to check")
    
    # Overall assessment
    ready = (
        checks['cuda'] and 
        checks['gpu_memory'] >= 6 and 
        checks['disk_space'] >= 5
    )
    
    console.print(f"\\n[bold]System Ready for Training: {'✅' if ready else '❌'}[/bold]")
    
    if not ready:
        console.print("[yellow]⚠️ Recommendations:[/yellow]")
        if not checks['cuda']:
            console.print("  - Enable GPU runtime in Colab")
        if checks['gpu_memory'] < 6:
            console.print("  - Use smaller batch sizes")
        if checks['disk_space'] < 5:
            console.print("  - Free up disk space")
    
    return checks

if __name__ == "__main__":
    # Run verification when script is executed directly
    verify_bartpho_fix()
'''

# Write the complete file
with open('data_utils.py', 'w', encoding='utf-8') as f:
    f.write(data_utils_complete)

print("✅ Created COMPLETE data_utils.py with all functions")

# Test the fix immediately
try:
    from data_utils import verify_bartpho_fix, check_system_requirements
    
    print("\n🔍 Running system checks...")
    system_checks = check_system_requirements()
    
    print("\n🧪 Testing BARTpho tokenizer fix...")
    fix_success = verify_bartpho_fix()
    
    if fix_success:
        print("\n🎉 BARTpho tokenizer fix is working correctly!")
        print("✅ Ready to proceed with training")
    else:
        print("\n❌ BARTpho test failed - please check the errors above")
    
except Exception as e:
    print(f"\n❌ Error during testing: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("🎯 Setup Status: ✅ COMPLETE - Ready for Vietnamese GEC Training")
print("="*60)

In [None]:
# Clone the repository (if needed)
# !git clone https://github.com/your-repo/CL_GEC.git
# %cd CL_GEC

# Or upload files directly to Colab
import os
import json
from rich.console import Console

console = Console()

# Create directory structure
directories = [
    './models/base',
    './models/contrastive', 
    './data/processed',
    './data/contrastive',
    './evaluation_results',
    './logs'
]

for directory in directories:
    os.makedirs(directory, exist_ok=True)
    console.print(f"✅ Created: {directory}")

# Create configuration file
config = {
    "model": {
        "name": "vinai/bartpho-syllable",  # Change to "VietAI/vit5-base" for ViT5
        "max_length": 384,
        "batch_size": 8,  # Will be adjusted based on GPU memory
        "learning_rate": 2e-5,
        "num_epochs": 3
    },
    "contrastive": {
        "lambda_cl": 1.0,
        "temperature": 0.25,
        "rdrop_alpha": 4.0,
        "cl_epochs": 2
    },
    "inference": {
        "use_contrastive_search": True,
        "contrastive_alpha": 0.7,
        "contrastive_k": 5,
        "num_beams": 5
    },
    "training": {
        "gradient_accumulation_steps": 4,
        "warmup_steps": 500,
        "weight_decay": 0.01,
        "label_smoothing": 0.1,
        "mixed_precision": True
    }
}

# Save configuration
with open('config.json', 'w', encoding='utf-8') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

console.print("📁 Project structure created successfully!")
console.print("⚙️ Configuration saved to config.json")

# Display configuration
console.print("\n📋 [bold]Current Configuration:[/bold]")
for section, settings in config.items():
    console.print(f"\n  [yellow]{section.upper()}:[/yellow]")
    for key, value in settings.items():
        console.print(f"    {key}: {value}")

console.print("\n💡 [bold]Tip:[/bold] You can modify config.json to adjust parameters")

In [None]:
# Upload all Python files to Colab
# Use the file upload button in Colab to upload:
# - data_utils.py
# - negative_sampler.py
# - contrastive_trainer.py
# - inference.py
# - evaluator.py
# - evaluate_model.py

# Verify files are uploaded
required_files = [
    'data_utils.py', 'negative_sampler.py',
    'contrastive_trainer.py', 'inference.py', 'evaluator.py', 'evaluate_model.py'
]

for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file} found")
    else:
        print(f"❌ {file} missing - please upload this file")

# Create all necessary training files directly in the notebook
# This eliminates the need for file uploads

console.print("📝 Creating training modules...")

# 1. Create Simple Base Trainer
base_trainer_code = '''"""
Simple Base Trainer for Vietnamese GEC - Colab Optimized
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    TrainingArguments,
    Trainer
)
import json
from tqdm.auto import tqdm
from rich.console import Console
import logging

console = Console()
logger = logging.getLogger(__name__)

class SimpleBaseTrainer:
    """Simple trainer for base model fine-tuning"""
    
    def __init__(self, config_path: str = 'config.json'):
        # Load configuration
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        console.print(f"[blue]Using device: {self.device}[/blue]")
        
        # Adjust batch size based on GPU memory
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
            if gpu_memory < 8:
                self.config['model']['batch_size'] = 4
                self.config['training']['gradient_accumulation_steps'] = 8
            elif gpu_memory < 16:
                self.config['model']['batch_size'] = 8
                self.config['training']['gradient_accumulation_steps'] = 4
            
            console.print(f"[yellow]Adjusted batch_size to {self.config['model']['batch_size']} for {gpu_memory:.1f}GB GPU[/yellow]")
    
    def load_data(self):
        """Load processed data"""
        from data_utils import load_processed_data
        
        console.print("[blue]📥 Loading processed data...[/blue]")
        data = load_processed_data("./data/processed")
        
        if not data:
            raise ValueError("No data found in ./data/processed")
        
        return data
    
    def prepare_model_and_tokenizer(self):
        """Load model and tokenizer"""
        from data_utils import get_model_and_tokenizer
        
        model_name = self.config['model']['name']
        console.print(f"[blue]🤖 Loading model: {model_name}[/blue]")
        
        model, tokenizer = get_model_and_tokenizer(model_name)
        model.to(self.device)
        
        return model, tokenizer
    
    def train(self):
        """Train the base model"""
        console.print("[bold green]🚀 Starting base model training...[/bold green]")
        
        # Load data and model
        data = self.load_data()
        model, tokenizer = self.prepare_model_and_tokenizer()
        
        # Create datasets
        from data_utils import create_data_loaders
        data_loaders = create_data_loaders(
            data,
            tokenizer,
            batch_size=self.config['model']['batch_size'],
            max_length=self.config['model']['max_length']
        )
        
        # Setup optimizer
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.config['model']['learning_rate'],
            weight_decay=self.config['training']['weight_decay']
        )
        
        # Setup scheduler
        total_steps = len(data_loaders['train']) * self.config['model']['num_epochs']
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config['training']['warmup_steps'],
            num_training_steps=total_steps
        )
        
        # Training loop
        model.train()
        for epoch in range(self.config['model']['num_epochs']):
            console.print(f"[blue]📚 Epoch {epoch + 1}/{self.config['model']['num_epochs']}[/blue]")
            
            epoch_loss = 0
            num_batches = 0
            
            for batch in tqdm(data_loaders['train'], desc="Training"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = model(**batch)
                loss = outputs.loss
                
                # Backward pass
                loss.backward()
                
                # Update weights
                if (num_batches + 1) % self.config['training']['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                
                epoch_loss += loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            console.print(f"[green]📊 Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}[/green]")
            
            # Validation
            if 'validation' in data_loaders:
                val_loss = self._validate(model, data_loaders['validation'])
                console.print(f"[yellow]📊 Validation Loss: {val_loss:.4f}[/yellow]")
        
        # Save model
        output_dir = "./models/base/final"
        os.makedirs(output_dir, exist_ok=True)
        
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        
        console.print(f"[green]✅ Model saved to {output_dir}[/green]")
        console.print("[bold green]✅ Base model training completed![/bold green]")
        
        return output_dir
    
    def _validate(self, model, val_loader):
        """Validate the model"""
        model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = model(**batch)
                total_loss += outputs.loss.item()
                num_batches += 1
        
        model.train()
        return total_loss / num_batches

def train_base_model():
    """Convenience function to train base model"""
    trainer = SimpleBaseTrainer()
    return trainer.train()
'''

# 2. Create Simple Negative Sampler
negative_sampler_code = '''"""
Simple Negative Sample Generator for Contrastive Learning
"""

import torch
import json
import os
from tqdm.auto import tqdm
from rich.console import Console
from data_utils import get_model_and_tokenizer

console = Console()

class SimpleNegativeSampler:
    """Generate negative samples for contrastive learning"""
    
    def __init__(self, model_path: str):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        console.print(f"[blue]🎭 Loading model from {model_path}[/blue]")
        self.model, self.tokenizer = get_model_and_tokenizer(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Check if this is ViT5
        self.use_prefix = hasattr(self.tokenizer, 'task_prefix')
    
    def generate_negatives(self, source_text: str, target_text: str, num_negatives: int = 3):
        """Generate negative samples for one example"""
        
        # Add prefix for ViT5
        input_text = source_text
        if self.use_prefix:
            input_text = self.tokenizer.task_prefix + source_text
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=384,
            truncation=True
        ).to(self.device)
        
        # Generate with beam search to get diverse candidates
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                num_beams=6,
                num_return_sequences=6,
                max_length=384,
                do_sample=False,
                early_stopping=True,
                length_penalty=1.0
            )
        
        # Decode outputs
        candidates = []
        for output in outputs:
            decoded = self.tokenizer.decode(output, skip_special_tokens=True)
            
            # Remove prefix if present
            if self.use_prefix and decoded.startswith(self.tokenizer.task_prefix):
                decoded = decoded[len(self.tokenizer.task_prefix):].strip()
            
            candidates.append(decoded)
        
        # Filter negatives (different from target)
        negatives = []
        for candidate in candidates:
            if candidate != target_text and candidate != source_text and candidate.strip():
                negatives.append(candidate)
                if len(negatives) >= num_negatives:
                    break
        
        # Add source as negative if it contains errors
        if source_text != target_text and len(negatives) < num_negatives:
            negatives.append(source_text)
        
        # Pad with variations if we don't have enough
        while len(negatives) < num_negatives:
            if negatives:
                negatives.append(negatives[0])  # Duplicate first negative
            else:
                negatives.append(source_text)  # Fallback
        
        return negatives[:num_negatives]
    
    def generate_contrastive_dataset(self, data, output_path: str, batch_size: int = 4):
        """Generate contrastive dataset"""
        
        console.print(f"[blue]🎭 Generating negative samples for {len(data)} examples...[/blue]")
        
        contrastive_data = []
        
        for i in tqdm(range(0, len(data), batch_size), desc="Generating negatives"):
            batch = data[i:i + batch_size]
            
            for item in batch:
                source = item['source']
                target = item['target']
                
                # Generate negatives
                negatives = self.generate_negatives(source, target)
                
                # Create contrastive example
                contrastive_item = {
                    'source': source,
                    'target': target,
                    'positive': target,  # For compatibility
                    'negatives': negatives
                }
                contrastive_data.append(contrastive_item)
        
        # Save to file
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(contrastive_data, f, ensure_ascii=False, indent=2)
        
        console.print(f"[green]💾 Saved {len(contrastive_data)} contrastive examples to {output_path}[/green]")
        return contrastive_data

def generate_negative_samples():
    """Convenience function to generate negative samples"""
    from data_utils import load_processed_data
    
    # Load data
    data = load_processed_data("./data/processed")
    
    # Create sampler
    sampler = SimpleNegativeSampler("./models/base/final")
    
    # Generate for train and validation splits
    for split in ['train', 'validation']:
        if split in data:
            console.print(f"[yellow]Processing {split} split...[/yellow]")
            
            output_path = f"./data/contrastive/{split}_contrastive.json"
            sampler.generate_contrastive_dataset(
                data[split],
                output_path,
                batch_size=2  # Small batch for memory efficiency
            )
    
    console.print("[green]✅ Negative sample generation completed![/green]")
'''

# 3. Create Simple Contrastive Trainer
contrastive_trainer_code = '''"""
Simple Contrastive Learning Trainer
"""

import torch
import torch.nn.functional as F
import json
import os
from tqdm.auto import tqdm
from rich.console import Console
from data_utils import get_model_and_tokenizer, create_contrastive_data_loaders

console = Console()

class SimpleContrastiveTrainer:
    """Simple contrastive learning trainer"""
    
    def __init__(self, config_path: str = 'config.json'):
        # Load configuration
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.lambda_cl = self.config['contrastive']['lambda_cl']
        self.temperature = self.config['contrastive']['temperature']
    
    def contrastive_loss(self, source_hidden, positive_hidden, negative_hidden):
        """Compute contrastive loss"""
        
        # Normalize representations
        source_hidden = F.normalize(source_hidden, dim=-1)
        positive_hidden = F.normalize(positive_hidden, dim=-1)
        negative_hidden = F.normalize(negative_hidden, dim=-1)  # [batch, num_neg, hidden]
        
        # Positive similarity
        pos_sim = torch.sum(source_hidden * positive_hidden, dim=-1) / self.temperature  # [batch]
        
        # Negative similarities
        neg_sim = torch.bmm(
            negative_hidden, 
            source_hidden.unsqueeze(-1)
        ).squeeze(-1) / self.temperature  # [batch, num_neg]
        
        # Contrastive loss
        logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)  # [batch, 1 + num_neg]
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
        
        loss = F.cross_entropy(logits, labels)
        return loss
    
    def train(self):
        """Train with contrastive learning"""
        console.print("[bold blue]🔄 Starting contrastive learning training...[/bold blue]")
        
        # Load model
        model, tokenizer = get_model_and_tokenizer("./models/base/final")
        model.to(self.device)
        
        # Create contrastive data loaders
        data_loaders = create_contrastive_data_loaders(
            "./data/contrastive",
            tokenizer,
            batch_size=self.config['model']['batch_size'] // 2,  # Smaller batch for contrastive
            max_length=self.config['model']['max_length']
        )
        
        if not data_loaders:
            raise ValueError("No contrastive data found!")
        
        # Setup optimizer
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.config['model']['learning_rate'] / 2,  # Lower LR for contrastive
            weight_decay=self.config['training']['weight_decay']
        )
        
        # Training loop
        model.train()
        for epoch in range(self.config['contrastive']['cl_epochs']):
            console.print(f"[blue]🔄 Contrastive Epoch {epoch + 1}/{self.config['contrastive']['cl_epochs']}[/blue]")
            
            epoch_ce_loss = 0
            epoch_cl_loss = 0
            num_batches = 0
            
            for batch in tqdm(data_loaders['train'], desc="Contrastive Training"):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass for source
                source_outputs = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['positive_ids']
                )
                
                # Cross-entropy loss
                ce_loss = source_outputs.loss
                
                # Get encoder hidden states for contrastive loss
                source_hidden = source_outputs.encoder_last_hidden_state.mean(dim=1)  # [batch, hidden]
                
                # Forward pass for positive
                positive_outputs = model.encoder(
                    input_ids=batch['positive_ids'],
                    attention_mask=batch['positive_attention_mask']
                )
                positive_hidden = positive_outputs.last_hidden_state.mean(dim=1)  # [batch, hidden]
                
                # Forward pass for negatives
                batch_size, num_neg, seq_len = batch['negative_ids'].shape
                negative_ids = batch['negative_ids'].view(batch_size * num_neg, seq_len)
                negative_mask = batch['negative_attention_mask'].view(batch_size * num_neg, seq_len)
                
                negative_outputs = model.encoder(
                    input_ids=negative_ids,
                    attention_mask=negative_mask
                )
                negative_hidden = negative_outputs.last_hidden_state.mean(dim=1)  # [batch*num_neg, hidden]
                negative_hidden = negative_hidden.view(batch_size, num_neg, -1)  # [batch, num_neg, hidden]
                
                # Contrastive loss
                cl_loss = self.contrastive_loss(source_hidden, positive_hidden, negative_hidden)
                
                # Combined loss
                total_loss = ce_loss + self.lambda_cl * cl_loss
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                epoch_ce_loss += ce_loss.item()
                epoch_cl_loss += cl_loss.item()
                num_batches += 1
            
            avg_ce_loss = epoch_ce_loss / num_batches
            avg_cl_loss = epoch_cl_loss / num_batches
            
            console.print(f"[green]📊 Epoch {epoch + 1} - CE Loss: {avg_ce_loss:.4f}, CL Loss: {avg_cl_loss:.4f}[/green]")
        
        # Save model
        output_dir = "./models/contrastive/final"
        os.makedirs(output_dir, exist_ok=True)
        
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        
        console.print(f"[green]✅ Contrastive model saved to {output_dir}[/green]")
        console.print("[bold green]✅ Contrastive learning training completed![/bold green]")
        
        return output_dir

def train_contrastive_model():
    """Convenience function to train contrastive model"""
    trainer = SimpleContrastiveTrainer()
    return trainer.train()
'''

# 4. Create Simple Inference
inference_code = '''"""
Simple Inference for Vietnamese GEC
"""

import torch
import json
from rich.console import Console
from data_utils import get_model_and_tokenizer

console = Console()

class SimpleInference:
    """Simple inference for Vietnamese GEC"""
    
    def __init__(self, model_path: str, config_path: str = 'config.json'):
        # Load configuration
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load model
        console.print(f"[blue]🔮 Loading model from {model_path}[/blue]")
        self.model, self.tokenizer = get_model_and_tokenizer(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Check if this is ViT5
        self.use_prefix = hasattr(self.tokenizer, 'task_prefix')
        
        # Inference settings
        self.use_contrastive = self.config['inference']['use_contrastive_search']
        self.alpha = self.config['inference']['contrastive_alpha']
        self.k = self.config['inference']['contrastive_k']
        self.num_beams = self.config['inference']['num_beams']
    
    def correct_text(self, text: str) -> str:
        """Correct a single text"""
        
        # Add prefix for ViT5
        input_text = text
        if self.use_prefix:
            input_text = self.tokenizer.task_prefix + text
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=384,
            truncation=True
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            if self.use_contrastive:
                # Contrastive search
                outputs = self.model.generate(
                    **inputs,
                    penalty_alpha=self.alpha,
                    top_k=self.k,
                    max_length=384,
                    do_sample=False,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            else:
                # Beam search
                outputs = self.model.generate(
                    **inputs,
                    num_beams=self.num_beams,
                    max_length=384,
                    do_sample=False,
                    early_stopping=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
        
        # Decode
        corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Remove prefix if present
        if self.use_prefix and corrected.startswith(self.tokenizer.task_prefix):
            corrected = corrected[len(self.tokenizer.task_prefix):].strip()
        
        return corrected
    
    def correct_batch(self, texts: list) -> list:
        """Correct a batch of texts"""
        return [self.correct_text(text) for text in texts]

def create_inference_engine(model_path: str = "./models/contrastive/final"):
    """Create inference engine"""
    return SimpleInference(model_path)
'''

# Write all files
files_to_create = {
    'base_trainer.py': base_trainer_code,
    'negative_sampler.py': negative_sampler_code,
    'contrastive_trainer.py': contrastive_trainer_code,
    'inference.py': inference_code
}

for filename, code in files_to_create.items():
    with open(filename, 'w', encoding='utf-8') as f:
        f.write(code)
    console.print(f"✅ Created {filename}")

console.print("\n🎯 All training modules created successfully!")
console.print("📋 Available modules:")
console.print("  🎯 base_trainer.py - Base model training")
console.print("  🎭 negative_sampler.py - Negative sample generation")
console.print("  🔄 contrastive_trainer.py - Contrastive learning")
console.print("  🔮 inference.py - Text correction inference")
console.print("\n✅ Ready for training pipeline!")

## 📊 Step 1: Data Preparation

In [None]:
# Import necessary modules and verify environment
import torch
import numpy as np
import sys
from rich.console import Console
from data_utils import load_vigec_dataset, save_processed_data, get_model_and_tokenizer, check_system_requirements
import json

console = Console()

# Check versions for compatibility
console.print("[bold blue]🔍 Environment Check[/bold blue]")
console.print(f"Python: {sys.version}")
console.print(f"PyTorch: {torch.__version__}")
console.print(f"NumPy: {np.__version__}")

# Verify numpy version is compatible
if np.__version__.startswith('2.'):
    console.print("[red]⚠️ WARNING: NumPy 2.0 detected - this may cause wandb issues[/red]")
    console.print("[yellow]Training will continue without wandb if conflicts occur[/yellow]")
else:
    console.print("[green]✅ NumPy version is compatible[/green]")

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
console.print(f"🔥 Using device: {device}")

if torch.cuda.is_available():
    console.print(f"GPU: {torch.cuda.get_device_name(0)}")
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    console.print(f"Memory: {gpu_memory:.1f} GB")
    
    # Optimize settings based on GPU memory
    if gpu_memory < 8:
        console.print("[yellow]⚠️ Limited GPU memory - using smaller batch sizes[/yellow]")
        BATCH_SIZE = 4
        GRADIENT_ACCUMULATION = 8
    elif gpu_memory < 16:
        console.print("[blue]🔄 Medium GPU memory - using moderate batch sizes[/blue]")
        BATCH_SIZE = 8
        GRADIENT_ACCUMULATION = 4
    else:
        console.print("[green]✅ High GPU memory - using optimal batch sizes[/green]")
        BATCH_SIZE = 16
        GRADIENT_ACCUMULATION = 2
else:
    console.print("[red]❌ No GPU available - training will be very slow[/red]")
    BATCH_SIZE = 2
    GRADIENT_ACCUMULATION = 16

console.print(f"📊 Optimized settings: batch_size={BATCH_SIZE}, grad_accumulation={GRADIENT_ACCUMULATION}")

# Update configuration
with open('config.json', 'r') as f:
    config = json.load(f)

config['model']['batch_size'] = BATCH_SIZE
config['training']['gradient_accumulation_steps'] = GRADIENT_ACCUMULATION

with open('config.json', 'w') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

console.print(f"[green]⚙️ Updated config: batch_size={BATCH_SIZE}, grad_accumulation={GRADIENT_ACCUMULATION}[/green]")

# Test wandb availability
try:
    import wandb
    console.print("[green]✅ Wandb available for experiment tracking[/green]")
    WANDB_AVAILABLE = True
except ImportError as e:
    console.print(f"[yellow]⚠️ Wandb not available: {e}[/yellow]")
    console.print("[yellow]Training will continue without experiment tracking[/yellow]")
    WANDB_AVAILABLE = False

# Data Preparation - Load and preprocess viGEC dataset

# Check system requirements first
console.print("[bold blue]🔍 Checking System Requirements[/bold blue]")
system_checks = check_system_requirements()

# Optimize settings based on available resources
if system_checks['gpu_memory'] < 8:
    BATCH_SIZE = 4
    GRAD_ACCUMULATION = 8
    console.print("[yellow]⚠️ Using smaller batch size for limited GPU memory[/yellow]")
elif system_checks['gpu_memory'] < 16:
    BATCH_SIZE = 8  
    GRAD_ACCUMULATION = 4
    console.print("[blue]📊 Using moderate batch size[/blue]")
else:
    BATCH_SIZE = 16
    GRAD_ACCUMULATION = 2
    console.print("[green]✅ Using optimal batch size[/green]")

# Update configuration
with open('config.json', 'r') as f:
    config = json.load(f)

config['model']['batch_size'] = BATCH_SIZE
config['training']['gradient_accumulation_steps'] = GRAD_ACCUMULATION

with open('config.json', 'w') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

console.print(f"[green]⚙️ Updated config: batch_size={BATCH_SIZE}, grad_accumulation={GRAD_ACCUMULATION}[/green]")

# Load and preprocess dataset
console.print("\n[bold blue]📥 Loading viGEC Dataset[/bold blue]")

try:
    # Load the dataset with small test subset for faster evaluation
    data = load_vigec_dataset(
        dataset_name="phuhuy-se1/viGEC",
        test_subset_ratio=0.05  # Use 5% of test set for faster evaluation
    )
    
    # Save processed data
    save_processed_data(data, "./data/processed")
    
    console.print("[bold green]✅ Data preprocessing completed![/bold green]")
    
    # Display statistics
    console.print("\n[bold]📊 Dataset Statistics:[/bold]")
    total_samples = 0
    for split, split_data in data.items():
        console.print(f"  {split}: {len(split_data):,} samples")
        total_samples += len(split_data)
    
    console.print(f"  [bold]Total: {total_samples:,} samples[/bold]")
    
    # Show sample data
    if 'train' in data and len(data['train']) > 0:
        console.print("\n[bold]📝 Sample Data:[/bold]")
        sample = data['train'][0]
        console.print(f"  Source: {sample['source']}")
        console.print(f"  Target: {sample['target']}")
    
    console.print("\n[green]🎯 Ready for model training![/green]")
    
except Exception as e:
    console.print(f"[red]❌ Error loading dataset: {e}[/red]")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# Optional: Login to Wandb for experiment tracking
# Uncomment the lines below if you want to use Wandb for tracking

# !wandb login

# Or login programmatically:
# import wandb
# wandb.login()

console.print("📈 [bold]Wandb Setup (Optional):[/bold]")
console.print("  🔸 Uncomment the lines above to enable experiment tracking")
console.print("  🔸 Visit https://wandb.ai to get your API key")
console.print("  🔸 Training will work without Wandb")

# Test wandb availability
try:
    import wandb
    console.print("[green]✅ Wandb is available (not logged in)[/green]")
except ImportError:
    console.print("[yellow]⚠️ Wandb not installed - training will continue without tracking[/yellow]")

console.print("[blue]🚀 Ready to proceed with or without experiment tracking[/blue]")

## 🎯 Step 2: Base Model Training with Hyperparameter Optimization

In [None]:
# Model Selection and Configuration
import json
from rich.console import Console

console = Console()

# Available models for Vietnamese GEC
AVAILABLE_MODELS = {
    "vinai/bartpho-syllable": "BARTpho (syllable-level) - Recommended for Vietnamese",
    "vinai/bartpho-word": "BARTpho (word-level) - Alternative option",
    "VietAI/vit5-base": "ViT5 Base - T5-based model for Vietnamese",
    "VietAI/vit5-large": "ViT5 Large - Larger model (requires more memory)"
}

# Choose your model here
MODEL_NAME = "vinai/bartpho-syllable"  # Change this to experiment with different models

console.print("[bold blue]🤖 Model Selection[/bold blue]")
console.print(f"[green]Selected Model: {MODEL_NAME}[/green]")
console.print(f"[yellow]Description: {AVAILABLE_MODELS[MODEL_NAME]}[/yellow]")

console.print("\n[bold]📋 Available Models:[/bold]")
for model, description in AVAILABLE_MODELS.items():
    marker = "✅" if model == MODEL_NAME else "  "
    console.print(f"{marker} {model}: {description}")

# Update configuration with selected model
with open('config.json', 'r') as f:
    config = json.load(f)

config['model']['name'] = MODEL_NAME

# Adjust settings based on model type
if 'vit5-large' in MODEL_NAME:
    config['model']['batch_size'] = max(2, config['model']['batch_size'] // 2)
    config['training']['gradient_accumulation_steps'] *= 2
    console.print("[yellow]⚠️ Adjusted batch size for large model[/yellow]")

with open('config.json', 'w') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

console.print(f"\n[green]⚙️ Configuration updated with {MODEL_NAME}[/green]")
console.print(f"[blue]📊 Batch size: {config['model']['batch_size']}[/blue]")
console.print(f"[blue]📊 Learning rate: {config['model']['learning_rate']}[/blue]")
console.print(f"[blue]📊 Max length: {config['model']['max_length']}[/blue]")

# Test model loading
console.print(f"\n[yellow]🧪 Testing model loading...[/yellow]")
try:
    from data_utils import test_tokenizer_compatibility
    success = test_tokenizer_compatibility(MODEL_NAME)
    if success:
        console.print("[green]✅ Model loading test passed![/green]")
    else:
        console.print("[red]❌ Model loading test failed![/red]")
except Exception as e:
    console.print(f"[red]❌ Error testing model: {e}[/red]")

console.print("\n[green]🎯 Ready for base model training![/green]")

In [None]:
from base_trainer import train_base_model
from rich.console import Console
import time
import json
import os
import torch

console = Console()

console.print("[bold green]🚀 Starting Base Model Training[/bold green]")
console.print("⏰ This will take approximately 1-3 hours depending on your setup")

# Display current configuration
with open('config.json', 'r') as f:
    config = json.load(f)

console.print(f"\n[bold blue]📋 Training Configuration:[/bold blue]")
console.print(f"  Model: {config['model']['name']}")
console.print(f"  Batch Size: {config['model']['batch_size']}")
console.print(f"  Learning Rate: {config['model']['learning_rate']}")
console.print(f"  Epochs: {config['model']['num_epochs']}")
console.print(f"  Max Length: {config['model']['max_length']}")
console.print(f"  Gradient Accumulation: {config['training']['gradient_accumulation_steps']}")

# Confirm system readiness
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    console.print(f"  GPU Memory: {gpu_memory:.1f} GB")
else:
    console.print("  [yellow]⚠️ No GPU available - training will be slow[/yellow]")

console.print("\n[yellow]🔄 Starting training...[/yellow]")
start_time = time.time()

try:
    # Train the base model
    model_path = train_base_model()
    
    end_time = time.time()
    training_time = (end_time - start_time) / 60  # Convert to minutes
    
    console.print(f"\n[bold green]✅ Base model training completed![/bold green]")
    console.print(f"[green]📊 Training time: {training_time:.1f} minutes[/green]")
    console.print(f"[green]📁 Model saved to: {model_path}[/green]")
    
    # Verify model was saved
    if os.path.exists(model_path):
        files = os.listdir(model_path)
        console.print(f"[blue]📦 Saved files: {files}[/blue]")
    
    console.print("\n[green]🎯 Ready for negative sample generation![/green]")
    
except Exception as e:
    console.print(f"\n[red]❌ Training failed: {e}[/red]")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# Start base model training
# This will:
# 1. Run hyperparameter optimization (30 trials)
# 2. Train final model with best parameters
# 3. Save model and tokenizer

console.print("🚀 Starting base model training...")
console.print("⏰ This may take 2-4 hours depending on your setup")

base_trainer.train()

console.print("✅ Base model training completed!")

In [None]:
from negative_sampler import generate_negative_samples
from rich.console import Console
import time
import os
import json

console = Console()

# Check if base model exists
BASE_MODEL_PATH = "./models/base/final"
if not os.path.exists(BASE_MODEL_PATH):
    console.print("[red]❌ Base model not found! Please run base training first.[/red]")
    raise FileNotFoundError("Base model not found")

console.print("[bold blue]🎭 Starting Negative Sample Generation[/bold blue]")
console.print("⏰ This will take approximately 30-60 minutes")

# Display configuration
with open('config.json', 'r') as f:
    config = json.load(f)

console.print(f"\n[bold blue]📋 Generation Configuration:[/bold blue]")
console.print(f"  Base Model: {BASE_MODEL_PATH}")
console.print(f"  Batch Size: 2 (optimized for memory)")
console.print(f"  Negatives per Sample: 3")

start_time = time.time()

try:
    # Generate negative samples
    generate_negative_samples()
    
    end_time = time.time()
    generation_time = (end_time - start_time) / 60
    
    console.print(f"\n[bold green]✅ Negative sample generation completed![/bold green]")
    console.print(f"[green]📊 Generation time: {generation_time:.1f} minutes[/green]")
    
    # Verify generated files
    contrastive_dir = "./data/contrastive"
    if os.path.exists(contrastive_dir):
        files = os.listdir(contrastive_dir)
        console.print(f"[blue]📦 Generated files: {files}[/blue]")
        
        # Show statistics
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(contrastive_dir, file)
                with open(file_path, 'r') as f:
                    data = json.load(f)
                console.print(f"  {file}: {len(data)} samples")
    
    console.print("\n[green]🎯 Ready for contrastive learning training![/green]")
    
except Exception as e:
    console.print(f"\n[red]❌ Negative sample generation failed: {e}[/red]")
    import traceback
    traceback.print_exc()
    raise

In [None]:
from data_utils import load_processed_data
import os

# Load processed data
data = load_processed_data("./data/processed")

# Generate contrastive datasets
os.makedirs("./data/contrastive", exist_ok=True)

console.print("🔄 Generating negative samples...")
console.print("⏰ This may take 1-2 hours depending on dataset size")

for split in ['train', 'validation']:
    if split in data:
        console.print(f"Processing {split} split...")
        
        output_path = f"./data/contrastive/{split}_contrastive.json"
        
        contrastive_data = generator.generate_contrastive_dataset(
            data[split],
            output_path,
            batch_size=8,
            max_samples=None  # Set to smaller number for testing, e.g., 1000
        )
        
        # Analyze quality
        generator.analyze_negatives_quality(contrastive_data, sample_size=5)

console.print("✅ Negative sample generation completed!")

## 🔄 Step 4: Contrastive Learning Training

In [None]:
# Start contrastive learning training
# This will:
# 1. Run hyperparameter optimization for λ, γ, k
# 2. Train final model with contrastive loss + R-Drop
# 3. Save final contrastive model

console.print("🚀 Starting contrastive learning training...")
console.print("⏰ This may take 1-3 hours")

contrastive_trainer.train()

console.print("✅ Contrastive learning training completed!")

## 🔮 Step 5: Inference with Contrastive Search

In [None]:
from inference import create_inference_engine
from rich.console import Console
import os
import json

console = Console()

# Check if contrastive model exists
contrastive_model_path = "./models/contrastive/final"
base_model_path = "./models/base/final"

if os.path.exists(contrastive_model_path):
    model_path = contrastive_model_path
    model_type = "Contrastive Learning Model"
elif os.path.exists(base_model_path):
    model_path = base_model_path
    model_type = "Base Model"
    console.print("[yellow]⚠️ Using base model (contrastive model not found)[/yellow]")
else:
    console.print("[red]❌ No trained model found! Please run training first.[/red]")
    raise FileNotFoundError("No trained model found")

console.print("[bold blue]🔮 Initializing Inference Engine[/bold blue]")
console.print(f"[green]Using: {model_type}[/green]")
console.print(f"[blue]Model Path: {model_path}[/blue]")

# Display inference configuration
with open('config.json', 'r') as f:
    config = json.load(f)

console.print(f"\n[bold blue]📋 Inference Configuration:[/bold blue]")
console.print(f"  Contrastive Search: {config['inference']['use_contrastive_search']}")
console.print(f"  Alpha (α): {config['inference']['contrastive_alpha']}")
console.print(f"  K: {config['inference']['contrastive_k']}")
console.print(f"  Beam Size: {config['inference']['num_beams']}")

try:
    # Create inference engine
    inference_engine = create_inference_engine(model_path)
    
    console.print("\n[green]✅ Inference engine initialized successfully![/green]")
    
    # Test with a simple example
    test_text = "Tôi đi học trường đại học."
    console.print(f"\n[yellow]🧪 Testing inference...[/yellow]")
    console.print(f"Input: {test_text}")
    
    corrected = inference_engine.correct_text(test_text)
    console.print(f"Output: {corrected}")
    
    if corrected != test_text:
        console.print("[green]✅ Model is making corrections![/green]")
    else:
        console.print("[blue]ℹ️ Model returned same text (may be already correct)[/blue]")
    
    console.print("\n[green]🎯 Ready for text correction![/green]")
    
except Exception as e:
    console.print(f"\n[red]❌ Inference setup failed: {e}[/red]")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# Interactive correction (optional)
# Uncomment to enable interactive mode

# console.print("🎮 Interactive mode - Enter text to correct (type 'quit' to exit):")
# contrastive_inference.interactive_correction()

## 📊 Step 6: Comprehensive Evaluation

In [None]:
from evaluate_model import ModelEvaluator

# Create model evaluator
evaluator = ModelEvaluator(
    model_path=CONTRASTIVE_MODEL_PATH,
    data_dir="./data/processed",
    output_dir="./evaluation_results"
)

console.print("📊 Model evaluator initialized!")

In [None]:
# Run comprehensive evaluation
console.print("🔍 Starting comprehensive evaluation...")
console.print("⏰ This may take 30-60 minutes")

# Evaluate on test set with different decoding strategies
evaluation_results = evaluator.evaluate_on_test_set(
    max_samples=None,  # Set to smaller number for testing, e.g., 500
    batch_size=8
)

console.print("✅ Evaluation completed!")

In [None]:
# Error type analysis
console.print("🔬 Running error type analysis...")

error_analysis = evaluator.evaluate_error_types(
    max_samples=1000  # Limit for faster analysis
)

console.print("✅ Error type analysis completed!")

In [None]:
# Display evaluation visualizations
from IPython.display import Image, display
import os

# Show evaluation comparison plot
plot_path = "./evaluation_results/evaluation_comparison.png"
if os.path.exists(plot_path):
    console.print("📈 Evaluation Comparison Visualization:")
    display(Image(plot_path))
else:
    console.print("❌ Visualization not found")

In [None]:
# Show evaluation results summary
import pandas as pd

# Load and display comparison table
csv_path = "./evaluation_results/strategy_comparison.csv"
if os.path.exists(csv_path):
    df = pd.read_csv(csv_path)
    console.print("📋 Strategy Comparison Results:")
    display(df)
else:
    console.print("❌ Comparison table not found")

## 📁 Results and Model Export

In [None]:
# Training Pipeline Summary and Results
from rich.console import Console
from rich.table import Table
import os
import json
import time

console = Console()

console.print("[bold green]🎉 Vietnamese GEC Training Pipeline Completed![/bold green]")

# Check what was created
results = {
    "base_model": os.path.exists("./models/base/final"),
    "contrastive_model": os.path.exists("./models/contrastive/final"),
    "processed_data": os.path.exists("./data/processed"),
    "contrastive_data": os.path.exists("./data/contrastive"),
    "inference_ready": 'inference_engine' in globals()
}

# Create results table
table = Table(title="Training Pipeline Results")
table.add_column("Component", style="cyan")
table.add_column("Status", style="green")
table.add_column("Location", style="yellow")

table.add_row(
    "Base Model",
    "✅ Complete" if results["base_model"] else "❌ Missing",
    "./models/base/final"
)

table.add_row(
    "Contrastive Model", 
    "✅ Complete" if results["contrastive_model"] else "❌ Missing",
    "./models/contrastive/final"
)

table.add_row(
    "Processed Data",
    "✅ Complete" if results["processed_data"] else "❌ Missing", 
    "./data/processed"
)

table.add_row(
    "Contrastive Data",
    "✅ Complete" if results["contrastive_data"] else "❌ Missing",
    "./data/contrastive"
)

table.add_row(
    "Inference Engine",
    "✅ Ready" if results["inference_ready"] else "❌ Not Ready",
    "In Memory"
)

console.print(table)

# Model information
if results["contrastive_model"] or results["base_model"]:
    console.print("\n[bold blue]📊 Model Information:[/bold blue]")
    
    # Load configuration
    with open('config.json', 'r') as f:
        config = json.load(f)
    
    console.print(f"  🤖 Base Model: {config['model']['name']}")
    console.print(f"  📏 Max Length: {config['model']['max_length']}")
    console.print(f"  🎯 Batch Size: {config['model']['batch_size']}")
    console.print(f"  📚 Training Epochs: {config['model']['num_epochs']}")
    
    if results["contrastive_model"]:
        console.print(f"  🔄 Contrastive λ: {config['contrastive']['lambda_cl']}")
        console.print(f"  🌡️ Temperature γ: {config['contrastive']['temperature']}")

# File statistics
console.print(f"\n[bold blue]📁 Generated Files:[/bold blue]")

def count_files_recursive(directory):
    if not os.path.exists(directory):
        return 0, 0
    
    file_count = 0
    total_size = 0
    
    for root, dirs, files in os.walk(directory):
        file_count += len(files)
        for file in files:
            try:
                total_size += os.path.getsize(os.path.join(root, file))
            except:
                pass
    
    return file_count, total_size

dirs_to_check = [
    ("./models", "Models"),
    ("./data", "Data"),
    ("./logs", "Logs")
]

for directory, name in dirs_to_check:
    file_count, total_size = count_files_recursive(directory)
    size_mb = total_size / (1024 * 1024)
    console.print(f"  📂 {name}: {file_count} files ({size_mb:.1f} MB)")

# Usage instructions
console.print(f"\n[bold green]🚀 Quick Usage Guide:[/bold green]")

if results["inference_ready"]:
    console.print("✅ Your model is ready to use! Try this:")
    console.print("""
# Correct Vietnamese text
text = "Your Vietnamese text here"
corrected = inference_engine.correct_text(text)
print(f"Original: {text}")
print(f"Corrected: {corrected}")
""")

console.print(f"\n[bold blue]💡 Next Steps:[/bold blue]")
console.print("  🔸 Test your model with more examples")
console.print("  🔸 Fine-tune hyperparameters if needed")
console.print("  🔸 Export model for production use")
console.print("  🔸 Create evaluation metrics")

if results["contrastive_model"]:
    console.print(f"\n[bold green]🎊 Congratulations![/bold green]")
    console.print("You have successfully trained a Vietnamese Grammatical Error Correction model")
    console.print("with Contrastive Learning! The model is ready for use.")
else:
    console.print(f"\n[bold yellow]⚠️ Partial Completion[/bold yellow]")
    console.print("Some components may be missing. Check the status table above.")

# Performance tips
console.print(f"\n[bold blue]⚡ Performance Tips:[/bold blue]")
console.print("  🔸 Use contrastive_search=True for better quality")
console.print("  🔸 Use contrastive_search=False for faster inference")
console.print("  🔸 Adjust alpha and k parameters for fine-tuning")
console.print("  🔸 Batch process multiple texts for efficiency")

In [None]:
# Download models and results (for local use)
# Uncomment to create zip files for download

# import shutil

# console.print("📦 Creating downloadable archives...")

# # Create zip files
# shutil.make_archive('contrastive_gec_model', 'zip', './models/contrastive/final')
# shutil.make_archive('evaluation_results', 'zip', './evaluation_results')

# console.print("✅ Archives created:")
# console.print("  📦 contrastive_gec_model.zip - Trained model")
# console.print("  📦 evaluation_results.zip - Evaluation results")
# console.print("\n💾 Use the file browser to download these files")

## 🚀 Quick Usage Guide

Once training is complete, you can use the model for inference:

In [None]:
# Quick usage example
console.print("🚀 [bold]Quick Usage Example:[/bold]")

# Load the model
from inference import GECInference

# Initialize
gec_model = GECInference(
    model_path="./models/contrastive/final",
    use_contrastive_search=True
)

# Correct text
text = "Tôi đi học trường đại học."
corrected = gec_model.correct_text(text)

console.print(f"Original: {text}")
console.print(f"Corrected: {corrected}")

console.print("\n💡 [bold]Usage Tips:[/bold]")
console.print("  🎯 Use contrastive_search=True for better quality")
console.print("  ⚡ Use contrastive_search=False for faster inference")
console.print("  📊 Adjust alpha and k parameters for fine-tuning")
console.print("  📁 Process files with correct_file() method")

# Practical Usage Examples and Export
from rich.console import Console
import json

console = Console()

console.print("[bold blue]🚀 Practical Usage Examples[/bold blue]")

# Example 1: Single text correction
console.print("\n[yellow]Example 1: Single Text Correction[/yellow]")
example_code_1 = '''
# Correct a single Vietnamese sentence
text = "Tôi đang học tiếng Việt ở trường đại học."
corrected = inference_engine.correct_text(text)
print(f"Original: {text}")
print(f"Corrected: {corrected}")
'''
console.print(example_code_1)

if 'inference_engine' in globals():
    text = "Tôi đang học tiếng Việt ở trường đại học."
    corrected = inference_engine.correct_text(text)
    console.print(f"✅ [green]Original:[/green] {text}")
    console.print(f"✅ [green]Corrected:[/green] {corrected}")

# Example 2: Batch processing
console.print("\n[yellow]Example 2: Batch Processing[/yellow]")
example_code_2 = '''
# Correct multiple texts at once
texts = [
    "Hôm nay trời đẹp quá.",
    "Chúng ta cần phải học bài này.",
    "Tôi thích ăn phở nhiều lắm."
]

corrected_texts = inference_engine.correct_batch(texts)
for original, corrected in zip(texts, corrected_texts):
    print(f"'{original}' → '{corrected}'")
'''
console.print(example_code_2)

if 'inference_engine' in globals():
    texts = [
        "Hôm nay trời đẹp quá.",
        "Chúng ta cần phải học bài này.", 
        "Tôi thích ăn phở nhiều lắm."
    ]
    try:
        corrected_texts = inference_engine.correct_batch(texts)
        for original, corrected in zip(texts, corrected_texts):
            console.print(f"✅ '{original}' → '{corrected}'")
    except:
        # Fallback to individual corrections
        for text in texts:
            corrected = inference_engine.correct_text(text)
            console.print(f"✅ '{text}' → '{corrected}'")

# Example 3: Configuration adjustment
console.print("\n[yellow]Example 3: Adjusting Inference Settings[/yellow]")
example_code_3 = '''
# Modify inference settings for different scenarios
import json

# Load current config
with open('config.json', 'r') as f:
    config = json.load(f)

# For faster inference (lower quality)
config['inference']['use_contrastive_search'] = False
config['inference']['num_beams'] = 3

# For better quality (slower)
config['inference']['use_contrastive_search'] = True
config['inference']['contrastive_alpha'] = 0.8
config['inference']['contrastive_k'] = 6

# Save updated config
with open('config.json', 'w') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

# Recreate inference engine with new settings
inference_engine = create_inference_engine("./models/contrastive/final")
'''
console.print(example_code_3)

# Export functions
console.print("\n[bold blue]📦 Model Export and Deployment[/bold blue]")

export_code = '''
# Export model for deployment
import shutil
import zipfile

def export_model(model_path, export_name):
    """Export trained model as a zip file"""
    shutil.make_archive(export_name, 'zip', model_path)
    print(f"Model exported as {export_name}.zip")

# Export the contrastive model
if os.path.exists("./models/contrastive/final"):
    export_model("./models/contrastive/final", "vietnamese_gec_model")

# Export configuration
shutil.copy("config.json", "vietnamese_gec_config.json")
print("Configuration exported as vietnamese_gec_config.json")
'''

console.print(export_code)

# Performance benchmarking
console.print("\n[bold blue]📊 Performance Benchmarking[/bold blue]")

if 'inference_engine' in globals():
    import time
    
    # Benchmark inference speed
    test_sentences = [
        "Tôi đi học ở trường đại học.",
        "Hôm nay thời tiết rất đẹp.",
        "Chúng tôi sẽ đi du lịch vào tuần sau.",
        "Cô ấy là một người rất thông minh.",
        "Anh ấy làm việc tại một công ty lớn."
    ]
    
    console.print(f"⏱️ Testing inference speed with {len(test_sentences)} sentences...")
    
    start_time = time.time()
    for sentence in test_sentences:
        _ = inference_engine.correct_text(sentence)
    end_time = time.time()
    
    total_time = end_time - start_time
    avg_time = total_time / len(test_sentences)
    
    console.print(f"📊 Results:")
    console.print(f"  Total time: {total_time:.2f} seconds")
    console.print(f"  Average per sentence: {avg_time:.2f} seconds")
    console.print(f"  Throughput: {len(test_sentences)/total_time:.1f} sentences/second")

console.print(f"\n[bold green]🎯 Your Vietnamese GEC model is ready for production use![/bold green]")
console.print("💡 Remember to save your work and download the model files!")

## 📝 Configuration and Hyperparameters

Key hyperparameters used in this pipeline:

### Base Training:
- **Learning Rate**: Optimized via Optuna (typically 1e-5 to 1e-4)
- **Label Smoothing**: 0.1
- **Batch Size**: 8-32 (depending on GPU memory)
- **Max Length**: 384 tokens
- **Epochs**: 5-10

### Contrastive Learning:
- **λ (lambda_cl)**: 1.0 (balance between CE and CL loss)
- **γ (temperature)**: 0.25 (contrastive loss temperature)
- **R-Drop α**: 4.0 (R-Drop regularization strength)
- **Epochs**: 3-5

### Contrastive Search:
- **α (alpha)**: 0.7 (balance between confidence and diversity)
- **k**: 5 (top-k candidates)
- **Beam Size**: 1 (as recommended in paper)

These parameters can be adjusted based on your specific needs and computational resources.