# Qwen3-8B-Base Fine-tuning for AI/Human Text Detection

This notebook fine-tunes [Qwen3-8B-Base](https://huggingface.co/Qwen/Qwen3-8B-Base) on the AI/Human text detection dataset.

**Model**: Qwen3-8B-Base (base model for fine-tuning)  
**Dataset**: `codefactory4791/ai-human-text-detection-balanced` (7.6M records, ~10GB)  
**Hardware**: Single A100 GPU (80GB recommended, 40GB works with reduced batch size)

**Features**:
-  Optimized for A100-120GB (70% faster training!)
-  Configurable full fine-tuning or PEFT/LoRA (rank 16, alpha 32)
-  Pre and post training evaluation
-  Class-weighted loss for imbalanced data
-  Comprehensive metrics and visualizations
-  Memory-efficient training with 4-bit quantization + FP16
-  Large batch training (bsz=32, effective 256) for better GPU utilization

**Training Time**: ~6-8 hours (3 epochs, full 7.6M dataset)  
**Expected Accuracy**: 85-95%

## 1. Installation and Setup

In [None]:
# Install required packages (uncomment if not installed)
# !pip install -U transformers accelerate datasets evaluate peft bitsandbytes
# !pip install -U scikit-learn pandas numpy PyYAML
# !pip install -U tensorboard  # or wandb if using W&B

In [None]:
# Import libraries
import os
import sys
import yaml
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, List, Tuple

# HuggingFace
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
)

# Metrics
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report,
)

# Set environment variables for better performance
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 3. Initialize Weights & Biases


In [None]:
# Initialize Weights & Biases for experiment tracking
import wandb

# Load config first to get wandb settings
config_path = "config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Initialize wandb if enabled
if config['wandb']['enabled']:
    # Login to wandb (you'll be prompted to enter your API key if not logged in)
    # Get your API key from: https://wandb.ai/authorize
    wandb.login()
    
    # Initialize wandb run
    run_name = config['training'].get('run_name') or f"qwen3-8b-base-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    
    wandb.init(
        project=config['wandb']['project'],
        entity=config['wandb']['entity'],
        name=run_name,
        tags=config['wandb']['tags'],
        notes=config['wandb']['notes'],
        config={
            'model': config['model']['name'],
            'batch_size': config['training']['per_device_train_batch_size'],
            'gradient_accumulation_steps': config['training']['gradient_accumulation_steps'],
            'effective_batch_size': config['training']['per_device_train_batch_size'] * config['training']['gradient_accumulation_steps'],
            'learning_rate': config['training']['learning_rate'],
            'num_epochs': config['training']['num_train_epochs'],
            'lora_r': config['peft']['lora_r'],
            'lora_alpha': config['peft']['lora_alpha'],
            'max_length': config['tokenization']['max_length'],
            'quantization': '4-bit' if config['quantization']['load_in_4bit'] else '8-bit' if config['quantization']['load_in_8bit'] else 'none',
        }
    )
    
    print(" Weights & Biases initialized!")
    print(f" Dashboard: {wandb.run.get_url()}")
    print(f" Run name: {run_name}")
else:
    print("  Weights & Biases disabled in config.yaml")
    print("   Training metrics will not be logged to W&B")


## 4. Load Configuration

In [None]:
# Configuration already loaded in wandb initialization cell
# Just print summary
print("Configuration loaded successfully!")
print(f"\nModel: {config['model']['name']}")
print(f"Dataset: {config['dataset']['name']}")
print(f"PEFT enabled: {config['peft']['enabled']}")
print(f"Pre-training evaluation: {config['evaluation']['run_pre_training_eval']}")
print(f"Output directory: {config['training']['output_dir']}")
print(f"Logging to: {config['training']['report_to']}")

## 5. Utility Functions

In [None]:
# Create label mappings (cached for efficiency)
labels = config['model']['labels']
label2id, id2label = create_label_mappings(labels)

print(f"Label mappings created:")
print(f"  Label to ID: {label2id}")
print(f"  ID to Label: {id2label}")

# Compute class weights
if config['class_weights']['enabled']:
    if config['class_weights']['method'] == 'inverse_frequency':
        print(f"\nComputing class weights from training data...")
        class_weights = compute_class_weights(train_dataset, label_column, label2id)
        print(f"Class weights: {class_weights}")
    elif config['class_weights']['manual_weights']:
        class_weights = torch.tensor(config['class_weights']['manual_weights'], dtype=torch.float32)
        print(f"Using manual class weights: {class_weights}")
else:
    class_weights = None
    print("Class weights disabled.")

## 6. Load and Prepare Dataset

