# VeriAIDPO - English Model Training Pipeline
## BERT-base-uncased for English PDPL 2025 Compliance

**Enterprise-Ready AI Training for English Data Protection Requests**

---

### Training Configuration:
- **Model**: BERT-base-uncased (110M parameters)
- **Dataset**: 5,000 English PDPL templates (625 per category)
- **Architecture**: Separate model (Option A - Bilingual Strategy)
- **Target Accuracy**: 88-92% (production-grade)
- **Model Size**: ~440MB
- **Categories**: 8 PDPL 2025 compliance categories

### Expected Performance:
- **Training Time**: 2-3 hours on T4 GPU
- **Inference Speed**: 40-80ms per request
- **Target Accuracy**: 88-92% on test set
- **Confidence**: 88-95% average

### Quality Assurance:
- Zero data leakage detection
- Template diversity analysis (100+ unique structures)
- Overfitting prevention (>=95% accuracy early stop)
- Underfitting detection (<=40% by epoch 2)
- Reserved company sets for train/val/test isolation

### Bilingual Integration:
- Companion to Vietnamese PhoBERT model
- Completely independent dataset (no overlap)
- Same 8 PDPL categories (English translations)
- Exports to VeriSyntra backend (veriaidpo_en)

---

## Step 1: Environment Setup & GPU Validation

**Enterprise-grade environment setup with comprehensive validation**

In [None]:
print("="*70, flush=True)
print("STEP 1: ENVIRONMENT SETUP & GPU VALIDATION", flush=True)
print("="*70 + "\n", flush=True)

import sys
import os
from datetime import datetime

# Check if running on Google Colab
IN_COLAB = 'google.colab' in sys.modules

print(f"Environment: {'Google Colab' if IN_COLAB else 'Local/Other'}", flush=True)
print(f"Python Version: {sys.version}", flush=True)
print(f"Execution Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", flush=True)
print(flush=True)

# Install required packages
print("Installing required packages...", flush=True)
!pip install -q transformers datasets torch scikit-learn matplotlib seaborn
print("Package installation complete!\n", flush=True)

# Import libraries
import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
from datasets import Dataset, DatasetDict
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import json
import random
from typing import List, Dict, Tuple, Set
from collections import defaultdict, Counter
import hashlib
import re

print("Library Import Status:", flush=True)
print(f"  - PyTorch: {torch.__version__}", flush=True)
print(f"  - Transformers: {transformers.__version__}", flush=True)
print(f"  - NumPy: {np.__version__}", flush=True)
print(flush=True)

# GPU Validation
print("GPU Validation:", flush=True)
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"  GPU Available: {gpu_name}", flush=True)
    print(f"  GPU Memory: {gpu_memory:.2f} GB", flush=True)
    device = torch.device("cuda")
    print("  Status: READY FOR TRAINING", flush=True)
else:
    print("  WARNING: No GPU detected. Training will be slow on CPU.", flush=True)
    device = torch.device("cpu")
    print("  Status: CPU MODE (Not recommended for production)", flush=True)

print(flush=True)
print("Environment setup complete!\n", flush=True)
print("="*70, flush=True)

## Step 2: Language Configuration & Hyperparameters

**Dynamic configuration for English BERT model training**

In [None]:
print("="*70, flush=True)
print("STEP 2: LANGUAGE CONFIGURATION & HYPERPARAMETERS", flush=True)
print("="*70 + "\n", flush=True)

# === LANGUAGE CONFIGURATION ===
LANGUAGE = "en"  # English training
MODEL_NAME = "bert-base-uncased"  # BERT for English
MODEL_DISPLAY_NAME = "BERT-base-uncased"

# === DATASET CONFIGURATION ===
DATASET_TARGET = 5000  # Total samples
NUM_CATEGORIES = 8  # PDPL categories
SAMPLES_PER_CATEGORY = DATASET_TARGET // NUM_CATEGORIES  # 625 per category

# === FILE PATHS ===
OUTPUT_DATASET_FILE = "english_pdpl_complete.jsonl"
TRAIN_FILE = "english_pdpl_train.jsonl"
VAL_FILE = "english_pdpl_val.jsonl"
TEST_FILE = "english_pdpl_test.jsonl"
MODEL_SAVE_DIR = "./veriaidpo_en_model"
OUTPUT_DIR = "./"  # Current directory
MAX_LENGTH = 256  # Token sequence length
RANDOM_SEED = 42

# === HYPERPARAMETERS (English-optimized) ===
TRAINING_CONFIG = {
    # Model configuration
    'model_name': MODEL_NAME,
    'hidden_dropout_prob': 0.20,  # Lower than Vietnamese (English easier)
    'attention_probs_dropout_prob': 0.20,
    'classifier_dropout': 0.20,
    
    # Training arguments
    'num_train_epochs': 8,  # Less than Vietnamese (faster convergence)
    'learning_rate': 3e-5,  # Standard BERT learning rate
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 16,
    'weight_decay': 0.01,
    'warmup_ratio': 0.1,
    'label_smoothing_factor': 0.10,  # Moderate smoothing
    'lr_scheduler_type': 'cosine',
    'save_strategy': 'epoch',
    'evaluation_strategy': 'epoch',
    'load_best_model_at_end': True,
    'metric_for_best_model': 'eval_accuracy',
    
    # Early stopping thresholds (English)
    'early_high_accuracy_threshold': 0.90,  # 90% (vs 92% for Vietnamese)
    'extreme_overfitting_threshold': 0.95,
    'underfitting_threshold': 0.40,
    'patience': 3,
    
    # File paths and directories
    'output_dir': OUTPUT_DIR,
    'model_save_dir': MODEL_SAVE_DIR,
    'train_file': TRAIN_FILE,
    'val_file': VAL_FILE,
    'test_file': TEST_FILE,
    'max_length': MAX_LENGTH,
    'seed': RANDOM_SEED,
    
    # Nested hyperparameters dictionary for SmartTrainingCallback
    'hyperparameters': {
        'early_high_accuracy_threshold': 0.90,
        'extreme_overfitting_threshold': 0.95,
    }
}

# === DATA SPLIT CONFIGURATION ===
DATA_SPLIT = {
    'train': 0.70,  # 3,500 samples
    'val': 0.15,    # 750 samples
    'test': 0.15    # 750 samples
}

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Display configuration
print("Configuration Summary:", flush=True)
print(f"  Language: {LANGUAGE.upper()}", flush=True)
print(f"  Model: {MODEL_DISPLAY_NAME}", flush=True)
print(f"  Target Dataset Size: {DATASET_TARGET:,} samples", flush=True)
print(f"  Samples per Category: {SAMPLES_PER_CATEGORY} samples", flush=True)
print(f"  Number of Categories: {NUM_CATEGORIES}", flush=True)
print(flush=True)

print("Hyperparameters:", flush=True)
print(f"  Dropout: {TRAINING_CONFIG['hidden_dropout_prob']}", flush=True)
print(f"  Learning Rate: {TRAINING_CONFIG['learning_rate']}", flush=True)
print(f"  Label Smoothing: {TRAINING_CONFIG['label_smoothing_factor']}", flush=True)
print(f"  Epochs: {TRAINING_CONFIG['num_train_epochs']}", flush=True)
print(f"  Batch Size: {TRAINING_CONFIG['per_device_train_batch_size']}", flush=True)
print(flush=True)

print("Data Split:", flush=True)
print(f"  Train: {DATA_SPLIT['train']*100:.0f}% ({int(DATASET_TARGET * DATA_SPLIT['train']):,} samples)", flush=True)
print(f"  Validation: {DATA_SPLIT['val']*100:.0f}% ({int(DATASET_TARGET * DATA_SPLIT['val']):,} samples)", flush=True)
print(f"  Test: {DATA_SPLIT['test']*100:.0f}% ({int(DATASET_TARGET * DATA_SPLIT['test']):,} samples)", flush=True)
print(flush=True)

print("Configuration complete!\n", flush=True)
print("="*70, flush=True)

## Step 3: PDPL Categories Definition

**8 PDPL 2025 compliance categories (bilingual - same as Vietnamese model)**

In [None]:
print("="*70, flush=True)
print("STEP 3: PDPL CATEGORIES DEFINITION", flush=True)
print("="*70 + "\n", flush=True)

# Enhanced PDPL 2025 Categories (Bilingual)
# Same categories as Vietnamese model for consistency
PDPL_CATEGORIES = {
    0: {
        "vi": "Tinh hop phap, cong bang va minh bach",
        "en": "Lawfulness, fairness and transparency"
    },
    1: {
        "vi": "Han che muc dich",
        "en": "Purpose limitation"
    },
    2: {
        "vi": "Toi thieu hoa du lieu",
        "en": "Data minimisation"
    },
    3: {
        "vi": "Tinh chinh xac",
        "en": "Accuracy"
    },
    4: {
        "vi": "Han che luu tru",
        "en": "Storage limitation"
    },
    5: {
        "vi": "Tinh toan ven va bao mat",
        "en": "Integrity and confidentiality"
    },
    6: {
        "vi": "Trach nhiem giai trinh",
        "en": "Accountability"
    },
    7: {
        "vi": "Quyen cua chu the du lieu",
        "en": "Data subject rights"
    }
}

# Display categories
print("PDPL 2025 Categories (English):", flush=True)
print(flush=True)
for cat_id, names in PDPL_CATEGORIES.items():
    print(f"  Category {cat_id}: {names['en']}", flush=True)
    print(f"    Vietnamese: {names['vi']}", flush=True)
    print(flush=True)

print(f"Total Categories: {len(PDPL_CATEGORIES)}", flush=True)
print(f"Target per Category: {SAMPLES_PER_CATEGORY} samples\n", flush=True)
print("="*70, flush=True)

## Step 4: English Template Generator

**Generate diverse English PDPL templates with zero Vietnamese overlap**

In [None]:
print("="*70, flush=True)
print("STEP 4: ENGLISH TEMPLATE GENERATOR", flush=True)
print("="*70 + "\n", flush=True)

