In [None]:
# ============================================================
# Cell 1: Install & Import
# ============================================================
# !pip install torch transformers datasets sentencepiece pillow tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    VisionEncoderDecoderModel,
    AutoImageProcessor,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)
from datasets import load_dataset
from PIL import Image
import sentencepiece as spm
import io
import os
import shutil
from tqdm import tqdm

print("‚úÖ All imports successful")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================================
# Cell 2: Define Thai Tokenizer Class
# ============================================================
class ThaiTokenizerFixed(PreTrainedTokenizer):
    """Fixed Thai SentencePiece Tokenizer with proper special token handling"""
    
    vocab_files_names = {"vocab_file": "spm.model"}
    
    def __init__(self, vocab_file=None, **kwargs):
        self.vocab_file = vocab_file or 'thai_sp_30000.model'
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(self.vocab_file)
        
        super().__init__(
            pad_token="<pad>",
            unk_token="<unk>",
            bos_token="<s>",
            eos_token="</s>",
            **kwargs
        )
    
    @property
    def pad_token_id(self):
        return 0
    
    @property
    def unk_token_id(self):
        return 1
    
    @property
    def bos_token_id(self):
        return 2
    
    @property
    def eos_token_id(self):
        return 3
    
    @property
    def vocab_size(self):
        return self.sp.vocab_size()
    
    def get_vocab(self):
        return {self.sp.id_to_piece(i): i for i in range(self.sp.vocab_size())}
    
    def _tokenize(self, text):
        return self.sp.encode_as_pieces(text)
    
    def _convert_token_to_id(self, token):
        return self.sp.piece_to_id(token)
    
    def _convert_id_to_token(self, index):
        return self.sp.id_to_piece(index)
    
    def convert_tokens_to_string(self, tokens):
        return self.sp.decode_pieces(tokens)
    
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """Add BOS and EOS tokens"""
        bos = [self.bos_token_id]
        eos = [self.eos_token_id]
        
        if token_ids_1 is None:
            return bos + token_ids_0 + eos
        return bos + token_ids_0 + eos + bos + token_ids_1 + eos
    
    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, 
                token_ids_1=token_ids_1, 
                already_has_special_tokens=True
            )
        
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1] + [1] + ([0] * len(token_ids_1)) + [1]
    
    def save_vocabulary(self, save_directory, filename_prefix=None):
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory)
        
        out_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") + "spm.model"
        )
        
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_file):
            shutil.copy(self.vocab_file, out_file)
        
        return (out_file,)

print("‚úÖ ThaiTokenizerFixed class defined")


In [None]:
# ============================================================
# Cell 3: Define Dataset Class
# ============================================================
class ThaiHandwritingDatasetFixed(Dataset):
    """Dataset with proper BOS/EOS tokens"""
    def __init__(self, dataset, tokenizer, image_processor, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Process image
        image = item['image']
        if isinstance(image, dict):
            image = Image.open(io.BytesIO(image['bytes']))
        elif not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.squeeze(0)
        
        # Process text with special tokens
        text = item['text']
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            add_special_tokens=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze(0)
        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            'pixel_values': pixel_values,
            'labels': labels,
            'text': text,
        }

def collate_fn(batch):
    """Custom collate function"""
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    texts = [item['text'] for item in batch]
    
    return {
        'pixel_values': pixel_values,
        'labels': labels,
        'texts': texts,
    }

print("‚úÖ Dataset class and collate_fn defined")