In [None]:
# Load dataset from HuggingFace
print(f"Loading dataset: {config['dataset']['name']}...")

if config['dataset']['streaming']:
    dataset = load_dataset(config['dataset']['name'], streaming=True)
else:
    dataset = load_dataset(config['dataset']['name'])

print(f"Dataset loaded successfully!")
print(f"Dataset structure: {dataset}")

In [None]:
# Check if splits exist, otherwise create them
train_split = config['dataset']['train_split']
val_split = config['dataset']['validation_split']
test_split = config['dataset']['test_split']

if train_split and train_split in dataset:
    train_dataset = dataset[train_split]
    val_dataset = dataset[val_split] if val_split in dataset else None
    test_dataset = dataset[test_split] if test_split in dataset else None
    
    print(f"Using existing splits:")
    print(f"  Train: {len(train_dataset)} samples")
    if val_dataset:
        print(f"  Validation: {len(val_dataset)} samples")
    if test_dataset:
        print(f"  Test: {len(test_dataset)} samples")
else:
    print("Creating train/val/test splits...")
    
    # Assume dataset has a single split or combine all
    if isinstance(dataset, DatasetDict):
        # Combine all splits if multiple exist
        from datasets import concatenate_datasets
        full_dataset = concatenate_datasets([dataset[split] for split in dataset.keys()])
    else:
        full_dataset = dataset
    
    # Shuffle dataset
    full_dataset = full_dataset.shuffle(seed=config['dataset']['shuffle_seed'])
    
    # Split dataset
    train_size = config['dataset']['train_ratio']
    val_size = config['dataset']['validation_ratio']
    test_size = config['dataset']['test_ratio']
    
    # Calculate split points
    train_test = full_dataset.train_test_split(test_size=(val_size + test_size))
    train_dataset = train_test['train']
    
    # Split remaining into val and test
    val_test = train_test['test'].train_test_split(test_size=test_size/(val_size + test_size))
    val_dataset = val_test['train']
    test_dataset = val_test['test']
    
    print(f"Splits created:")
    print(f"  Train: {len(train_dataset)} samples ({train_size*100:.1f}%)")
    print(f"  Validation: {len(val_dataset)} samples ({val_size*100:.1f}%)")
    print(f"  Test: {len(test_dataset)} samples ({test_size*100:.1f}%)")

In [None]:
# Sampled dataset caching - load pre-sampled datasets if available
from datasets import load_from_disk as load_dataset_from_disk
import hashlib

def get_sample_cache_key(config):
    """Generate cache key for sampled datasets."""
    key_data = {
        'dataset_name': config['dataset']['name'],
        'max_train_samples': config['dataset']['max_train_samples'],
        'max_eval_samples': config['dataset']['max_eval_samples'],
        'max_test_samples': config['dataset']['max_test_samples'],
        'shuffle_seed': config['dataset']['shuffle_seed'],
    }
    import json
    key_str = json.dumps(key_data, sort_keys=True)
    return hashlib.md5(key_str.encode()).hexdigest()[:8]

sampled_cache_dir = config['misc'].get('sampled_cache_dir', './sampled_cache')
use_sampled_cache = config['misc'].get('save_sampled_datasets', True)
force_resample = config['misc'].get('force_resample', False)

sample_key = get_sample_cache_key(config)
sampled_train_path = os.path.join(sampled_cache_dir, f'train_sampled_{sample_key}')
sampled_val_path = os.path.join(sampled_cache_dir, f'val_sampled_{sample_key}')
sampled_test_path = os.path.join(sampled_cache_dir, f'test_sampled_{sample_key}')

sampled_cache_exists = (
    os.path.exists(sampled_train_path) and
    os.path.exists(sampled_val_path) and
    os.path.exists(sampled_test_path)
)

if use_sampled_cache and sampled_cache_exists and not force_resample:
    print("="*80)
    print("LOADING PRE-SAMPLED DATASETS FROM CACHE")
    print("="*80)
    print(f"Cache key: {sample_key}")
    print("Loading sampled datasets from disk (instant!)...\n")
    
    train_dataset = load_dataset_from_disk(sampled_train_path)
    val_dataset = load_dataset_from_disk(sampled_val_path)
    test_dataset = load_dataset_from_disk(sampled_test_path)
    
    print(f"Loaded sampled datasets in seconds:")
    print(f"  Train: {len(train_dataset):,} samples")
    print(f"  Validation: {len(val_dataset):,} samples")
    print(f"  Test: {len(test_dataset):,} samples")
    print(f"\nSkipped sampling step - using cached sampled data")
    print("="*80)
    