# English companies (same names as Vietnamese but English contexts)
ENGLISH_COMPANIES = {
    'north': [
        'VNG', 'FPT', 'VNPT', 'Viettel', 'Vingroup', 'VietinBank', 
        'Agribank', 'BIDV', 'MB Bank', 'ACB', 'VPBank', 'TPBank'
    ],
    'central': [
        'Vinamilk', 'Hoa Phat', 'Petrolimex', 'PVN', 'EVN', 
        'Vinatex', 'DHG Pharma', 'Hau Giang Pharma'
    ],
    'south': [
        'Shopee VN', 'Lazada VN', 'Tiki', 'Grab VN', 'MoMo', 'ZaloPay', 
        'Techcombank', 'VCB', 'Sacombank', 'HDBank'
    ]
}

# English business contexts
BUSINESS_CONTEXTS_EN = {
    'banking': [
        'account', 'transaction', 'credit card', 'loan', 'deposit', 
        'transfer', 'investment', 'insurance', 'mortgage', 'credit'
    ],
    'ecommerce': [
        'order', 'payment', 'delivery', 'product', 'promotion', 
        'review', 'cart', 'voucher', 'refund', 'return'
    ],
    'healthcare': [
        'medical record', 'consultation', 'prescription', 'insurance', 
        'test', 'diagnosis', 'treatment', 'surgery', 'follow-up', 'vaccine'
    ],
    'education': [
        'student', 'grade', 'tuition', 'certificate', 'course', 
        'degree', 'exam', 'scholarship', 'enrollment', 'schedule'
    ],
    'technology': [
        'application', 'account', 'data', 'security', 'service', 
        'software', 'login', 'password', 'API', 'cloud'
    ],
    'insurance': [
        'policy', 'benefit', 'claim', 'premium', 'contract', 
        'claim request', 'risk assessment', 'reinsurance'
    ],
    'telecommunications': [
        'call', 'message', 'data', 'roaming', 'charge', 
        'subscription', 'network', 'phone number', 'internet'
    ],
    'logistics': [
        'shipping', 'delivery', 'warehouse', 'tracking', 'fee', 
        'packaging', 'export', 'import', 'logistics'
    ]
}

class EnglishTemplateGenerator:
    """Generate diverse English PDPL templates with zero overlap"""
    
    def __init__(self):
        self.companies = ENGLISH_COMPANIES
        self.contexts = BUSINESS_CONTEXTS_EN
        self.generated_templates = set()
        self.template_hashes = set()
        
        # Sentence structure types (English grammar)
        self.structure_types = ['simple', 'compound', 'complex']
        
        # Formality levels
        self.formality_levels = ['formal', 'business', 'standard']
        
        # Modal verbs for variation
        self.modals = ['must', 'shall', 'should', 'will', 'can', 'may']
        
        # Verb variations
        self.verbs = {
            'process': ['process', 'handle', 'manage', 'deal with'],
            'collect': ['collect', 'gather', 'obtain', 'acquire'],
            'store': ['store', 'retain', 'keep', 'maintain'],
            'delete': ['delete', 'remove', 'erase', 'eliminate'],
            'protect': ['protect', 'safeguard', 'secure', 'shield'],
            'ensure': ['ensure', 'guarantee', 'maintain', 'establish']
        }
    
    def get_category_templates(self, category_id: int) -> Dict[str, List[str]]:
        """Get template patterns for each PDPL category"""
        
        templates = {
            0: {  # Lawfulness, fairness and transparency
                'simple': [
                    "{company} {modal} process {context} data lawfully and fairly.",
                    "The company {company} ensures transparency in {context} processing.",
                    "{company} provides clear information about {context} data usage.",
                    "Lawful processing of {context} is required by {company}.",
                    "{company} commits to fair data practices for {context}.",
                    "{company} {modal} handle {context} data with transparency.",
                    "Fair and lawful {context} processing is ensured by {company}."
                ],
                'compound': [
                    "{company} processes {context} data lawfully and provides transparent information.",
                    "The company {company} ensures fairness but also complies with legal requirements for {context}.",
                    "{company} maintains transparency and lawful processing of {context} data.",
                    "{company} {modal} process {context} fairly and ensure legal compliance.",
                    "Legal compliance is maintained and {context} data is processed transparently by {company}."
                ],
                'complex': [
                    "To ensure lawfulness, {company} establishes clear legal basis for {context} processing.",
                    "When processing {context} data, {company} {modal} demonstrate compliance with legal requirements.",
                    "Although complex, {company} commits to maintaining transparency in {context} processing.",
                    "Before processing {context}, {company} verifies legal grounds and ensures fairness.",
                    "If {context} data is processed, {company} guarantees lawful and transparent handling."
                ]
            },
            1: {  # Purpose limitation
                'simple': [
                    "{company} limits {context} data to specified purposes.",
                    "The company {company} restricts {context} use to stated goals.",
                    "{company} {modal} collect {context} for defined purposes only.",
                    "Purpose limitation applies to all {context} data at {company}.",
                    "{company} ensures {context} is used for specified purposes.",
                    "Data collection for {context} is purpose-limited by {company}."
                ],
                'compound': [
                    "{company} collects {context} data and restricts use to stated purposes.",
                    "The company {company} defines purposes and limits {context} data usage accordingly.",
                    "{company} {modal} specify purposes but ensure {context} is not used beyond them.",
                    "Purpose specification is done and {context} usage is limited by {company}."
                ],
                'complex': [
                    "When collecting {context} data, {company} ensures purpose limitation is enforced.",
                    "To prevent misuse, {company} restricts {context} data to originally stated purposes.",
                    "Although {context} has multiple uses, {company} limits processing to defined goals.",
                    "Before using {context}, {company} verifies alignment with specified purposes."
                ]
            },
            2: {  # Data minimisation
                'simple': [
                    "{company} collects only necessary {context} data.",
                    "The company {company} minimizes {context} data collection.",
                    "{company} {modal} limit {context} to essential information only.",
                    "Data minimisation principles apply to {context} at {company}.",
                    "{company} ensures minimal {context} data is collected.",
                    "Only required {context} information is gathered by {company}."
                ],
                'compound': [
                    "{company} evaluates necessity and collects minimal {context} data.",
                    "The company {company} reviews requirements and minimizes {context} collection.",
                    "{company} {modal} assess needs but collect only essential {context} data.",
                    "Necessity is evaluated and {context} minimisation is applied by {company}."
                ],
                'complex': [
                    "To reduce data volume, {company} collects only necessary {context} information.",
                    "When gathering {context}, {company} ensures data minimisation is implemented.",
                    "Although more {context} could be collected, {company} limits to essentials.",
                    "Before collecting {context}, {company} verifies necessity and minimizes data."
                ]
            },
            3: {  # Accuracy
                'simple': [
                    "{company} ensures {context} data is accurate and up-to-date.",
                    "The company {company} maintains accurate {context} records.",
                    "{company} {modal} verify {context} data accuracy regularly.",
                    "Data accuracy is maintained for {context} by {company}.",
                    "{company} updates {context} information to ensure accuracy.",
                    "Accurate {context} data is guaranteed by {company}."
                ],
                'compound': [
                    "{company} verifies {context} data and maintains accuracy standards.",
                    "The company {company} reviews records and ensures {context} is accurate.",
                    "{company} {modal} check {context} but also update when needed.",
                    "Verification is performed and {context} accuracy is ensured by {company}."
                ],
                'complex': [
                    "To ensure reliability, {company} verifies {context} data accuracy regularly.",
                    "When managing {context}, {company} implements accuracy verification processes.",
                    "Although {context} changes over time, {company} maintains data accuracy.",
                    "Before using {context}, {company} confirms data is accurate and current."
                ]
            },
            4: {  # Storage limitation
                'simple': [
                    "{company} limits {context} data storage duration.",
                    "The company {company} retains {context} for limited periods.",
                    "{company} {modal} delete {context} data after specified time.",
                    "Storage limitation applies to {context} at {company}.",
                    "{company} ensures {context} is not kept longer than necessary.",
                    "Retention periods for {context} are enforced by {company}."
                ],
                'compound': [
                    "{company} sets retention periods and deletes {context} accordingly.",
                    "The company {company} defines storage limits and removes old {context} data.",
                    "{company} {modal} retain {context} temporarily but delete after use.",
                    "Time limits are set and {context} deletion is enforced by {company}."
                ],
                'complex': [
                    "To prevent excessive storage, {company} deletes {context} after retention period.",
                    "When storing {context}, {company} ensures compliance with time limitations.",
                    "Although {context} may be needed later, {company} limits storage duration.",
                    "After retention period expires, {company} securely deletes {context} data."
                ]
            },
            5: {  # Integrity and confidentiality
                'simple': [
                    "{company} protects {context} data integrity and confidentiality.",
                    "The company {company} secures {context} against unauthorized access.",
                    "{company} {modal} implement security measures for {context}.",
                    "Data security is maintained for {context} by {company}.",
                    "{company} ensures {context} confidentiality through encryption.",
                    "Integrity of {context} data is protected by {company}."
                ],
                'compound': [
                    "{company} encrypts {context} data and prevents unauthorized access.",
                    "The company {company} implements controls and protects {context} integrity.",
                    "{company} {modal} secure {context} but also monitor for breaches.",
                    "Security measures are applied and {context} confidentiality is ensured by {company}."
                ],
                'complex': [
                    "To prevent data breaches, {company} implements security controls for {context}.",
                    "When handling {context}, {company} ensures integrity and confidentiality.",
                    "Although threats exist, {company} protects {context} through robust security.",
                    "Before processing {context}, {company} verifies security measures are active."
                ]
            },
            6: {  # Accountability
                'simple': [
                    "{company} demonstrates accountability for {context} data processing.",
                    "The company {company} maintains records of {context} activities.",
                    "{company} {modal} document {context} processing decisions.",
                    "Accountability measures are implemented for {context} by {company}.",
                    "{company} takes responsibility for {context} data handling.",
                    "Documentation of {context} processing is maintained by {company}."
                ],
                'compound': [
                    "{company} documents decisions and demonstrates accountability for {context}.",
                    "The company {company} maintains logs and shows responsibility for {context}.",
                    "{company} {modal} record activities but also ensure accountability.",
                    "Records are kept and {context} accountability is demonstrated by {company}."
                ],
                'complex': [
                    "To demonstrate compliance, {company} maintains detailed {context} processing records.",
                    "When processing {context}, {company} ensures accountability through documentation.",
                    "Although complex, {company} takes full responsibility for {context} handling.",
                    "Before making decisions about {context}, {company} documents the rationale."
                ]
            },
            7: {  # Data subject rights
                'simple': [
                    "{company} respects data subject rights for {context}.",
                    "The company {company} enables {context} data access requests.",
                    "{company} {modal} provide {context} data upon user request.",
                    "User rights are honored for {context} by {company}.",
                    "{company} facilitates {context} data deletion requests.",
                    "Data subject rights regarding {context} are protected by {company}."
                ],
                'compound': [
                    "{company} receives requests and provides {context} data access.",
                    "The company {company} honors rights and enables {context} correction.",
                    "{company} {modal} process requests but also verify user identity.",
                    "Requests are handled and {context} rights are respected by {company}."
                ],
                'complex': [
                    "To respect user rights, {company} enables {context} data access and deletion.",
                    "When users request {context} data, {company} provides it within legal timeframes.",
                    "Although verification is needed, {company} honors {context} data rights.",
                    "After receiving a request, {company} promptly provides {context} information."
                ]
            }
        }
        
        return templates.get(category_id, {})
    
    def generate_template_hash(self, text: str) -> str:
        """Generate hash for template deduplication"""
        # Normalize text: lowercase, remove extra spaces
        normalized = re.sub(r'\s+', ' ', text.lower().strip())
        return hashlib.md5(normalized.encode()).hexdigest()
    
    def calculate_similarity(self, text1: str, text2: str) -> float:
        """Calculate simple word-based similarity"""
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())
        
        if not words1 or not words2:
            return 0.0
        
        intersection = words1.intersection(words2)
        union = words1.union(words2)
        
        return len(intersection) / len(union) if union else 0.0
    
    def is_unique_template(self, text: str, similarity_threshold: float = 0.85) -> bool:
        """Check if template is unique enough"""
        # Check hash for exact duplicates
        text_hash = self.generate_template_hash(text)
        if text_hash in self.template_hashes:
            return False
        
        # Check similarity with existing templates
        for existing_template in self.generated_templates:
            similarity = self.calculate_similarity(text, existing_template)
            if similarity > similarity_threshold:
                return False
        
        return True
    
    def generate_sample(self, category_id: int, region: str, context_type: str) -> Dict:
        """Generate a single training sample"""
        # Get templates for this category
        category_templates = self.get_category_templates(category_id)
        
        # Randomly select structure type and formality
        structure = random.choice(self.structure_types)
        formality = random.choice(self.formality_levels)
        
        # Get template pattern
        template_pattern = random.choice(category_templates.get(structure, category_templates['simple']))
        
        # Get company and context
        company = random.choice(self.companies[region])
        context = random.choice(self.contexts[context_type])
        modal = random.choice(self.modals)
        
        # Generate text
        text = template_pattern.format(
            company=company,
            context=context,
            modal=modal
        )
        
        # Create sample dictionary
        sample = {
            'text': text,
            'label': category_id,
            'template_id': self.generate_template_hash(text),
            'language': 'en',
            'metadata': {
                'company': company,
                'context_type': context_type,
                'context': context,
                'region': region,
                'structure': structure,
                'formality': formality,
                'category_name': PDPL_CATEGORIES[category_id]['en']
            }
        }
        
        return sample
    
    def generate_diverse_samples(self, category_id: int, count: int) -> List[Dict]:
        """Generate diverse samples for a category"""
        samples = []
        attempts = 0
        max_attempts = count * 10  # Allow 10x attempts for uniqueness
        
        regions = list(self.companies.keys())
        context_types = list(self.contexts.keys())
        
        while len(samples) < count and attempts < max_attempts:
            # Randomly select region and context type
            region = random.choice(regions)
            context_type = random.choice(context_types)
            
            # Generate sample
            sample = self.generate_sample(category_id, region, context_type)
            
            # Check uniqueness
            if self.is_unique_template(sample['text']):
                samples.append(sample)
                self.generated_templates.add(sample['text'])
                self.template_hashes.add(sample['template_id'])
            
            attempts += 1
        
        if len(samples) < count:
            print(f"  Warning: Only generated {len(samples)}/{count} unique samples for category {category_id}", flush=True)
        
        return samples