In [None]:
# ============================================================
# Cell 4: Define Trainer Class (with Resume Support)
# ============================================================
class Trainer:
    """Trainer for TrOCR with Thai tokenizer - supports proper resume"""
    
    def __init__(
        self,
        model,
        train_dataloader,
        val_dataloader,
        tokenizer,
        device='cuda',
        learning_rate=5e-5,
        num_epochs=10,
        warmup_steps=500,
        output_dir='./checkpoints',
        gradient_accumulation_steps=1,
        resume_from=None,  # ‚≠ê Path to checkpoint for resume
    ):
        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.tokenizer = tokenizer
        self.device = device
        self.num_epochs = num_epochs
        self.output_dir = output_dir
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.learning_rate = learning_rate
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Calculate total steps
        self.steps_per_epoch = len(train_dataloader) // gradient_accumulation_steps
        total_steps = self.steps_per_epoch * num_epochs
        
        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Initialize scheduler
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training state
        self.start_epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.training_history = []
        
        # ‚≠ê Resume from checkpoint if provided
        if resume_from is not None:
            self._load_checkpoint(resume_from)
    
    def _load_checkpoint(self, checkpoint_path):
        """Load checkpoint and restore training state"""
        print(f"\nüìÇ Loading checkpoint from: {checkpoint_path}")
        
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Restore model
        self.model.load_state_dict(checkpoint['model_state_dict'])
        print("   ‚úÖ Model restored")
        
        # Restore optimizer
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("   ‚úÖ Optimizer restored")
        
        # Restore scheduler
        if 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print("   ‚úÖ Scheduler restored")
        
        # Restore training state
        self.start_epoch = checkpoint.get('epoch', 0) + 1  # Start from next epoch
        self.global_step = checkpoint.get('global_step', 0)
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        
        if 'training_history' in checkpoint:
            self.training_history = checkpoint['training_history']
        
        # Verify LR
        actual_lr = self.optimizer.param_groups[0]['lr']
        print(f"\nüìä Resume State:")
        print(f"   Completed epochs: {self.start_epoch}")
        print(f"   Global step: {self.global_step}")
        print(f"   Best val loss: {self.best_val_loss:.4f}")
        print(f"   Current LR: {actual_lr:.2e}")
    
    def _set_learning_rate(self, new_lr):
        """Manually set learning rate (use with caution)"""
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr
        print(f"‚ö†Ô∏è  LR manually set to: {new_lr:.2e}")
    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        epoch_steps = 0
        
        progress_bar = tqdm(
            self.train_dataloader, 
            desc=f"Epoch {epoch+1}/{self.start_epoch + self.num_epochs}"
        )
        
        self.optimizer.zero_grad()
        
        for step, batch in enumerate(progress_bar):
            pixel_values = batch['pixel_values'].to(self.device)
            labels = batch['labels'].to(self.device)
            
            outputs = self.model(
                pixel_values=pixel_values,
                labels=labels
            )
            
            loss = outputs.loss / self.gradient_accumulation_steps
            loss.backward()
            
            if (step + 1) % self.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                
                self.global_step += 1
                epoch_steps += 1
            
            total_loss += loss.item() * self.gradient_accumulation_steps
            
            # Show current LR in progress bar
            current_lr = self.optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'loss': f"{loss.item() * self.gradient_accumulation_steps:.4f}",
                'lr': f"{current_lr:.2e}"
            })
        
        avg_loss = total_loss / len(self.train_dataloader)
        return avg_loss
    
    def validate(self):
        """Validate the model"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_dataloader, desc="Validating"):
                pixel_values = batch['pixel_values'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                outputs = self.model(
                    pixel_values=pixel_values,
                    labels=labels
                )
                
                total_loss += outputs.loss.item()
        
        avg_loss = total_loss / len(self.val_dataloader)
        return avg_loss
    
    def save_checkpoint(self, epoch, val_loss):
        """Save checkpoint with full training state"""
        # Record history
        self.training_history.append({
            'epoch': epoch,
            'val_loss': val_loss,
            'lr': self.optimizer.param_groups[0]['lr'],
            'global_step': self.global_step,
        })
        
        # Save latest checkpoint
        latest_path = os.path.join(self.output_dir, 'checkpoint-latest.pt')
        torch.save({
            'epoch': epoch,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'val_loss': val_loss,
            'best_val_loss': self.best_val_loss,
            'training_history': self.training_history,
            'config': {
                'learning_rate': self.learning_rate,
                'gradient_accumulation_steps': self.gradient_accumulation_steps,
            }
        }, latest_path)
        print(f"üíæ Checkpoint saved (epoch {epoch+1}, step {self.global_step})")
        
        # Save periodic checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            periodic_path = os.path.join(self.output_dir, f'checkpoint-epoch-{epoch+1}.pt')
            shutil.copy(latest_path, periodic_path)
            print(f"üìÅ Periodic checkpoint: {periodic_path}")
        
        # Save best model
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            best_path = os.path.join(self.output_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch,
                'global_step': self.global_step,
                'model_state_dict': self.model.state_dict(),
                'val_loss': val_loss,
            }, best_path)
            print(f"üèÜ New best model! Val Loss: {val_loss:.4f}")
        else:
            print(f"üìä Val loss: {val_loss:.4f} (best: {self.best_val_loss:.4f})")
    
    def train(self):
        """Full training loop with resume support"""
        end_epoch = self.start_epoch + self.num_epochs
        
        print("\n" + "="*60)
        print("üöÄ TRAINING START")
        print("="*60)
        print(f"   From epoch: {self.start_epoch + 1}")
        print(f"   To epoch: {end_epoch}")
        print(f"   Total new epochs: {self.num_epochs}")
        print(f"   Learning rate: {self.optimizer.param_groups[0]['lr']:.2e}")
        print(f"   Batch size: {self.train_dataloader.batch_size}")
        print(f"   Gradient accumulation: {self.gradient_accumulation_steps}")
        print("="*60 + "\n")
        
        for epoch in range(self.start_epoch, end_epoch):
            print(f"\n{'='*50}")
            print(f"üìä Epoch {epoch+1}/{end_epoch}")
            print(f"   Global step: {self.global_step}")
            print(f"   LR: {self.optimizer.param_groups[0]['lr']:.2e}")
            print(f"{'='*50}")
            
            train_loss = self.train_epoch(epoch)
            print(f"üìà Train Loss: {train_loss:.5f}")
            
            val_loss = self.validate()
            print(f"üìâ Val Loss: {val_loss:.4f}")
            
            self.save_checkpoint(epoch, val_loss)
        
        print("\n" + "="*60)
        print("‚úÖ TRAINING COMPLETED")
        print("="*60)
        print(f"   Total epochs trained: {end_epoch}")
        print(f"   Final global step: {self.global_step}")
        print(f"   Best validation loss: {self.best_val_loss:.4f}")
        print("="*60)

print("‚úÖ Trainer class defined (with resume support)")

In [None]:
# ============================================================
# Cell 5: Setup & Train
# ============================================================
def setup_training(
    # Paths
    spm_model_path='thai_sp_30000.model',
    output_dir='./thai-trocr-custom-tokenizer',
    checkpoint_path=None,  # None = train from scratch, path = resume
    
    # Training params
    batch_size=8,
    learning_rate=5e-5,
    num_epochs=100,
    warmup_steps=500,
    gradient_accumulation_steps=2,
    max_length=128,
    
    # Data
    dataset_name="Pongthorn/HW-Sentence",
    train_split="train",
    val_split="test",
):
    """
    Setup and run training
    
    Usage:
        # Train from scratch
        setup_training(num_epochs=100)
        
        # Resume training
        setup_training(
            checkpoint_path='./thai-trocr-custom-tokenizer/checkpoint-latest.pt',
            num_epochs=50,  # Additional epochs to train
            learning_rate=5e-6,  # Can adjust LR for fine-tuning
        )
    """
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"üñ•Ô∏è  Device: {DEVICE}")
    
    # ========== Load Tokenizer ==========
    print("\nüìù Loading tokenizer...")
    tokenizer = ThaiTokenizerFixed(vocab_file=spm_model_path)
    print(f"   Vocab size: {tokenizer.vocab_size}")
    
    # ========== Load Image Processor ==========
    print("\nüñºÔ∏è  Loading image processor...")
    image_processor = AutoImageProcessor.from_pretrained(
        "microsoft/trocr-base-handwritten"
    )
    
    # ========== Load Dataset ==========
    print(f"\nüìö Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    print(f"   Train samples: {len(dataset[train_split])}")
    print(f"   Val samples: {len(dataset[val_split])}")
    
    # Create datasets
    train_dataset = ThaiHandwritingDatasetFixed(
        dataset[train_split], tokenizer, image_processor, max_length
    )
    val_dataset = ThaiHandwritingDatasetFixed(
        dataset[val_split], tokenizer, image_processor, max_length
    )
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    
    print(f"   Train batches: {len(train_dataloader)}")
    print(f"   Val batches: {len(val_dataloader)}")
    
    # ========== Load Model ==========
    print("\nü§ñ Loading model...")
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        # Resume: Load base model first, checkpoint will restore weights
        print("   Mode: RESUME TRAINING")
        model = VisionEncoderDecoderModel.from_pretrained(
            "microsoft/trocr-base-handwritten"
        )
    else:
        # Fresh start
        print("   Mode: TRAIN FROM SCRATCH")
        model = VisionEncoderDecoderModel.from_pretrained(
            "microsoft/trocr-base-handwritten"
        )
    
    # Resize embeddings for Thai tokenizer
    model.decoder.resize_token_embeddings(tokenizer.vocab_size)
    
    # Configure model
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.vocab_size = tokenizer.vocab_size
    model.config.max_length = max_length
    model.config.early_stopping = True
    model.config.num_beams = 4
    
    print(f"   Decoder vocab size: {model.config.vocab_size}")
    
    # ========== Create Trainer ==========
    print("\nüéØ Initializing trainer...")
    
    # Determine warmup steps
    if checkpoint_path and os.path.exists(checkpoint_path):
        # Resume: No warmup needed (scheduler state will be restored)
        actual_warmup = warmup_steps  # Will be overridden by scheduler state
    else:
        actual_warmup = warmup_steps
    
    trainer = Trainer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        tokenizer=tokenizer,
        device=DEVICE,
        learning_rate=learning_rate,
        num_epochs=num_epochs,
        warmup_steps=actual_warmup,
        output_dir=output_dir,
        gradient_accumulation_steps=gradient_accumulation_steps,
        resume_from=checkpoint_path,  # ‚≠ê This handles everything!
    )
    
    # ========== Start Training ==========
    trainer.train()
    
    return trainer


# ============================================================
# HOW TO USE
# ============================================================
"""
# ===== TRAIN FROM SCRATCH =====
trainer = setup_training(
    spm_model_path='thai_sp_30000.model',
    output_dir='./thai-trocr-custom-tokenizer',
    batch_size=8,
    learning_rate=5e-5,
    num_epochs=100,
    warmup_steps=500,
    gradient_accumulation_steps=2,
)