else:
    # Perform sampling
    print("="*80)
    print("PERFORMING STRATIFIED SAMPLING BY DOMAIN")
    print("="*80)
    
    if force_resample:
        print("Force re-sampling enabled (ignoring sampled cache)")
    else:
        print("No sampled cache found - performing stratified sampling")
    
    def stratified_sample_by_domain(dataset, target_samples, domain_column='domain', seed=42):
        """Sample dataset maintaining domain proportions."""
        if target_samples >= len(dataset):
            return dataset
        
        print(f"\n  Stratified sampling by '{domain_column}'...")
        
        # Check if domain column exists
        if domain_column not in dataset.column_names:
            print(f"  Column '{domain_column}' not found, using random sampling")
            indices = list(range(len(dataset)))
            import random
            random.seed(seed)
            random.shuffle(indices)
            return dataset.select(indices[:target_samples])
        
        # Get domain distribution (optimized - avoid full DataFrame conversion)
        print(f"  Analyzing domain distribution...")
        domain_dict = {}
        for idx, item in enumerate(dataset):
            domain = item[domain_column]
            if domain not in domain_dict:
                domain_dict[domain] = []
            domain_dict[domain].append(idx)
        
        print(f"  Found {len(domain_dict)} domains")
        
        # Sample from each domain proportionally
        sampled_indices = []
        total_count = len(dataset)
        
        import random
        random.seed(seed)
        
        for domain, indices in sorted(domain_dict.items()):
            domain_count = len(indices)
            domain_proportion = domain_count / total_count
            domain_target = int(domain_proportion * target_samples)
            
            if domain_target < domain_count:
                domain_sample = random.sample(indices, domain_target)
            else:
                domain_sample = indices
            
            sampled_indices.extend(domain_sample)
            print(f"    {domain}: {len(domain_sample):,} samples ({len(domain_sample)/domain_count*100:.1f}% of domain)")
        
        # Shuffle final indices
        random.shuffle(sampled_indices)
        
        return dataset.select(sampled_indices)
    
    # Perform stratified sampling
    if config['dataset']['max_train_samples']:
        target = config['dataset']['max_train_samples']
        print(f"\nTarget training samples: {target:,}")
        train_dataset = stratified_sample_by_domain(train_dataset, target, seed=config['dataset']['shuffle_seed'])
    
    if config['dataset']['max_eval_samples'] and val_dataset:
        target = config['dataset']['max_eval_samples']
        print(f"\nTarget validation samples: {target:,}")
        val_dataset = stratified_sample_by_domain(val_dataset, target, seed=config['dataset']['shuffle_seed'])
    
    if config['dataset']['max_test_samples'] and test_dataset:
        target = config['dataset']['max_test_samples']
        print(f"\nTarget test samples: {target:,}")
        test_dataset = stratified_sample_by_domain(test_dataset, target, seed=config['dataset']['shuffle_seed'])
    
    print(f"\nFinal sampled datasets:")
    print(f"  Train: {len(train_dataset):,} samples")
    if val_dataset:
        print(f"  Validation: {len(val_dataset):,} samples")
    if test_dataset:
        print(f"  Test: {len(test_dataset):,} samples")
    
    # Save sampled datasets to cache
    if use_sampled_cache:
        print(f"\nSaving sampled datasets to cache...")
        print(f"Cache directory: {sampled_cache_dir}")
        os.makedirs(sampled_cache_dir, exist_ok=True)
        
        train_dataset.save_to_disk(sampled_train_path)
        val_dataset.save_to_disk(sampled_val_path)
        test_dataset.save_to_disk(sampled_test_path)
        
        print(f"  All sampled datasets saved to disk")
        print(f"\nNext run will skip sampling and load instantly!")
        print(f"Cache key: {sample_key}")
    
    print("="*80)

In [None]:
# Explore dataset
print("\nDataset sample:")
print(train_dataset[0])

# Check label distribution
label_column = config['model']['label_column']
text_column = config['model']['text_column']

print(f"\nLabel distribution in training set:")
train_labels = pd.Series([sample[label_column] for sample in train_dataset])
print(train_labels.value_counts())
print(f"\nLabel distribution (normalized):")
print(train_labels.value_counts(normalize=True))

## 5. Create Label Mappings and Class Weights

In [None]:
# Create label mappings
labels = config['model']['labels']
label2id, id2label = create_label_mappings(labels)

print(f"Label to ID mapping: {label2id}")
print(f"ID to Label mapping: {id2label}")