# Initialize generator
print("Initializing English template generator...", flush=True)
generator = EnglishTemplateGenerator()

print(f"Companies available: {sum(len(v) for v in ENGLISH_COMPANIES.values())}", flush=True)
print(f"Business contexts: {sum(len(v) for v in BUSINESS_CONTEXTS_EN.values())}", flush=True)
print(f"Template structures: {len(generator.structure_types)}", flush=True)
print(f"Formality levels: {len(generator.formality_levels)}", flush=True)
print(flush=True)

print("Template generator ready!\n", flush=True)
print("="*70, flush=True)

## Step 5: Generate English Dataset

**Generate 5,000 unique English PDPL templates (625 per category)**

In [None]:
print("="*70, flush=True)
print("STEP 5: GENERATE ENGLISH DATASET", flush=True)
print("="*70 + "\n", flush=True)

# Generate samples for all categories
all_samples = []
category_stats = {}

print(f"Generating {DATASET_TARGET:,} English PDPL templates...\n", flush=True)

for category_id in range(NUM_CATEGORIES):
    category_name = PDPL_CATEGORIES[category_id]['en']
    print(f"Category {category_id}: {category_name}", flush=True)
    print(f"  Target: {SAMPLES_PER_CATEGORY} samples", flush=True)
    
    # Generate samples
    samples = generator.generate_diverse_samples(category_id, SAMPLES_PER_CATEGORY)
    
    # Add to all samples
    all_samples.extend(samples)
    
    # Track statistics
    category_stats[category_id] = {
        'count': len(samples),
        'name': category_name,
        'structures': Counter([s['metadata']['structure'] for s in samples]),
        'formality': Counter([s['metadata']['formality'] for s in samples]),
        'regions': Counter([s['metadata']['region'] for s in samples])
    }
    
    print(f"  Generated: {len(samples)} samples", flush=True)
    print(f"  Structures: {dict(category_stats[category_id]['structures'])}", flush=True)
    print(flush=True)

# Shuffle all samples
random.shuffle(all_samples)

print("="*70, flush=True)
print("DATASET GENERATION COMPLETE", flush=True)
print("="*70 + "\n", flush=True)

print(f"Total samples generated: {len(all_samples):,}", flush=True)
print(f"Target samples: {DATASET_TARGET:,}", flush=True)
print(f"Unique templates: {len(generator.generated_templates):,}", flush=True)
print(flush=True)

# Category distribution
print("Category Distribution:", flush=True)
for cat_id, stats in category_stats.items():
    percentage = (stats['count'] / len(all_samples)) * 100
    print(f"  Cat {cat_id} ({stats['name']}): {stats['count']} ({percentage:.1f}%)", flush=True)

print(flush=True)

# Save complete dataset
print(f"Saving dataset to {OUTPUT_DATASET_FILE}...", flush=True)
with open(OUTPUT_DATASET_FILE, 'w', encoding='utf-8') as f:
    for sample in all_samples:
        f.write(json.dumps(sample, ensure_ascii=False) + '\n')

print(f"Dataset saved successfully!\n", flush=True)
print("="*70, flush=True)

## Step 6: Reserved Company Sets for Data Isolation

**Critical: Prevent data leakage by isolating companies across train/val/test splits**

In [None]:
print("="*70, flush=True)
print("STEP 6: RESERVED COMPANY SETS (DATA LEAK PREVENTION)", flush=True)
print("="*70 + "\n", flush=True)

# Reserve companies for each split to prevent data leakage
# This ensures NO company appears in multiple splits
RESERVED_COMPANIES = {
    'train': {
        'north': ['VNG', 'FPT', 'VNPT', 'Viettel', 'Vingroup', 'MB Bank', 'ACB', 'VPBank'],
        'central': ['Vinamilk', 'Hoa Phat', 'Petrolimex', 'PVN', 'EVN'],
        'south': ['Shopee VN', 'Lazada VN', 'Tiki', 'Grab VN', 'MoMo', 'ZaloPay', 'Techcombank']
    },
    'val': {
        'north': ['VietinBank', 'Agribank'],
        'central': ['Vinatex', 'DHG Pharma'],
        'south': ['VCB', 'Sacombank']
    },
    'test': {
        'north': ['BIDV', 'TPBank'],
        'central': ['Hau Giang Pharma'],
        'south': ['HDBank']
    }
}

# Verify no company overlap between splits
print("Verifying company isolation...\n", flush=True)

all_train_companies = []
all_val_companies = []
all_test_companies = []

for region in ['north', 'central', 'south']:
    all_train_companies.extend(RESERVED_COMPANIES['train'][region])
    all_val_companies.extend(RESERVED_COMPANIES['val'][region])
    all_test_companies.extend(RESERVED_COMPANIES['test'][region])

# Check for overlaps
train_val_overlap = set(all_train_companies).intersection(set(all_val_companies))
train_test_overlap = set(all_train_companies).intersection(set(all_test_companies))
val_test_overlap = set(all_val_companies).intersection(set(all_test_companies))

print(f"Train companies: {len(all_train_companies)}", flush=True)
print(f"Val companies: {len(all_val_companies)}", flush=True)
print(f"Test companies: {len(all_test_companies)}", flush=True)
print(flush=True)

print(f"Train-Val overlap: {len(train_val_overlap)} (should be 0)", flush=True)
print(f"Train-Test overlap: {len(train_test_overlap)} (should be 0)", flush=True)
print(f"Val-Test overlap: {len(val_test_overlap)} (should be 0)", flush=True)
print(flush=True)

if len(train_val_overlap) == 0 and len(train_test_overlap) == 0 and len(val_test_overlap) == 0:
    print("SUCCESS: Complete company isolation achieved!", flush=True)
    print("No data leakage possible from company overlap.\n", flush=True)
else:
    print("WARNING: Company overlap detected!", flush=True)
    if train_val_overlap:
        print(f"  Train-Val: {train_val_overlap}", flush=True)
    if train_test_overlap:
        print(f"  Train-Test: {train_test_overlap}", flush=True)
    if val_test_overlap:
        print(f"  Val-Test: {val_test_overlap}", flush=True)
    print(flush=True)