# ===== RESUME TRAINING (Continue with same settings) =====
trainer = setup_training(
    spm_model_path='thai_sp_30000.model',
    output_dir='./thai-trocr-custom-tokenizer',
    checkpoint_path='./thai-trocr-custom-tokenizer/checkpoint-latest.pt',
    batch_size=8,
    learning_rate=5e-5,  # Will be restored from checkpoint
    num_epochs=50,       # Additional epochs to train
    gradient_accumulation_steps=2,
)


# ===== RESUME WITH NEW LEARNING RATE (Fine-tuning) =====
# If you want to change LR after resuming, manually set it:
trainer = setup_training(
    checkpoint_path='./thai-trocr-custom-tokenizer/checkpoint-latest.pt',
    num_epochs=50,
)
# Then manually adjust LR if needed:
# trainer._set_learning_rate(1e-6)
# trainer.train()  # Won't work directly, need to modify


# ===== QUICK RESUME (just run this) =====
if __name__ == "__main__":
    import sys
    
    checkpoint = './thai-trocr-custom-tokenizer/checkpoint-latest.pt'
    
    if os.path.exists(checkpoint):
        print("üìÇ Found checkpoint, resuming...")
        trainer = setup_training(
            checkpoint_path=checkpoint,
            num_epochs=50,
        )
    else:
        print("üÜï No checkpoint found, starting fresh...")
        trainer = setup_training(
            num_epochs=100,
        )