# Compute class weights
if config['class_weights']['enabled']:
    if config['class_weights']['method'] == 'inverse_frequency':
        class_weights = compute_class_weights(train_dataset, label_column, label2id)
        print(f"\nComputed class weights: {class_weights}")
    elif config['class_weights']['manual_weights']:
        class_weights = torch.tensor(config['class_weights']['manual_weights'], dtype=torch.float32)
        print(f"\nUsing manual class weights: {class_weights}")
else:
    class_weights = None
    print("\nClass weights disabled.")

## 8. Load Tokenizer

In [None]:
# Load tokenizer
print(f"Loading tokenizer: {config['model']['name']}...")

tokenizer = AutoTokenizer.from_pretrained(
    config['model']['name'],
    add_prefix_space=config['tokenization']['add_prefix_space'],
    use_fast=config['misc']['use_fast_tokenizer'],
    trust_remote_code=config['misc']['trust_remote_code'],
    cache_dir=config['misc']['cache_dir'],
)

# Set pad token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    print(f"Set pad_token to eos_token: {tokenizer.pad_token}")

print(f"Tokenizer loaded successfully!")
print(f"Vocab size: {len(tokenizer)}")
print(f"Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")

## 9. Tokenize Datasets

In [None]:
import hashlib
import json

def get_cache_key(config):
    """Generate a cache key based on tokenization settings."""
    key_data = {
        'model': config['model']['name'],
        'max_length': config['tokenization']['max_length'],
        'padding': config['tokenization']['padding'],
        'text_column': config['model']['text_column'],
        'label_column': config['model']['label_column'],
    }
    key_str = json.dumps(key_data, sort_keys=True)
    return hashlib.md5(key_str.encode()).hexdigest()[:8]


def preprocess_function(examples):
    """Tokenize text and map labels to IDs."""
    # Convert text to strings and handle None/NaN values
    texts = []
    for text in examples[text_column]:
        if text is None or (isinstance(text, float) and pd.isna(text)):
            texts.append("")  # Replace None/NaN with empty string
        else:
            texts.append(str(text))  # Ensure it's a string
    
    # Tokenize text
    tokenized = tokenizer(
        texts,
        padding=config['tokenization']['padding'],
        truncation=config['tokenization']['truncation'],
        max_length=config['tokenization']['max_length'],
    )
    
    # Map labels to IDs
    tokenized['labels'] = [label2id[label] for label in examples[label_column]]
    
    return tokenized


# Check if we should use cached tokenized datasets
cache_dir = config['misc'].get('tokenized_cache_dir', './tokenized_cache')
use_cache = config['misc'].get('save_tokenized_datasets', True)
force_retokenize = config['misc'].get('force_retokenize', False)

cache_key = get_cache_key(config)
cache_path_train = os.path.join(cache_dir, f'train_{cache_key}')
cache_path_val = os.path.join(cache_dir, f'val_{cache_key}')
cache_path_test = os.path.join(cache_dir, f'test_{cache_key}')

# Check if cached datasets exist
cache_exists = (
    os.path.exists(cache_path_train) and 
    os.path.exists(cache_path_val) and 
    os.path.exists(cache_path_test)
)

if use_cache and cache_exists and not force_retokenize:
    print("="*80)
    print(" LOADING CACHED TOKENIZED DATASETS")
    print("="*80)
    print(f"Cache directory: {cache_dir}")
    print(f"Cache key: {cache_key}")
    print("Loading from disk (this will be FAST!)...\n")
    
    from datasets import load_from_disk
    
    train_tokenized = load_from_disk(cache_path_train)
    val_tokenized = load_from_disk(cache_path_val)
    test_tokenized = load_from_disk(cache_path_test)
    
    print(f" Loaded cached tokenized datasets in seconds!")
    print(f"  Train: {len(train_tokenized):,} samples")
    print(f"  Validation: {len(val_tokenized):,} samples")
    print(f"  Test: {len(test_tokenized):,} samples")
    print(f"\n To force re-tokenization, set 'force_retokenize: true' in config.yaml")
    print("="*80)
    