print("="*70, flush=True)

## Step 7: Data Splitting with Company Isolation

**Split dataset into Train/Val/Test with reserved company sets**

In [None]:
print("="*70, flush=True)
print("STEP 7: DATA SPLITTING WITH COMPANY ISOLATION", flush=True)
print("="*70 + "\n", flush=True)

# Split samples by reserved company sets
train_samples = []
val_samples = []
test_samples = []

print("Splitting dataset by reserved companies...\n", flush=True)

for sample in all_samples:
    company = sample['metadata']['company']
    
    # Check which split this company belongs to
    if company in all_train_companies:
        train_samples.append(sample)
    elif company in all_val_companies:
        val_samples.append(sample)
    elif company in all_test_companies:
        test_samples.append(sample)
    else:
        # Fallback: assign to train if company not in reserved sets
        train_samples.append(sample)

# Shuffle each split
random.shuffle(train_samples)
random.shuffle(val_samples)
random.shuffle(test_samples)

print("Split Summary:", flush=True)
print(f"  Train: {len(train_samples):,} samples ({len(train_samples)/len(all_samples)*100:.1f}%)", flush=True)
print(f"  Val: {len(val_samples):,} samples ({len(val_samples)/len(all_samples)*100:.1f}%)", flush=True)
print(f"  Test: {len(test_samples):,} samples ({len(test_samples)/len(all_samples)*100:.1f}%)", flush=True)
print(f"  Total: {len(train_samples) + len(val_samples) + len(test_samples):,} samples", flush=True)
print(flush=True)

# Verify category distribution in each split
print("Category Distribution by Split:", flush=True)
for split_name, split_samples in [('Train', train_samples), ('Val', val_samples), ('Test', test_samples)]:
    category_counts = Counter([s['label'] for s in split_samples])
    print(f"\n  {split_name}:", flush=True)
    for cat_id in range(NUM_CATEGORIES):
        count = category_counts.get(cat_id, 0)
        percentage = (count / len(split_samples) * 100) if split_samples else 0
        print(f"    Cat {cat_id}: {count} ({percentage:.1f}%)", flush=True)

print(flush=True)

# Save split datasets
print("Saving split datasets...", flush=True)

with open(TRAIN_FILE, 'w', encoding='utf-8') as f:
    for sample in train_samples:
        f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"  {TRAIN_FILE}: {len(train_samples):,} samples", flush=True)

with open(VAL_FILE, 'w', encoding='utf-8') as f:
    for sample in val_samples:
        f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"  {VAL_FILE}: {len(val_samples):,} samples", flush=True)

with open(TEST_FILE, 'w', encoding='utf-8') as f:
    for sample in test_samples:
        f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"  {TEST_FILE}: {len(test_samples):,} samples", flush=True)

print(flush=True)
print("Data splitting complete!\n", flush=True)
print("="*70, flush=True)

## Step 8: Data Leakage Detection & Diagnostics

**CRITICAL: Comprehensive data leakage analysis to ensure model validity**

In [None]:
print("="*70, flush=True)
print("STEP 8: DATA LEAKAGE DETECTION & DIAGNOSTICS", flush=True)
print("="*70 + "\n", flush=True)

# CRITICAL DATA LEAKAGE CHECKS
# This section ensures NO information leaks between train/val/test splits

print("CRITICAL DATA INTEGRITY CHECKS\n", flush=True)
print("-" * 70, flush=True)

# CHECK 1: Template Overlap Detection
print("\nCHECK 1: Template Overlap Detection", flush=True)
print("-" * 70, flush=True)

train_texts = set([s['text'] for s in train_samples])
val_texts = set([s['text'] for s in val_samples])
test_texts = set([s['text'] for s in test_samples])

train_val_text_overlap = train_texts.intersection(val_texts)
train_test_text_overlap = train_texts.intersection(test_texts)
val_test_text_overlap = val_texts.intersection(test_texts)

print(f"Train samples: {len(train_texts):,}", flush=True)
print(f"Val samples: {len(val_texts):,}", flush=True)
print(f"Test samples: {len(test_texts):,}", flush=True)
print(flush=True)

print(f"Train-Val template overlap: {len(train_val_text_overlap)} (MUST be 0)", flush=True)
print(f"Train-Test template overlap: {len(train_test_text_overlap)} (MUST be 0)", flush=True)
print(f"Val-Test template overlap: {len(val_test_text_overlap)} (MUST be 0)", flush=True)
print(flush=True)

template_overlap_ok = (len(train_val_text_overlap) == 0 and 
                       len(train_test_text_overlap) == 0 and 
                       len(val_test_text_overlap) == 0)

if template_overlap_ok:
    print("SUCCESS: ZERO template overlap detected!", flush=True)
    print("Data splits are completely isolated.\n", flush=True)
else:
    print("CRITICAL WARNING: Template overlap detected!", flush=True)
    if train_val_text_overlap:
        print(f"  Train-Val overlap: {list(train_val_text_overlap)[:3]}", flush=True)
    if train_test_text_overlap:
        print(f"  Train-Test overlap: {list(train_test_text_overlap)[:3]}", flush=True)
    if val_test_text_overlap:
        print(f"  Val-Test overlap: {list(val_test_text_overlap)[:3]}", flush=True)
    print(flush=True)

# CHECK 2: Company Overlap Detection
print("\nCHECK 2: Company Overlap Detection", flush=True)
print("-" * 70, flush=True)

train_companies = set([s['metadata']['company'] for s in train_samples])
val_companies = set([s['metadata']['company'] for s in val_samples])
test_companies = set([s['metadata']['company'] for s in test_samples])

train_val_company_overlap = train_companies.intersection(val_companies)
train_test_company_overlap = train_companies.intersection(test_companies)
val_test_company_overlap = val_companies.intersection(test_companies)

print(f"Train companies: {len(train_companies)}", flush=True)
print(f"Val companies: {len(val_companies)}", flush=True)
print(f"Test companies: {len(test_companies)}", flush=True)
print(flush=True)

print(f"Train-Val company overlap: {len(train_val_company_overlap)} (MUST be 0)", flush=True)
print(f"Train-Test company overlap: {len(train_test_company_overlap)} (MUST be 0)", flush=True)
print(f"Val-Test company overlap: {len(val_test_company_overlap)} (MUST be 0)", flush=True)
print(flush=True)

company_overlap_ok = (len(train_val_company_overlap) == 0 and 
                      len(train_test_company_overlap) == 0 and 
                      len(val_test_company_overlap) == 0)

if company_overlap_ok:
    print("SUCCESS: ZERO company overlap detected!", flush=True)
    print("Companies are completely isolated across splits.\n", flush=True)
else:
    print("CRITICAL WARNING: Company overlap detected!", flush=True)
    if train_val_company_overlap:
        print(f"  Train-Val: {train_val_company_overlap}", flush=True)
    if train_test_company_overlap:
        print(f"  Train-Test: {train_test_company_overlap}", flush=True)
    if val_test_company_overlap:
        print(f"  Val-Test: {val_test_company_overlap}", flush=True)
    print(flush=True)

# CHECK 3: Template Hash Overlap (Double verification)
print("\nCHECK 3: Template Hash Overlap (Double Verification)", flush=True)
print("-" * 70, flush=True)

train_hashes = set([s['template_id'] for s in train_samples])
val_hashes = set([s['template_id'] for s in val_samples])
test_hashes = set([s['template_id'] for s in test_samples])

train_val_hash_overlap = train_hashes.intersection(val_hashes)
train_test_hash_overlap = train_hashes.intersection(test_hashes)
val_test_hash_overlap = val_hashes.intersection(test_hashes)

print(f"Train-Val hash overlap: {len(train_val_hash_overlap)} (MUST be 0)", flush=True)
print(f"Train-Test hash overlap: {len(train_test_hash_overlap)} (MUST be 0)", flush=True)
print(f"Val-Test hash overlap: {len(val_test_hash_overlap)} (MUST be 0)", flush=True)
print(flush=True)

hash_overlap_ok = (len(train_val_hash_overlap) == 0 and 
                   len(train_test_hash_overlap) == 0 and 
                   len(val_test_hash_overlap) == 0)

if hash_overlap_ok:
    print("SUCCESS: ZERO hash overlap detected!", flush=True)
    print("Template uniqueness verified.\n", flush=True)
else:
    print("CRITICAL WARNING: Hash overlap detected!", flush=True)
    print(f"  This indicates duplicate templates exist.\n", flush=True)

# CHECK 4: Similarity Analysis (Sample-based)
print("\nCHECK 4: Cross-Split Similarity Analysis", flush=True)
print("-" * 70, flush=True)

def calculate_text_similarity(text1, text2):
    """Calculate word-based similarity"""
    words1 = set(text1.lower().split())
    words2 = set(text2.lower().split())
    if not words1 or not words2:
        return 0.0
    intersection = words1.intersection(words2)
    union = words1.union(words2)
    return len(intersection) / len(union) if union else 0.0

# Sample 100 random pairs from each split combination
sample_size = min(100, len(val_samples), len(test_samples))

print(f"Analyzing {sample_size} random sample pairs...\n", flush=True)

# Train-Val similarity
train_val_similarities = []
for _ in range(sample_size):
    train_sample = random.choice(train_samples)
    val_sample = random.choice(val_samples)
    sim = calculate_text_similarity(train_sample['text'], val_sample['text'])
    train_val_similarities.append(sim)

# Train-Test similarity  
train_test_similarities = []
for _ in range(sample_size):
    train_sample = random.choice(train_samples)
    test_sample = random.choice(test_samples)
    sim = calculate_text_similarity(train_sample['text'], test_sample['text'])
    train_test_similarities.append(sim)

# Val-Test similarity
val_test_similarities = []
for _ in range(sample_size):
    val_sample = random.choice(val_samples)
    test_sample = random.choice(test_samples)
    sim = calculate_text_similarity(val_sample['text'], test_sample['text'])
    val_test_similarities.append(sim)

avg_train_val_sim = np.mean(train_val_similarities)
avg_train_test_sim = np.mean(train_test_similarities)
avg_val_test_sim = np.mean(val_test_similarities)