"""

print("\n" + "="*60)
print("‚úÖ ALL CELLS READY!")
print("="*60)
print("\nTo start training, run:")
print("  trainer = setup_training(num_epochs=100)")
print("\nTo resume training, run:")
print("  trainer = setup_training(")
print("      checkpoint_path='./thai-trocr-custom-tokenizer/checkpoint-latest.pt',")
print("      num_epochs=50,")
print("  )")
print("="*60)


In [None]:
# ============================================================
# Cell 6: Test Model
# ============================================================
import random

print("\n" + "="*80)
print("üß™ TESTING MODEL")
print("="*80)

# ‚≠ê ‡∏ï‡πâ‡∏≠‡∏á‡∏°‡∏µ trainer ‡∏à‡∏≤‡∏Å Cell 5 ‡∏Å‡πà‡∏≠‡∏ô
# ‡∏ñ‡πâ‡∏≤‡πÑ‡∏°‡πà‡∏°‡∏µ ‡πÉ‡∏´‡πâ‡πÇ‡∏´‡∏•‡∏î‡πÉ‡∏´‡∏°‡πà:
# trainer = setup_training(
#     checkpoint_path='./thai-trocr-custom-tokenizer/checkpoint-latest.pt',
#     num_epochs=0,  # ‡πÑ‡∏°‡πà‡∏ï‡πâ‡∏≠‡∏á train ‡πÄ‡∏û‡∏¥‡πà‡∏°
# )

# ‚≠ê ‡πÄ‡∏•‡∏∑‡∏≠‡∏Å checkpoint
USE_CHECKPOINT = "latest"  # "latest", "best", or "current"

model = trainer.model
tokenizer = trainer.tokenizer
DEVICE = trainer.device

if USE_CHECKPOINT == "latest":
    print("\nüì• Loading checkpoint-latest.pt...")
    ckpt = torch.load('./thai-trocr-custom-tokenizer/checkpoint-latest.pt', map_location=DEVICE)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"‚úÖ Loaded: Epoch {ckpt['epoch']+1}, Loss {ckpt['val_loss']:.4f}")

elif USE_CHECKPOINT == "best":
    print("\nüì• Loading best_model.pt...")
    ckpt = torch.load('./thai-trocr-custom-tokenizer/best_model.pt', map_location=DEVICE)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"‚úÖ Loaded: Epoch {ckpt['epoch']+1}, Loss {ckpt['val_loss']:.4f}")

else:
    print("\n‚úÖ Using current model in trainer")
    ckpt = None

model.eval()

# Test function
def test_model_samples(model, dataloader, tokenizer, device, num_samples=10, title="Test"):
    """Test model with random samples"""
    model.eval()
    
    # Collect all samples
    all_samples = []
    for batch in dataloader:
        for i in range(batch['pixel_values'].size(0)):
            all_samples.append({
                'pixel_values': batch['pixel_values'][i],
                'text': batch['texts'][i]
            })
        if len(all_samples) >= 100:  # Enough for random sampling
            break
    
    # Random sample
    samples = random.sample(all_samples, min(num_samples, len(all_samples)))
    
    print(f"\n{'='*80}")
    print(f"üìä {title}")
    print(f"{'='*80}\n")
    
    results = []
    
    with torch.no_grad():
        for idx, sample in enumerate(samples):
            ground_truth = sample['text']
            pixel_values = sample['pixel_values'].unsqueeze(0).to(device)
            
            generated_ids = model.generate(
                pixel_values,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=1.5,
                pad_token_id=tokenizer.pad_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            
            predicted = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            # Character accuracy
            correct = sum(1 for a, b in zip(predicted, ground_truth) if a == b)
            accuracy = correct / max(len(ground_truth), 1) * 100
            
            results.append({
                'ground_truth': ground_truth,
                'predicted': predicted,
                'accuracy': accuracy
            })
            
            print(f"Sample {idx+1}/{num_samples}")
            print(f"‚îú‚îÄ GT:   {ground_truth}")
            print(f"‚îú‚îÄ Pred: {predicted}")
            print(f"‚îú‚îÄ Acc:  {accuracy:.1f}%")
            print(f"‚îî‚îÄ {'‚úÖ MATCH' if ground_truth == predicted else '‚ùå'}")
            print()
    
    return results

# Run tests
train_results = test_model_samples(
    model, trainer.train_dataloader, tokenizer, DEVICE, 
    num_samples=10, title="üéØ TRAIN SET"
)

val_results = test_model_samples(
    model, trainer.val_dataloader, tokenizer, DEVICE,
    num_samples=10, title="üéØ VALIDATION SET"
)

# Summary
print(f"\n{'='*80}")
print("üìà SUMMARY")
print(f"{'='*80}")

train_perfect = sum(1 for r in train_results if r['ground_truth'] == r['predicted'])
val_perfect = sum(1 for r in val_results if r['ground_truth'] == r['predicted'])
train_acc = sum(r['accuracy'] for r in train_results) / len(train_results)
val_acc = sum(r['accuracy'] for r in val_results) / len(val_results)

print(f"\nüéØ Train: {train_perfect}/10 perfect, {train_acc:.1f}% avg char acc")
print(f"üéØ Val:   {val_perfect}/10 perfect, {val_acc:.1f}% avg char acc")
print("\n‚úÖ Testing completed!")