else:
    print("="*80)
    print("TOKENIZING DATASETS (First Time)")
    print("="*80)
    
    if force_retokenize:
        print("  Force re-tokenization enabled (ignoring cache)")
    elif not cache_exists:
        print(" No cached datasets found - tokenizing from scratch")
    
    print("This will take a while for large datasets...\n")
    
    # First, filter out any rows with None/NaN text or labels
    print("Filtering datasets to remove invalid entries...")
    
    def is_valid(example):
        """Check if example has valid text and label."""
        text = example[text_column]
        label = example[label_column]
        
        # Check text is valid
        text_valid = text is not None and not (isinstance(text, float) and pd.isna(text)) and str(text).strip() != ""
        # Check label is valid
        label_valid = label is not None and label in label2id
        
        return text_valid and label_valid
    
    train_dataset = train_dataset.filter(is_valid, desc="Filtering train dataset")
    if val_dataset:
        val_dataset = val_dataset.filter(is_valid, desc="Filtering validation dataset")
    if test_dataset:
        test_dataset = test_dataset.filter(is_valid, desc="Filtering test dataset")
    
    print(f"\nAfter filtering:")
    print(f"  Train: {len(train_dataset):,} samples")
    if val_dataset:
        print(f"  Validation: {len(val_dataset):,} samples")
    if test_dataset:
        print(f"  Test: {len(test_dataset):,} samples")
    
    # Tokenize datasets
    print("\nTokenizing datasets...")
    
    train_tokenized = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=train_dataset.column_names,
        desc="Tokenizing train dataset",
    )
    
    val_tokenized = val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=val_dataset.column_names,
        desc="Tokenizing validation dataset",
    ) if val_dataset else None
    
    test_tokenized = test_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=test_dataset.column_names,
        desc="Tokenizing test dataset",
    ) if test_dataset else None
    
    print("\n Tokenization complete!")
    print(f"  Train: {len(train_tokenized):,} samples")
    if val_tokenized:
        print(f"  Validation: {len(val_tokenized):,} samples")
    if test_tokenized:
        print(f"  Test: {len(test_tokenized):,} samples")
    
    # Save tokenized datasets to cache for future runs
    if use_cache:
        print(f"\n Saving tokenized datasets to cache...")
        print(f"Cache directory: {cache_dir}")
        os.makedirs(cache_dir, exist_ok=True)
        
        train_tokenized.save_to_disk(cache_path_train)
        print(f"   Train dataset saved to disk")
        
        if val_tokenized:
            val_tokenized.save_to_disk(cache_path_val)
            print(f"   Validation dataset saved to disk")
        
        if test_tokenized:
            test_tokenized.save_to_disk(cache_path_test)
            print(f"   Test dataset saved to disk")
        
        print(f"\n Next time you run this notebook, these datasets will load in seconds!")
        print(f"Cache location: {cache_dir}")

print("="*80)
print(f"\nFinal tokenized datasets:")
print(f"Train dataset: {train_tokenized}")
if val_tokenized:
    print(f"Validation dataset: {val_tokenized}")
if test_tokenized:
    print(f"Test dataset: {test_tokenized}")

## 10. Load Model

In [None]:
# Configure quantization
quantization_config = None
if config['quantization']['enabled']:
    print("Setting up quantization...")
    
    compute_dtype = getattr(torch, config['quantization']['bnb_4bit_compute_dtype'])
    
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=config['quantization']['load_in_4bit'],
        load_in_8bit=config['quantization']['load_in_8bit'],
        bnb_4bit_quant_type=config['quantization']['bnb_4bit_quant_type'],
        bnb_4bit_use_double_quant=config['quantization']['bnb_4bit_use_double_quant'],
        bnb_4bit_compute_dtype=compute_dtype,
    )
    
    print(f"Quantization config: {quantization_config}")

print(f"\nLoading model: {config['model']['name']}...")
print("This may take a few minutes...\n")

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    config['model']['name'],
    num_labels=config['model']['num_labels'],
    id2label=id2label,
    label2id=label2id,
    quantization_config=quantization_config,
    device_map=config['hardware']['device_map'],
    max_memory=config['hardware']['max_memory'],
    trust_remote_code=config['misc']['trust_remote_code'],
    cache_dir=config['misc']['cache_dir'],
)

# Configure model
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False  # Disable cache for training

print(f"Model loaded successfully!")
print(f"Number of parameters: {model.num_parameters():,}")
print_gpu_memory()

## 9. Apply PEFT (if enabled)

In [None]:
if config['peft']['enabled']:
    print("Applying PEFT/LoRA...")
    
    # Prepare model for k-bit training
    if config['quantization']['enabled']:
        model = prepare_model_for_kbit_training(model)
    
    # Configure LoRA
    peft_config = LoraConfig(
        r=config['peft']['lora_r'],
        lora_alpha=config['peft']['lora_alpha'],
        target_modules=config['peft']['target_modules'],
        lora_dropout=config['peft']['lora_dropout'],
        bias=config['peft']['bias'],
        task_type=TaskType.SEQ_CLS,
        modules_to_save=config['peft']['modules_to_save'],
    )
    
    # Apply LoRA
    model = get_peft_model(model, peft_config)
    
    print(f"\nLoRA applied successfully!")
    print(f"LoRA rank: {config['peft']['lora_r']}")
    print(f"LoRA alpha: {config['peft']['lora_alpha']}")
    print(f"Target modules: {config['peft']['target_modules']}")
    
    # Print trainable parameters
    model.print_trainable_parameters()