max_train_val_sim = np.max(train_val_similarities)
max_train_test_sim = np.max(train_test_similarities)
max_val_test_sim = np.max(val_test_similarities)

print(f"Train-Val similarity:", flush=True)
print(f"  Average: {avg_train_val_sim:.4f}", flush=True)
print(f"  Maximum: {max_train_val_sim:.4f}", flush=True)
print(flush=True)

print(f"Train-Test similarity:", flush=True)
print(f"  Average: {avg_train_test_sim:.4f}", flush=True)
print(f"  Maximum: {max_train_test_sim:.4f}", flush=True)
print(flush=True)

print(f"Val-Test similarity:", flush=True)
print(f"  Average: {avg_val_test_sim:.4f}", flush=True)
print(f"  Maximum: {max_val_test_sim:.4f}", flush=True)
print(flush=True)

SIMILARITY_THRESHOLD = 0.85
high_similarity_detected = (max_train_val_sim > SIMILARITY_THRESHOLD or 
                            max_train_test_sim > SIMILARITY_THRESHOLD or 
                            max_val_test_sim > SIMILARITY_THRESHOLD)

if not high_similarity_detected:
    print(f"SUCCESS: All similarities below threshold ({SIMILARITY_THRESHOLD})", flush=True)
    print("Templates are sufficiently diverse.\n", flush=True)
else:
    print(f"WARNING: High similarity detected (>{SIMILARITY_THRESHOLD})", flush=True)
    print("Review template diversity.\n", flush=True)

# FINAL VERDICT
print("=" * 70, flush=True)
print("DATA LEAKAGE DETECTION - FINAL VERDICT", flush=True)
print("=" * 70 + "\n", flush=True)

all_checks_passed = (template_overlap_ok and 
                     company_overlap_ok and 
                     hash_overlap_ok and 
                     not high_similarity_detected)

if all_checks_passed:
    print("SUCCESS: ALL DATA INTEGRITY CHECKS PASSED!", flush=True)
    print(flush=True)
    print("Data leakage prevention verified:", flush=True)
    print("  - ZERO template overlap", flush=True)
    print("  - ZERO company overlap", flush=True)
    print("  - ZERO hash overlap", flush=True)
    print(f"  - Low cross-split similarity (<{SIMILARITY_THRESHOLD})", flush=True)
    print(flush=True)
    print("The dataset is READY for training.", flush=True)
    print("Model evaluation will be VALID and UNBIASED.\n", flush=True)
else:
    print("WARNING: Some integrity checks failed!", flush=True)
    print("Review the warnings above before proceeding.\n", flush=True)

print("=" * 70, flush=True)

## Step 9: Load Datasets from Files

Load the train/val/test datasets from JSONL files for model training.

In [None]:
print("="*70, flush=True)
print("STEP 9: LOADING DATASETS FROM FILES", flush=True)
print("="*70 + "\n", flush=True)

# Load datasets from saved JSONL files (using variables from Step 2)
train_file = TRAIN_FILE
val_file = VAL_FILE
test_file = TEST_FILE

print("Loading datasets from:", flush=True)
print(f"  Train: {train_file}", flush=True)
print(f"  Val: {val_file}", flush=True)
print(f"  Test: {test_file}\n", flush=True)

# Load train dataset
with open(train_file, 'r', encoding='utf-8') as f:
    train_data = [json.loads(line.strip()) for line in f]

# Load validation dataset
with open(val_file, 'r', encoding='utf-8') as f:
    val_data = [json.loads(line.strip()) for line in f]

# Load test dataset
with open(test_file, 'r', encoding='utf-8') as f:
    test_data = [json.loads(line.strip()) for line in f]

print("Dataset Sizes:", flush=True)
print(f"  Train: {len(train_data):,} samples", flush=True)
print(f"  Val: {len(val_data):,} samples", flush=True)
print(f"  Test: {len(test_data):,} samples", flush=True)
print(f"  Total: {len(train_data) + len(val_data) + len(test_data):,} samples\n", flush=True)

# Verify category distribution (samples use 'label' key, not 'category')
print("Category Distribution (Train):", flush=True)
train_categories = {}
for sample in train_data:
    cat = sample['label']  # Changed from 'category' to 'label'
    train_categories[cat] = train_categories.get(cat, 0) + 1

for cat_id in sorted(train_categories.keys()):
    count = train_categories[cat_id]
    percentage = (count / len(train_data)) * 100
    cat_name = PDPL_CATEGORIES[cat_id]['en']
    print(f"  Category {cat_id} ({cat_name}): {count:,} ({percentage:.1f}%)", flush=True)

print("\nDatasets loaded successfully!", flush=True)
print("=" * 70, flush=True)

## Step 10: BERT Tokenizer Setup

Initialize BERT-base-uncased tokenizer for English text processing.

In [None]:
print("="*70, flush=True)
print("STEP 10: BERT TOKENIZER SETUP", flush=True)
print("="*70 + "\n", flush=True)

# Load BERT tokenizer (using MODEL_NAME from Step 2)
print(f"Loading tokenizer: {MODEL_NAME}", flush=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"\nTokenizer loaded successfully!", flush=True)
print(f"  Model: {MODEL_NAME}", flush=True)
print(f"  Vocab size: {tokenizer.vocab_size:,}", flush=True)
print(f"  Max model length: {tokenizer.model_max_length}", flush=True)
print(f"  Case: uncased (lowercase)", flush=True)
print(f"  Tokenization: WordPiece\n", flush=True)

# Test tokenizer with sample text
sample_text = "VNG Corporation must process customer data lawfully under PDPL 2025."
print("Tokenizer Test:", flush=True)
print(f"  Text: {sample_text}", flush=True)

tokens = tokenizer.tokenize(sample_text)
print(f"  Tokens: {tokens[:10]}... ({len(tokens)} total)", flush=True)

encoded = tokenizer.encode(sample_text, add_special_tokens=True)
print(f"  Token IDs: {encoded[:10]}... ({len(encoded)} total)", flush=True)

decoded = tokenizer.decode(encoded)
print(f"  Decoded: {decoded}\n", flush=True)

# Tokenize function for dataset
MAX_LENGTH = 256  # Standard for BERT

def tokenize_function(examples):
    """Tokenize text samples with padding and truncation"""
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors='pt'
    )

print("Tokenization function defined.", flush=True)
print(f"Max sequence length: {MAX_LENGTH} tokens", flush=True)
print("Padding: max_length", flush=True)
print("Truncation: enabled\n", flush=True)

print("=" * 70, flush=True)

## Step 11: Model Loading and Preparation

Load BERT-base-uncased model and apply English-specific hyperparameters.

In [None]:
print("="*70, flush=True)
print("STEP 11: MODEL LOADING AND PREPARATION", flush=True)
print("="*70 + "\n", flush=True)

# Load BERT model for sequence classification
num_labels = len(PDPL_CATEGORIES)
print(f"Loading model: {MODEL_NAME}", flush=True)
print(f"Number of labels: {num_labels}", flush=True)
print(f"Device: {device}\n", flush=True)

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    problem_type="single_label_classification"
)

print("Model loaded successfully!", flush=True)
print(f"  Model class: {type(model).__name__}", flush=True)
print(f"  Number of parameters: {model.num_parameters():,}", flush=True)
print(f"  Number of labels: {model.config.num_labels}\n", flush=True)

# Apply English hyperparameters (access directly from TRAINING_CONFIG)
print("Applying English-specific hyperparameters:", flush=True)

# Set dropout rate
model.config.hidden_dropout_prob = TRAINING_CONFIG['hidden_dropout_prob']
model.config.attention_probs_dropout_prob = TRAINING_CONFIG['attention_probs_dropout_prob']

print(f"  Hidden dropout: {model.config.hidden_dropout_prob}", flush=True)
print(f"  Attention dropout: {model.config.attention_probs_dropout_prob}", flush=True)
print(f"  Learning rate: {TRAINING_CONFIG['learning_rate']}", flush=True)
print(f"  Label smoothing: {TRAINING_CONFIG['label_smoothing_factor']}", flush=True)
print(f"  Epochs: {TRAINING_CONFIG['num_train_epochs']}", flush=True)
print(f"  Batch size: {TRAINING_CONFIG['per_device_train_batch_size']}\n", flush=True)

# Move model to device
model = model.to(device)
print(f"Model moved to: {device}", flush=True)

# Display model architecture summary
print("\nModel Architecture Summary:", flush=True)
print(f"  Embedding dim: {model.config.hidden_size}", flush=True)
print(f"  Attention heads: {model.config.num_attention_heads}", flush=True)
print(f"  Hidden layers: {model.config.num_hidden_layers}", flush=True)
print(f"  Intermediate size: {model.config.intermediate_size}", flush=True)
print(f"  Max position embeddings: {model.config.max_position_embeddings}\n", flush=True)

# Label mapping
print("Label Mapping:", flush=True)
for cat_id, cat_data in PDPL_CATEGORIES.items():
    print(f"  {cat_id}: {cat_data['en']}", flush=True)

print("\n" + "=" * 70, flush=True)

## Step 12: Prepare Dataset Objects for Training

Create HuggingFace Dataset objects and apply tokenization.

In [None]:
print("="*70, flush=True)
print("STEP 12: PREPARE DATASET OBJECTS FOR TRAINING", flush=True)
print("="*70 + "\n", flush=True)

# Convert to HuggingFace Dataset format
from datasets import Dataset

print("Converting data to HuggingFace Dataset format...\n", flush=True)

# Extract text and labels (using 'label' key, not 'category')
train_texts = [sample['text'] for sample in train_data]
train_labels = [sample['label'] for sample in train_data]

val_texts = [sample['text'] for sample in val_data]
val_labels = [sample['label'] for sample in val_data]

test_texts = [sample['text'] for sample in test_data]
test_labels = [sample['label'] for sample in test_data]

# Create Dataset objects
train_dataset = Dataset.from_dict({
    'text': train_texts,
    'label': train_labels
})

val_dataset = Dataset.from_dict({
    'text': val_texts,
    'label': val_labels
})

test_dataset = Dataset.from_dict({
    'text': test_texts,
    'label': test_labels
})

print("Datasets created:", flush=True)
print(f"  Train: {len(train_dataset):,} samples", flush=True)
print(f"  Val: {len(val_dataset):,} samples", flush=True)
print(f"  Test: {len(test_dataset):,} samples\n", flush=True)

# Apply tokenization
print("Applying tokenization...", flush=True)

def tokenize_batch(examples):
    """Tokenize a batch of examples"""
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=MAX_LENGTH
    )

train_dataset = train_dataset.map(tokenize_batch, batched=True, batch_size=1000)
val_dataset = val_dataset.map(tokenize_batch, batched=True, batch_size=1000)
test_dataset = test_dataset.map(tokenize_batch, batched=True, batch_size=1000)

print("Tokenization complete!\n", flush=True)

# Set dataset format for PyTorch
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

print("Dataset format set to PyTorch tensors.", flush=True)
print("\nDataset Columns:", flush=True)
print(f"  {train_dataset.column_names}", flush=True)

print("\nSample from train dataset:", flush=True)
sample = train_dataset[0]
print(f"  input_ids shape: {sample['input_ids'].shape}", flush=True)
print(f"  attention_mask shape: {sample['attention_mask'].shape}", flush=True)
print(f"  label: {sample['label']}", flush=True)

print("\nDatasets are ready for training!", flush=True)
print("=" * 70, flush=True)

## Step 13: Smart Training Callback for English Model

Implement intelligent training monitoring with English-specific thresholds.

In [None]:
print("="*70, flush=True)
print("STEP 13: SMART TRAINING CALLBACK FOR ENGLISH MODEL", flush=True)
print("="*70 + "\n", flush=True)

class SmartTrainingCallback(TrainerCallback):
    """
    Smart callback for English PDPL model training with:
    - Early high accuracy detection (0.90 threshold)
    - Extreme overfitting prevention (0.95 threshold)
    - Training metrics monitoring
    """
    
    def __init__(self, early_high_accuracy_threshold=0.90, extreme_overfitting_threshold=0.95):
        self.early_high_accuracy_threshold = early_high_accuracy_threshold
        self.extreme_overfitting_threshold = extreme_overfitting_threshold
        self.best_eval_accuracy = 0
        self.best_epoch = 0
        self.training_history = []
        
    def on_epoch_end(self, args, state, control, **kwargs):
        """Called at the end of each epoch"""
        
        # Get latest metrics
        if state.log_history:
            latest_logs = state.log_history[-1]
            epoch = latest_logs.get('epoch', 0)
            
            # Extract metrics
            train_loss = latest_logs.get('loss', None)
            eval_loss = latest_logs.get('eval_loss', None)
            eval_accuracy = latest_logs.get('eval_accuracy', None)
            
            # Store in history
            self.training_history.append({
                'epoch': epoch,
                'train_loss': train_loss,
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy
            })
            
            if eval_accuracy is not None:
                # Update best accuracy
                if eval_accuracy > self.best_eval_accuracy:
                    self.best_eval_accuracy = eval_accuracy
                    self.best_epoch = epoch
                
                print(f"\nEpoch {epoch:.1f} Summary:", flush=True)
                if train_loss:
                    print(f"  Train Loss: {train_loss:.4f}", flush=True)
                if eval_loss:
                    print(f"  Eval Loss: {eval_loss:.4f}", flush=True)
                print(f"  Eval Accuracy: {eval_accuracy:.4f} ({eval_accuracy*100:.2f}%)", flush=True)
                print(f"  Best Accuracy: {self.best_eval_accuracy:.4f} (Epoch {self.best_epoch:.1f})", flush=True)
                
                # Check for early high accuracy (English target: 88-92%)
                if eval_accuracy >= self.early_high_accuracy_threshold:
                    print(f"\nEARLY HIGH ACCURACY DETECTED!", flush=True)
                    print(f"  Accuracy {eval_accuracy:.4f} >= threshold {self.early_high_accuracy_threshold}", flush=True)
                    print(f"  Target range for English: 0.88-0.92 (88-92%)", flush=True)
                    
                    # Check if in target range
                    if 0.88 <= eval_accuracy <= 0.92:
                        print(f"  PERFECT! Within target range. Consider stopping.", flush=True)
                        control.should_training_stop = True
                    elif eval_accuracy > 0.92:
                        print(f"  WARNING: Above target (may indicate overfitting)", flush=True)
                
                # Check for extreme overfitting
                if eval_accuracy >= self.extreme_overfitting_threshold:
                    print(f"\nEXTREME OVERFITTING WARNING!", flush=True)
                    print(f"  Accuracy {eval_accuracy:.4f} >= {self.extreme_overfitting_threshold}", flush=True)
                    print(f"  Model may not generalize well to unseen data.", flush=True)
                    print(f"  Stopping training to prevent overfitting...", flush=True)
                    control.should_training_stop = True
                
                # Check for overfitting (train loss < eval loss significantly)
                if train_loss and eval_loss and train_loss < eval_loss * 0.5:
                    print(f"\n  WARNING: Potential overfitting detected", flush=True)
                    print(f"  Train loss ({train_loss:.4f}) << Eval loss ({eval_loss:.4f})", flush=True)
        
        return control
    
    def on_train_end(self, args, state, control, **kwargs):
        """Called at the end of training"""
        print("\n" + "="*70, flush=True)
        print("TRAINING COMPLETE - FINAL SUMMARY", flush=True)
        print("="*70, flush=True)
        print(f"\nBest Accuracy: {self.best_eval_accuracy:.4f} ({self.best_eval_accuracy*100:.2f}%)", flush=True)
        print(f"Best Epoch: {self.best_epoch:.1f}", flush=True)
        print(f"Total Epochs: {len(self.training_history)}", flush=True)
        
        # Target assessment
        print(f"\nEnglish Model Target: 88-92% accuracy", flush=True)
        if 0.88 <= self.best_eval_accuracy <= 0.92:
            print(f"SUCCESS: Best accuracy is WITHIN target range!", flush=True)
        elif self.best_eval_accuracy > 0.92:
            print(f"ABOVE TARGET: Consider using more regularization", flush=True)
        else:
            print(f"BELOW TARGET: Consider training longer or tuning hyperparameters", flush=True)
        
        print("="*70, flush=True)
        
        return control

# Initialize callback with English thresholds (use direct TRAINING_CONFIG keys)
smart_callback = SmartTrainingCallback(
    early_high_accuracy_threshold=TRAINING_CONFIG['early_high_accuracy_threshold'],
    extreme_overfitting_threshold=TRAINING_CONFIG['extreme_overfitting_threshold']
)

print("SmartTrainingCallback initialized!", flush=True)
print(f"  Early high accuracy threshold: {TRAINING_CONFIG['early_high_accuracy_threshold']} (90%)", flush=True)
print(f"  Extreme overfitting threshold: {TRAINING_CONFIG['extreme_overfitting_threshold']} (95%)", flush=True)
print(f"  Target accuracy range: 0.88-0.92 (88-92%)\n", flush=True)

print("=" * 70, flush=True)

## Step 14: Training Arguments Configuration

Configure all training parameters for English BERT model.

In [None]:
print("="*70, flush=True)
print("STEP 14: TRAINING ARGUMENTS CONFIGURATION", flush=True)
print("="*70 + "\n", flush=True)

# Configure training arguments (use standalone variables from Step 2)
training_args = TrainingArguments(
    output_dir=MODEL_SAVE_DIR,
    
    # Training hyperparameters
    num_train_epochs=TRAINING_CONFIG['num_train_epochs'],
    per_device_train_batch_size=TRAINING_CONFIG['per_device_train_batch_size'],
    per_device_eval_batch_size=TRAINING_CONFIG['per_device_eval_batch_size'],
    learning_rate=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay'],
    warmup_ratio=TRAINING_CONFIG['warmup_ratio'],
    
    # Evaluation settings
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    
    # Logging
    logging_dir=os.path.join(MODEL_SAVE_DIR, 'logs'),
    logging_strategy="steps",
    logging_steps=50,
    
    # Performance
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
    
    # Reproducibility
    seed=RANDOM_SEED,
    
    # Other settings
    remove_unused_columns=False,
    label_smoothing_factor=TRAINING_CONFIG['label_smoothing_factor'],
    report_to=["tensorboard"] if IN_COLAB else [],
)

print("Training Arguments Configuration:", flush=True)
print(f"\nOutput & Logging:", flush=True)
print(f"  Output dir: {training_args.output_dir}", flush=True)
print(f"  Logging dir: {training_args.logging_dir}", flush=True)
print(f"  Logging steps: {training_args.logging_steps}", flush=True)

print(f"\nTraining Hyperparameters:", flush=True)
print(f"  Epochs: {training_args.num_train_epochs}", flush=True)
print(f"  Train batch size: {training_args.per_device_train_batch_size}", flush=True)
print(f"  Eval batch size: {training_args.per_device_eval_batch_size}", flush=True)
print(f"  Learning rate: {training_args.learning_rate}", flush=True)
print(f"  Weight decay: {training_args.weight_decay}", flush=True)
print(f"  Warmup ratio: {training_args.warmup_ratio}", flush=True)
print(f"  Label smoothing: {training_args.label_smoothing_factor}", flush=True)

print(f"\nEvaluation Settings:", flush=True)
print(f"  Eval strategy: {training_args.eval_strategy}", flush=True)
print(f"  Save strategy: {training_args.save_strategy}", flush=True)
print(f"  Load best model: {training_args.load_best_model_at_end}", flush=True)
print(f"  Best model metric: {training_args.metric_for_best_model}", flush=True)

print(f"\nPerformance:", flush=True)
print(f"  FP16 (mixed precision): {training_args.fp16}", flush=True)
print(f"  Dataloader workers: {training_args.dataloader_num_workers}", flush=True)
print(f"  Device: {device}", flush=True)

print(f"\nReproducibility:", flush=True)
print(f"  Seed: {training_args.seed}\n", flush=True)

print("Training arguments configured successfully!", flush=True)
print("=" * 70, flush=True)

## Step 15: Compute Metrics Function

Define metrics computation for evaluation during training.

In [None]:
print("="*70, flush=True)
print("STEP 15: COMPUTE METRICS FUNCTION", flush=True)
print("="*70 + "\n", flush=True)