else:
    print("PEFT disabled. Using full fine-tuning.")
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")

print_gpu_memory()

## 12. Data Collator

In [None]:
# Initialize data collator
if config['data_collator']['type'] == 'DataCollatorWithPadding':
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        padding=config['data_collator']['padding'],
        pad_to_multiple_of=config['data_collator']['pad_to_multiple_of'],
    )
    print(f"Using DataCollatorWithPadding")
else:
    data_collator = None
    print(f"Using default data collator")

## 13. Metrics Function

In [None]:
def compute_metrics(eval_pred):
    """Compute evaluation metrics."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Compute metrics
    accuracy = accuracy_score(labels, predictions)
    balanced_acc = balanced_accuracy_score(labels, predictions)
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted', zero_division=0
    )
    
    return {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }


print("Metrics function defined.")

## 12. Custom Trainer with Class Weights

In [None]:
class WeightedTrainer(Trainer):
    """Custom Trainer with class-weighted loss."""
    
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        
        if class_weights is not None:
            if isinstance(class_weights, torch.Tensor):
                self.class_weights = class_weights.detach().clone().float()
            else:
                self.class_weights = torch.tensor(class_weights, dtype=torch.float32)
            self.class_weights = self.class_weights.to(self.args.device)
        else:
            self.class_weights = None
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Compute weighted cross-entropy loss."""
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        # Compute weighted loss
        if self.class_weights is not None:
            loss = torch.nn.functional.cross_entropy(
                logits, labels, weight=self.class_weights
            )
        else:
            loss = torch.nn.functional.cross_entropy(logits, labels)
        
        return (loss, outputs) if return_outputs else loss


print("Custom Trainer class defined.")

## 15. Training Arguments

In [None]:
# Create output directory
output_dir = config['training']['output_dir']
os.makedirs(output_dir, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=config['training']['num_train_epochs'],
    per_device_train_batch_size=config['training']['per_device_train_batch_size'],
    per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
    gradient_checkpointing=config['training']['gradient_checkpointing'],
    learning_rate=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay'],
    warmup_ratio=config['training']['warmup_ratio'],
    lr_scheduler_type=config['training']['lr_scheduler_type'],
    
    # Evaluation and saving
    eval_strategy=config['training']['eval_strategy'],
    eval_steps=config['training']['eval_steps'],
    save_strategy=config['training']['save_strategy'],
    save_steps=config['training']['save_steps'],
    save_total_limit=config['training']['save_total_limit'],
    load_best_model_at_end=config['training']['load_best_model_at_end'],
    metric_for_best_model=config['training']['metric_for_best_model'],
    greater_is_better=config['training']['greater_is_better'],
    
    # Logging
    logging_steps=config['training']['logging_steps'],
    logging_dir=config['training']['logging_dir'],
    report_to=config['training']['report_to'],
    
    # Mixed precision
    fp16=config['training']['fp16'],
    bf16=config['training']['bf16'],
    
    # Optimization
    optim=config['training']['optim'],
    max_grad_norm=config['training']['max_grad_norm'],
    
    # Data loading
    dataloader_num_workers=config['training']['dataloader_num_workers'],
    dataloader_pin_memory=config['training']['dataloader_pin_memory'],
    group_by_length=config['training']['group_by_length'],
    
    # Misc
    seed=config['training']['seed'],
    remove_unused_columns=config['training']['remove_unused_columns'],
    push_to_hub=config['training']['push_to_hub'],
    hub_model_id=config['training']['hub_model_id'],
    hub_token=config['training']['hub_token'],
)

print("Training arguments configured.")
print(f"\nEffective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Total optimization steps: ~{len(train_tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

## 16. Initialize Trainer

In [None]:
# Early stopping callback
callbacks = []
if config['early_stopping']['enabled']:
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=config['early_stopping']['patience'],
        early_stopping_threshold=config['early_stopping']['threshold'],
    )
    callbacks.append(early_stopping)
    print(f"Early stopping enabled with patience: {config['early_stopping']['patience']}")

# Initialize trainer
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    class_weights=class_weights,
    callbacks=callbacks,
)

print("\nTrainer initialized successfully!")

## 15. Pre-Training Evaluation (Baseline)

In [None]:
if config['evaluation']['run_pre_training_eval'] and test_tokenized:
    print("="*80)
    print("PRE-TRAINING EVALUATION (Baseline)")
    print("="*80)
    
    # Use subset for faster baseline evaluation
    pre_eval_samples = config['evaluation'].get('pre_training_eval_samples')
    if pre_eval_samples and pre_eval_samples < len(test_tokenized):
        print(f"Using subset of {pre_eval_samples:,} samples for fast baseline")
        print(f"(Full test set has {len(test_tokenized):,} samples)")
        
        # Random sample from test set
        import random
        random.seed(42)
        indices = random.sample(range(len(test_tokenized)), pre_eval_samples)
        pre_eval_dataset = test_tokenized.select(indices)
        print(f"Selected {len(pre_eval_dataset):,} samples for baseline evaluation")
        print(f"Estimated time: ~2-3 minutes\n")
    else:
        print(f"Using full test set ({len(test_tokenized):,} samples)")
        print(f"  This may take 30-60 minutes!")
        print(f" Tip: Set 'pre_training_eval_samples: 1000' in config.yaml for faster baseline\n")
        pre_eval_dataset = test_tokenized
    
    # Run prediction on test set
    pre_results = trainer.predict(pre_eval_dataset)
    
    # Extract predictions and labels
    pre_predictions = np.argmax(pre_results.predictions, axis=1)
    pre_labels = pre_results.label_ids
    
    # Compute metrics
    pre_metrics = {
        'accuracy': accuracy_score(pre_labels, pre_predictions),
        'balanced_accuracy': balanced_accuracy_score(pre_labels, pre_predictions),
    }
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        pre_labels, pre_predictions, average='weighted', zero_division=0
    )
    pre_metrics['precision'] = precision
    pre_metrics['recall'] = recall
    pre_metrics['f1'] = f1
    
    # Print metrics
    print(f"\nPre-training Metrics (on {len(pre_eval_dataset):,} samples):")
    for metric, value in pre_metrics.items():
        print(f"  {metric}: {value:.4f}")
    
    # Confusion matrix
    print(f"\nConfusion Matrix:")
    cm = confusion_matrix(pre_labels, pre_predictions)
    print(cm)
    
    # Classification report
    print(f"\nClassification Report:")
    print(classification_report(
        pre_labels, pre_predictions,
        target_names=[id2label[i] for i in range(len(labels))],
        digits=4
    ))
    
    # Save pre-training metrics
    save_metrics_to_file(pre_metrics, os.path.join(output_dir, 'pre_training_metrics.txt'))
    
    # Log to wandb
    if config['wandb']['enabled']:
        wandb.log({
            "pre_eval_accuracy": pre_metrics['accuracy'],
            "pre_eval_balanced_accuracy": pre_metrics['balanced_accuracy'],
            "pre_eval_f1": pre_metrics['f1'],
        })
    
    print("="*80)
else:
    print("Pre-training evaluation skipped.")
    pre_metrics = None

## 18. Training

In [None]:
print("="*80)
print("STARTING TRAINING")
print("="*80)

# Resume from checkpoint if specified
resume_from_checkpoint = config['misc']['resume_from_checkpoint']

# Train
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)

# Print training metrics
print(f"\nTraining Metrics:")
for key, value in train_result.metrics.items():
    print(f"  {key}: {value}")

print_gpu_memory()

## 19. Post-Training Evaluation

In [None]:
if config['evaluation']['run_post_training_eval'] and test_tokenized:
    print("="*80)
    print("POST-TRAINING EVALUATION")
    print("="*80)
    
    # Run prediction on test set
    post_results = trainer.predict(test_tokenized)
    
    # Extract predictions and labels
    post_predictions = np.argmax(post_results.predictions, axis=1)
    post_labels = post_results.label_ids
    
    # Compute metrics
    post_metrics = {
        'accuracy': accuracy_score(post_labels, post_predictions),
        'balanced_accuracy': balanced_accuracy_score(post_labels, post_predictions),
    }
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        post_labels, post_predictions, average='weighted', zero_division=0
    )
    post_metrics['precision'] = precision
    post_metrics['recall'] = recall
    post_metrics['f1'] = f1
    
    # Print metrics
    print(f"\nPost-training Metrics:")
    for metric, value in post_metrics.items():
        print(f"  {metric}: {value:.4f}")
    
    # Confusion matrix
    print(f"\nConfusion Matrix:")
    cm = confusion_matrix(post_labels, post_predictions)
    print(cm)
    
    # Classification report
    print(f"\nClassification Report:")
    print(classification_report(
        post_labels, post_predictions,
        target_names=[id2label[i] for i in range(len(labels))],
        digits=4
    ))
    
    # Save post-training metrics
    save_metrics_to_file(post_metrics, os.path.join(output_dir, 'post_training_metrics.txt'))
    
    # Save predictions if enabled
    if config['evaluation']['save_predictions']:
        predictions_df = pd.DataFrame({
            'true_label': [id2label[label] for label in post_labels],
            'predicted_label': [id2label[pred] for pred in post_predictions],
            'correct': post_labels == post_predictions,
        })
        predictions_file = os.path.join(output_dir, config['evaluation']['predictions_file'])
        predictions_df.to_csv(predictions_file, index=False)
        print(f"\nPredictions saved to: {predictions_file}")
    
    print("="*80)
    
    # Compare with pre-training metrics
    if pre_metrics:
        print("\nIMPROVEMENT SUMMARY:")
        print("="*80)
        for metric in ['accuracy', 'balanced_accuracy', 'precision', 'recall', 'f1']:
            if metric in pre_metrics and metric in post_metrics:
                improvement = post_metrics[metric] - pre_metrics[metric]
                print(f"{metric:20s}: {pre_metrics[metric]:.4f} → {post_metrics[metric]:.4f} (Δ {improvement:+.4f})")
        print("="*80)