# Import f1_score
from sklearn.metrics import f1_score

# Define metrics computation function
def compute_metrics(eval_pred):
    """
    Compute accuracy and F1 score for evaluation
    """
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Calculate accuracy
    accuracy = accuracy_score(labels, predictions)
    
    # Calculate F1 score (macro average)
    f1 = f1_score(labels, predictions, average='macro')
    
    return {
        'accuracy': accuracy,
        'f1': f1
    }

print("Compute metrics function defined!", flush=True)
print("  Metrics:", flush=True)
print("    - Accuracy (primary metric)", flush=True)
print("    - F1 Score (macro average)\n", flush=True)

print("=" * 70, flush=True)

## Step 16: Trainer Initialization

Initialize HuggingFace Trainer with all components.

In [None]:
print("="*70, flush=True)
print("STEP 16: TRAINER INITIALIZATION", flush=True)
print("="*70 + "\n", flush=True)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[smart_callback]
)

print("Trainer initialized successfully!", flush=True)
print(f"\nTrainer Configuration:", flush=True)
print(f"  Model: {MODEL_NAME}", flush=True)
print(f"  Train samples: {len(train_dataset):,}", flush=True)
print(f"  Eval samples: {len(val_dataset):,}", flush=True)
print(f"  Callbacks: SmartTrainingCallback", flush=True)
print(f"  Compute metrics: accuracy, f1\n", flush=True)

# Calculate training steps
total_steps = len(train_dataset) // training_args.per_device_train_batch_size * training_args.num_train_epochs
warmup_steps = int(total_steps * training_args.warmup_ratio)

print(f"Training Steps Calculation:", flush=True)
print(f"  Steps per epoch: {len(train_dataset) // training_args.per_device_train_batch_size}", flush=True)
print(f"  Total epochs: {training_args.num_train_epochs}", flush=True)
print(f"  Total steps: {total_steps:,}", flush=True)
print(f"  Warmup steps: {warmup_steps:,} ({training_args.warmup_ratio*100:.0f}%)\n", flush=True)

print("Ready to start training!", flush=True)
print("=" * 70, flush=True)

## Step 17: Training Execution - CRITICAL

Execute the training process with real-time monitoring.

In [None]:
print("="*70, flush=True)
print("STEP 17: TRAINING EXECUTION - STARTING NOW", flush=True)
print("="*70 + "\n", flush=True)

print("ENGLISH BERT MODEL TRAINING", flush=True)
print(f"Model: {MODEL_NAME}", flush=True)
print(f"Dataset: {len(train_dataset):,} training samples", flush=True)
print(f"Target Accuracy: 88-92%", flush=True)
print(f"Max Epochs: {training_args.num_train_epochs}", flush=True)
print(f"Early Stop Threshold: {TRAINING_CONFIG['early_high_accuracy_threshold']*100}%\n", flush=True)

print("Training will automatically stop if:", flush=True)
print(f"  - Accuracy reaches {TRAINING_CONFIG['early_high_accuracy_threshold']*100}% and is in target range (88-92%)", flush=True)
print(f"  - Accuracy exceeds {TRAINING_CONFIG['extreme_overfitting_threshold']*100}% (overfitting prevention)", flush=True)
print(f"  - Maximum epochs ({training_args.num_train_epochs}) completed\n", flush=True)

print("="*70, flush=True)
print("TRAINING IN PROGRESS...", flush=True)
print("="*70 + "\n", flush=True)

# Start training
try:
    train_result = trainer.train()
    
    print("\n" + "="*70, flush=True)
    print("TRAINING COMPLETED SUCCESSFULLY!", flush=True)
    print("="*70 + "\n", flush=True)
    
    # Display training results
    print("Training Results:", flush=True)
    print(f"  Final train loss: {train_result.training_loss:.4f}", flush=True)
    print(f"  Total training time: {train_result.metrics.get('train_runtime', 0):.2f} seconds", flush=True)
    print(f"  Samples per second: {train_result.metrics.get('train_samples_per_second', 0):.2f}", flush=True)
    print(f"  Steps per second: {train_result.metrics.get('train_steps_per_second', 0):.2f}\n", flush=True)
    
    # Get evaluation metrics
    eval_result = trainer.evaluate()
    print("Final Evaluation Metrics:", flush=True)
    print(f"  Eval loss: {eval_result.get('eval_loss', 0):.4f}", flush=True)
    print(f"  Eval accuracy: {eval_result.get('eval_accuracy', 0):.4f} ({eval_result.get('eval_accuracy', 0)*100:.2f}%)", flush=True)
    print(f"  Eval F1 score: {eval_result.get('eval_f1', 0):.4f}\n", flush=True)
    
    # Assessment against target
    final_accuracy = eval_result.get('eval_accuracy', 0)
    print("Target Assessment:", flush=True)
    print(f"  Target range: 88-92%", flush=True)
    print(f"  Achieved: {final_accuracy*100:.2f}%", flush=True)
    
    if 0.88 <= final_accuracy <= 0.92:
        print("  STATUS: SUCCESS - Within target range!", flush=True)
    elif final_accuracy > 0.92:
        print("  STATUS: ABOVE TARGET - Excellent but may indicate overfitting", flush=True)
    elif final_accuracy >= 0.85:
        print("  STATUS: CLOSE - Near target, acceptable performance", flush=True)
    else:
        print("  STATUS: BELOW TARGET - Consider retraining with different hyperparameters", flush=True)
    
    print("\n" + "="*70, flush=True)
    
except Exception as e:
    print(f"\nTRAINING ERROR: {str(e)}", flush=True)
    print("Please check the error message above and retry.", flush=True)
    raise

## Step 18: Test Set Evaluation

Evaluate the trained model on the held-out test set.

In [None]:
print("="*70, flush=True)
print("STEP 18: TEST SET EVALUATION", flush=True)
print("="*70 + "\n", flush=True)

print(f"Evaluating on test set ({len(test_dataset):,} samples)...\n", flush=True)

# Evaluate on test set
test_results = trainer.evaluate(test_dataset)

print("Test Set Results:", flush=True)
print(f"  Test loss: {test_results.get('eval_loss', 0):.4f}", flush=True)
print(f"  Test accuracy: {test_results.get('eval_accuracy', 0):.4f} ({test_results.get('eval_accuracy', 0)*100:.2f}%)", flush=True)
print(f"  Test F1 score: {test_results.get('eval_f1', 0):.4f}\n", flush=True)

# Get predictions for detailed analysis
predictions_output = trainer.predict(test_dataset)
predictions = np.argmax(predictions_output.predictions, axis=1)
true_labels = predictions_output.label_ids

print("Prediction Statistics:", flush=True)
print(f"  Total predictions: {len(predictions):,}", flush=True)
print(f"  Correct predictions: {np.sum(predictions == true_labels):,}", flush=True)
print(f"  Incorrect predictions: {np.sum(predictions != true_labels):,}\n", flush=True)

# Calculate per-class metrics
print("Per-Category Performance:", flush=True)
print("-" * 70, flush=True)

from sklearn.metrics import classification_report

report = classification_report(
    true_labels, 
    predictions, 
    target_names=[PDPL_CATEGORIES[i]['en'] for i in range(len(PDPL_CATEGORIES))],
    digits=4
)
print(report, flush=True)

print("=" * 70, flush=True)

## Step 19: Confusion Matrix & Visualization

Visualize model performance with confusion matrix.

In [None]:
print("="*70, flush=True)
print("STEP 19: CONFUSION MATRIX & VISUALIZATION", flush=True)
print("="*70 + "\n", flush=True)

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Calculate confusion matrix
cm = confusion_matrix(true_labels, predictions)

print("Confusion Matrix (Test Set):", flush=True)
print(cm, flush=True)
print(flush=True)