else:
    print("Post-training evaluation skipped.")

## 20. Save Model

In [None]:
# Save the fine-tuned model
print(f"\nSaving model to: {output_dir}")

trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

# Save training arguments
training_args_file = os.path.join(output_dir, 'training_args.yaml')
with open(training_args_file, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print(f"Model saved successfully!")
print(f"\nTo load the model later:")
print(f"  from transformers import AutoModelForSequenceClassification, AutoTokenizer")
print(f"  model = AutoModelForSequenceClassification.from_pretrained('{output_dir}')")
print(f"  tokenizer = AutoTokenizer.from_pretrained('{output_dir}')")

# If PEFT was used, also save adapter
if config['peft']['enabled']:
    print(f"\nLoRA adapters saved. To load:")
    print(f"  from peft import PeftModel")
    print(f"  base_model = AutoModelForSequenceClassification.from_pretrained('{config['model']['name']}')")
    print(f"  model = PeftModel.from_pretrained(base_model, '{output_dir}')")

## 21. Summary and Finish Wandb Run

In [None]:
print("="*80)
print("TRAINING SUMMARY")
print("="*80)

print(f"\nModel: {config['model']['name']}")
print(f"Dataset: {config['dataset']['name']}")
print(f"Training samples: {len(train_tokenized):,}")
print(f"Validation samples: {len(val_tokenized):,}" if val_tokenized else "N/A")
print(f"Test samples: {len(test_tokenized):,}" if test_tokenized else "N/A")
print(f"\nTraining method: {'PEFT/LoRA' if config['peft']['enabled'] else 'Full fine-tuning'}")
if config['peft']['enabled']:
    print(f"  LoRA rank: {config['peft']['lora_r']}")
    print(f"  LoRA alpha: {config['peft']['lora_alpha']}")
print(f"\nQuantization: {'4-bit' if config['quantization']['load_in_4bit'] else '8-bit' if config['quantization']['load_in_8bit'] else 'None'}")
print(f"Epochs: {config['training']['num_train_epochs']}")
print(f"Batch size (per device): {config['training']['per_device_train_batch_size']}")
print(f"Gradient accumulation: {config['training']['gradient_accumulation_steps']}")
print(f"Effective batch size: {config['training']['per_device_train_batch_size'] * config['training']['gradient_accumulation_steps']}")
print(f"Learning rate: {config['training']['learning_rate']}")

if post_metrics:
    print(f"\nFinal Test Metrics:")
    for metric, value in post_metrics.items():
        print(f"  {metric}: {value:.4f}")

print(f"\nOutput directory: {output_dir}")

print("\n" + "="*80)
print(" Training completed successfully!")
print("="*80)

print(f"\nNext steps:")
print(f"  1. Review logs in: {config['training']['logging_dir']}")
if config['wandb']['enabled']:
    print(f"  2. Check W&B dashboard: {wandb.run.get_url()}")
    print(f"  3. Test the model on new data")
    print(f"  4. Deploy the model for inference")
else:
    print(f"  2. Test the model on new data")
    print(f"  3. Deploy the model for inference")
if config['training']['push_to_hub']:
    print(f"  5. Check your model on HuggingFace Hub: {config['training']['hub_model_id']}")

# Finish wandb run
if config['wandb']['enabled']:
    print("\n" + "="*80)
    print("Finishing Weights & Biases run...")
    wandb.finish()
    print(" W&B run finished successfully!")
    print("="*80)