# Create visualization
plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=[PDPL_CATEGORIES[i]['en'][:20] for i in range(len(PDPL_CATEGORIES))],
    yticklabels=[PDPL_CATEGORIES[i]['en'][:20] for i in range(len(PDPL_CATEGORIES))],
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - English BERT PDPL Model\nTest Set Performance', fontsize=14, pad=20)
plt.xlabel('Predicted Category', fontsize=12)
plt.ylabel('True Category', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

# Save confusion matrix
cm_path = os.path.join(TRAINING_CONFIG['model_save_dir'], 'confusion_matrix.png')
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to: {cm_path}", flush=True)

plt.show()

# Calculate per-class accuracy
print("\nPer-Category Accuracy:", flush=True)
print("-" * 70, flush=True)

for i in range(len(PDPL_CATEGORIES)):
    class_total = np.sum(cm[i, :])
    class_correct = cm[i, i]
    class_accuracy = class_correct / class_total if class_total > 0 else 0
    
    cat_name = PDPL_CATEGORIES[i]['en']
    print(f"Category {i} ({cat_name}):", flush=True)
    print(f"  Correct: {class_correct}/{class_total} ({class_accuracy*100:.2f}%)", flush=True)

print("\n" + "=" * 70, flush=True)

## Step 20: Model Export & Packaging

Export the trained model with metadata and deployment guide.

In [None]:
print("="*70, flush=True)
print("STEP 20: MODEL EXPORT & PACKAGING", flush=True)
print("="*70 + "\n", flush=True)

# Determine run number with fallback
try:
    from datetime import datetime
    run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    export_dir = f"veriaidpo_en_run_{run_timestamp}"
except:
    export_dir = "veriaidpo_en_run_final"

export_path = os.path.join(OUTPUT_DIR, export_dir)

print(f"Exporting model to: {export_path}\n", flush=True)

# Save model and tokenizer
model.save_pretrained(export_path)
tokenizer.save_pretrained(export_path)

print("Model saved successfully!", flush=True)
print(f"  Model files: {export_path}", flush=True)

# Save training configuration
config_path = os.path.join(export_path, 'training_config.json')
# Create a serializable copy of TRAINING_CONFIG
training_config_export = {k: v for k, v in TRAINING_CONFIG.items() if isinstance(v, (str, int, float, bool, list, dict))}
with open(config_path, 'w', encoding='utf-8') as f:
    json.dump(training_config_export, f, indent=2, ensure_ascii=False)

print(f"  Training config: {config_path}", flush=True)

# Save test results
results_path = os.path.join(export_path, 'test_results.json')
test_results_dict = {
    'test_accuracy': float(test_results.get('eval_accuracy', 0)),
    'test_f1': float(test_results.get('eval_f1', 0)),
    'test_loss': float(test_results.get('eval_loss', 0)),
    'dataset_size': {
        'train': len(train_dataset),
        'val': len(val_dataset),
        'test': len(test_dataset)
    },
    'model_name': MODEL_NAME,
    'num_labels': num_labels
}

with open(results_path, 'w', encoding='utf-8') as f:
    json.dump(test_results_dict, f, indent=2)

print(f"  Test results: {results_path}", flush=True)

# Save label mapping
labels_path = os.path.join(export_path, 'label_mapping.json')
with open(labels_path, 'w', encoding='utf-8') as f:
    json.dump(PDPL_CATEGORIES, f, indent=2, ensure_ascii=False)

print(f"  Label mapping: {labels_path}", flush=True)

# Create deployment guide
deployment_guide = f"""# VeriAIDPO English Model - Deployment Guide

## Model Information
- Model: {MODEL_NAME}
- Language: English
- Task: PDPL 2025 Compliance Classification
- Accuracy: {test_results.get('eval_accuracy', 0)*100:.2f}%
- F1 Score: {test_results.get('eval_f1', 0):.4f}

## Dataset
- Training samples: {len(train_dataset):,}
- Validation samples: {len(val_dataset):,}
- Test samples: {len(test_dataset):,}
- Total samples: {len(train_dataset) + len(val_dataset) + len(test_dataset):,}

## Categories (8 PDPL Categories)
"""

for cat_id, cat_data in PDPL_CATEGORIES.items():
    deployment_guide += f"{cat_id}. {cat_data['en']} (VN: {cat_data['vi']})\n"

deployment_guide += f"""
## Inference Example

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model
model_path = "{export_path}"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()

# Example inference
text = "VNG must process customer data lawfully under PDPL 2025."
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)

with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=-1).item()
    confidence = predictions[0][predicted_class].item()

print(f"Predicted category: {{predicted_class}}")
print(f"Confidence: {{confidence*100:.2f}}%")
```

## System Requirements
- Python 3.8+
- PyTorch 2.0+
- Transformers 4.30+
- RAM: 2-4GB (inference)
- GPU: Optional (CPU inference supported)

## Performance
- Inference time: 40-80ms per sample (CPU)
- Batch inference: Supported
- Max sequence length: 256 tokens

## Integration with Vietnamese Model
This English model is designed to work alongside the Vietnamese PhoBERT model for bilingual PDPL compliance detection.

## Support
For issues or questions, refer to VeriSyntra documentation.
"""

guide_path = os.path.join(export_path, 'DEPLOYMENT_GUIDE_EN.md')
with open(guide_path, 'w', encoding='utf-8') as f:
    f.write(deployment_guide)

print(f"  Deployment guide: {guide_path}\n", flush=True)

print("Export complete!", flush=True)
print(f"\nExported files:", flush=True)
print(f"  - Model weights (pytorch_model.bin)", flush=True)
print(f"  - Tokenizer files", flush=True)
print(f"  - Configuration (config.json)", flush=True)
print(f"  - Training config (training_config.json)", flush=True)
print(f"  - Test results (test_results.json)", flush=True)
print(f"  - Label mapping (label_mapping.json)", flush=True)
print(f"  - Deployment guide (DEPLOYMENT_GUIDE_EN.md)", flush=True)

print("\n" + "=" * 70, flush=True)

## Step 21: Create ZIP Archive for Download

Package all model files into a ZIP archive for easy download.

In [None]:
print("="*70, flush=True)
print("STEP 21: CREATE ZIP ARCHIVE FOR DOWNLOAD", flush=True)
print("="*70 + "\n", flush=True)

import shutil

zip_filename = f"{export_dir}.zip"
zip_path = os.path.join(TRAINING_CONFIG['output_dir'], zip_filename)

print(f"Creating ZIP archive: {zip_filename}\n", flush=True)

# Create ZIP archive
try:
    shutil.make_archive(
        base_name=os.path.join(TRAINING_CONFIG['output_dir'], export_dir),
        format='zip',
        root_dir=TRAINING_CONFIG['output_dir'],
        base_dir=export_dir
    )
    
    print("ZIP archive created successfully!", flush=True)
    print(f"  File: {zip_path}", flush=True)
    
    # Get ZIP file size
    zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    print(f"  Size: {zip_size_mb:.2f} MB\n", flush=True)
    
    # List contents
    import zipfile
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        file_list = zip_ref.namelist()
        print(f"ZIP archive contains {len(file_list)} files:", flush=True)
        for filename in file_list[:10]:
            print(f"  - {filename}", flush=True)
        if len(file_list) > 10:
            print(f"  ... and {len(file_list) - 10} more files", flush=True)
    
    print(flush=True)
    
except Exception as e:
    print(f"Error creating ZIP: {str(e)}", flush=True)
    raise

print("=" * 70, flush=True)

## Step 22: Download Model (Google Colab Only)

Download the ZIP archive directly from Google Colab.

In [None]:
print("="*70, flush=True)
print("STEP 22: DOWNLOAD MODEL (GOOGLE COLAB ONLY)", flush=True)
print("="*70 + "\n", flush=True)

if IN_COLAB:
    print("Initiating download...\n", flush=True)
    
    try:
        from google.colab import files
        
        print(f"Downloading: {zip_filename}", flush=True)
        print(f"Size: {zip_size_mb:.2f} MB", flush=True)
        print("Please wait for the download to complete...\n", flush=True)
        
        files.download(zip_path)
        
        print("Download initiated successfully!", flush=True)
        print("Check your browser's download folder.\n", flush=True)
        
    except Exception as e:
        print(f"Download error: {str(e)}", flush=True)
        print(f"\nAlternative: Access the file directly at:", flush=True)
        print(f"  {zip_path}\n", flush=True)
else:
    print("NOT running in Google Colab.", flush=True)
    print("Model files are available at:", flush=True)
    print(f"  {export_path}", flush=True)
    print(f"\nZIP archive available at:", flush=True)
    print(f"  {zip_path}", flush=True)
    print(f"  Size: {zip_size_mb:.2f} MB\n", flush=True)

print("=" * 70, flush=True)

## Step 23: Training Completion Summary

Final summary of the entire English model training pipeline.

In [None]:
print("\n" + "="*70, flush=True)
print("VERIAIDPO ENGLISH MODEL TRAINING - COMPLETION SUMMARY", flush=True)
print("="*70 + "\n", flush=True)

print("TRAINING PIPELINE COMPLETED SUCCESSFULLY!", flush=True)
print(flush=True)

print("Model Information:", flush=True)
print(f"  Base Model: {model_name}", flush=True)
print(f"  Language: English", flush=True)
print(f"  Task: PDPL 2025 Compliance Classification", flush=True)
print(f"  Categories: 8 PDPL categories", flush=True)
print(flush=True)

print("Dataset Statistics:", flush=True)
print(f"  Total samples: {len(train_dataset) + len(val_dataset) + len(test_dataset):,}", flush=True)
print(f"  Training: {len(train_dataset):,} samples (70%)", flush=True)
print(f"  Validation: {len(val_dataset):,} samples (15%)", flush=True)
print(f"  Test: {len(test_dataset):,} samples (15%)", flush=True)
print(flush=True)

print("Model Performance:", flush=True)
print(f"  Test Accuracy: {test_results.get('eval_accuracy', 0)*100:.2f}%", flush=True)
print(f"  Test F1 Score: {test_results.get('eval_f1', 0):.4f}", flush=True)
print(f"  Target Range: 88-92%", flush=True)

final_accuracy = test_results.get('eval_accuracy', 0)
if 0.88 <= final_accuracy <= 0.92:
    print(f"  Status: SUCCESS - Within target range!", flush=True)
elif final_accuracy > 0.92:
    print(f"  Status: EXCELLENT - Above target", flush=True)
else:
    print(f"  Status: Needs improvement", flush=True)
print(flush=True)

print("Data Integrity:", flush=True)
print(f"  Template overlap: 0 (verified)", flush=True)
print(f"  Company overlap: 0 (verified)", flush=True)
print(f"  Data leakage: None detected", flush=True)
print(flush=True)

print("Exported Files:", flush=True)
print(f"  Model directory: {export_path}", flush=True)
print(f"  ZIP archive: {zip_path}", flush=True)
print(f"  ZIP size: {zip_size_mb:.2f} MB", flush=True)
print(flush=True)

print("What's Included:", flush=True)
print("  - Trained BERT model (pytorch_model.bin)", flush=True)
print("  - Tokenizer files", flush=True)
print("  - Model configuration", flush=True)
print("  - Training configuration", flush=True)
print("  - Test results (JSON)", flush=True)
print("  - Label mapping (8 categories)", flush=True)
print("  - Deployment guide (DEPLOYMENT_GUIDE_EN.md)", flush=True)
print("  - Confusion matrix visualization", flush=True)
print(flush=True)

print("Next Steps:", flush=True)
print("  1. Download the ZIP archive (if in Colab)", flush=True)
print("  2. Extract the files locally", flush=True)
print("  3. Review DEPLOYMENT_GUIDE_EN.md for integration instructions", flush=True)
print("  4. Test inference with sample English PDPL texts", flush=True)
print("  5. Integrate with Vietnamese PhoBERT model for bilingual system", flush=True)
print(flush=True)

print("Bilingual System Status:", flush=True)
print("  Vietnamese Model: PhoBERT-base (6,984 samples, ~100% accuracy)", flush=True)
print(f"  English Model: BERT-base-uncased (5,000 samples, {final_accuracy*100:.2f}% accuracy)", flush=True)
print(f"  Combined System: 980MB total, 92-96% weighted accuracy target", flush=True)
print(flush=True)

print("="*70, flush=True)
print("TRAINING COMPLETE - ALL STEPS EXECUTED SUCCESSFULLY!", flush=True)
print("="*70, flush=True)
print(flush=True)

print("Thank you for using VeriAIDPO English Model Training Pipeline!", flush=True)
print("For support, refer to VeriSyntra documentation.", flush=True)