# Administrative Form Detection using Donut

This notebook implements a classifier to detect standardized administrative forms in PDF documents using the Donut transformer model.

## Objectives:
1. Detect if a PDF document contains a specific standardized administrative form
2. Identify which page number the form appears on

## Approach:
- Use Donut (Document Understanding Transformer) for OCR-free document classification
- Process PDFs page by page to identify the administrative form

## Future Options:
- Consider LayoutLMv3 for combining visual and text features if higher accuracy is needed

## 1. Setup and Dependencies

In [3]:
# Install required packages
#!pip install transformers torch torchvision pdf2image pillow numpy pandas tqdm scikit-learn

In [4]:
import os
# Set tokenizers parallelism to false to avoid fork warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from pdf2image import convert_from_path
from tqdm import tqdm
import json
from typing import List, Tuple, Dict, Optional

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel,
    AutoConfig,
    AutoTokenizer,
    AutoImageProcessor
)
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## 2. Data Exploration

In [6]:
# Define paths
BASE_PATH = Path('/Users/admin-tascott/Documents/GitHub/chehalis')
EXAMPLE_FORMS_PATH = BASE_PATH / 'data' / 'raw' / '_exampleforms'
NON_EXAMPLES_PATH = BASE_PATH / 'data' / 'raw' / '_nonexamples'

# Check if paths exist
print(f"Example forms path exists: {EXAMPLE_FORMS_PATH.exists()}")
print(f"Non-examples path exists: {NON_EXAMPLES_PATH.exists()}")

# List example files
if EXAMPLE_FORMS_PATH.exists():
    example_files = list(EXAMPLE_FORMS_PATH.glob('*.pdf'))
    print(f"\nFound {len(example_files)} example form files")
    for f in example_files[:5]:  # Show first 5
        print(f"  - {f.name}")

# List non-example files
if NON_EXAMPLES_PATH.exists():
    non_example_files = list(NON_EXAMPLES_PATH.glob('*.pdf'))
    print(f"\nFound {len(non_example_files)} non-example files")
    for f in non_example_files[:5]:  # Show first 5
        print(f"  - {f.name}")

Example forms path exists: True
Non-examples path exists: True

Found 106 example form files
  - 25581-000.pdf
  - 99171-000.pdf
  - 13924-002.pdf
  - 1197-000.pdf
  - 67419-000.pdf

Found 133 non-example files
  - 25581-000.pdf
  - 0000000000000000000062223-001.pdf
  - 1197-000.pdf
  - 67419-000.pdf
  - 104473-000.pdf


## 3. PDF Processing Functions

In [8]:
def pdf_to_images(pdf_path: Path, dpi: int = 200) -> List[Image.Image]:
    """
    Convert PDF to list of PIL Images (one per page)
    
    Args:
        pdf_path: Path to PDF file
        dpi: Resolution for conversion
    
    Returns:
        List of PIL Images
    """
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        return images
    except Exception as e:
        print(f"Error converting {pdf_path}: {e}")
        return []

def preprocess_image_for_donut(image: Image.Image, size: Tuple[int, int] = (1280, 960)) -> Image.Image:
    """
    Preprocess image for Donut model
    
    Args:
        image: PIL Image
        size: Target size (width, height)
    
    Returns:
        Preprocessed PIL Image
    """
    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Resize while maintaining aspect ratio
    image.thumbnail(size, Image.Resampling.LANCZOS)
    
    return image

## 4. Dataset Creation

In [10]:
class FormDataset(Dataset):
    """
    Dataset for administrative form classification with lazy loading support
    """
    def __init__(self, data_list: List[Dict], processor, lazy_load: bool = True):
        self.data = data_list
        self.processor = processor
        self.lazy_load = lazy_load
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load image if needed (lazy loading)
        if item.get('needs_loading', False) and self.lazy_load:
            # Load specific page from PDF
            pdf_path = Path(item['image_path'])
            page_num = item['page_num'] - 1  # 0-indexed for pdf2image
            
            try:
                # Load only the specific page
                images = convert_from_path(pdf_path, dpi=150, 
                                         first_page=page_num+1, 
                                         last_page=page_num+1)
                if images:
                    image = preprocess_image_for_donut(images[0])
                else:
                    # Return a blank image if loading fails
                    image = Image.new('RGB', (1280, 960), color='white')
            except Exception as e:
                print(f"Error loading page {page_num+1} from {pdf_path.name}: {e}")
                # Return a blank image
                image = Image.new('RGB', (1280, 960), color='white')
        else:
            image = item['image']
        
        # Process image
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        return {
            'pixel_values': pixel_values.squeeze(),
            'labels': item['label'],
            'metadata': {
                'file_path': item.get('file_path', ''),
                'page_num': item.get('page_num', -1),
                'source_folder': item.get('source_folder', ''),
                'original_filename': item.get('original_filename', '')
            }
        }

In [11]:
def create_dataset_from_pdfs(example_path: Path, non_example_path: Optional[Path] = None, 
                           max_samples_per_class: int = None, batch_process: bool = True) -> List[Dict]:
    """
    Create dataset from PDF files with class balancing
    
    IMPORTANT: Handles the case where:
    - Example form PDFs: ALL pages are the administrative form (some PDFs may be multi-page)
    - Non-example PDFs: NO pages are the administrative form
    
    Strategy:
    - Use ALL pages from ALL positive example PDFs
    - Randomly sample 2x that number of negative pages
    
    Args:
        example_path: Path to example forms
        non_example_path: Path to non-examples
        max_samples_per_class: Limit number of PDFs to process (for testing)
        batch_process: If True, process images in batches to save memory
    
    Returns:
        List of dictionaries with 'image', 'label', 'file_path', 'page_num', 'source_folder'
    """
    dataset = []
    
    # Process ALL positive examples (forms)
    # Every page in every PDF in the examples folder is a positive example
    positive_count = 0
    if example_path.exists():
        pdf_files = list(example_path.glob('*.pdf'))
        if max_samples_per_class:
            pdf_files = pdf_files[:max_samples_per_class]
            
        print(f"Processing {len(pdf_files)} positive example PDFs...")
        print("Note: Every page in these PDFs is considered a positive example (the form)")
        
        for pdf_path in tqdm(pdf_files, desc="Positive examples"):
            try:
                images = pdf_to_images(pdf_path, dpi=150)  # Lower DPI to save memory
                
                for page_num, image in enumerate(images):
                    # Process and store image path instead of image if batch_process
                    if batch_process:
                        dataset.append({
                            'image_path': str(pdf_path),
                            'page_num': page_num + 1,
                            'label': 1,  # 1 for administrative form
                            'file_path': str(pdf_path),
                            'source_folder': 'examples',
                            'original_filename': pdf_path.name,
                            'needs_loading': True
                        })
                    else:
                        dataset.append({
                            'image': preprocess_image_for_donut(image),
                            'label': 1,
                            'file_path': str(pdf_path),
                            'page_num': page_num + 1,
                            'source_folder': 'examples',
                            'original_filename': pdf_path.name
                        })
                    positive_count += 1
                    
                # Clear images from memory
                del images
                
            except Exception as e:
                print(f"Error processing {pdf_path.name}: {e}")
    
    print(f"\nTotal positive examples (form pages): {positive_count}")
    
    # Process negative examples (non-forms)
    # Sample 2x the number of positive examples
    negative_metadata = []  # Store metadata first
    
    if non_example_path and non_example_path.exists():
        pdf_files = list(non_example_path.glob('*.pdf'))
        print(f"\nProcessing {len(pdf_files)} negative example PDFs...")
        print("Note: No pages in these PDFs contain the administrative form")
        
        # First, collect metadata for all negative pages
        for pdf_path in tqdm(pdf_files, desc="Collecting negative page info"):
            try:
                # Use pdfinfo to get page count without loading images
                import subprocess
                result = subprocess.run(['pdfinfo', str(pdf_path)], 
                                      capture_output=True, text=True)
                if result.returncode == 0:
                    for line in result.stdout.split('\n'):
                        if line.startswith('Pages:'):
                            num_pages = int(line.split()[1])
                            for page_num in range(num_pages):
                                negative_metadata.append({
                                    'image_path': str(pdf_path),
                                    'page_num': page_num + 1,
                                    'label': 0,
                                    'file_path': str(pdf_path),
                                    'source_folder': 'non_examples',
                                    'original_filename': pdf_path.name,
                                    'needs_loading': True
                                })
                            break
            except Exception as e:
                print(f"Error getting info for {pdf_path.name}: {e}")
        
        print(f"Total negative pages available: {len(negative_metadata)}")
        
        # Sample 2x the number of positive examples
        target_negative_count = min(len(negative_metadata), positive_count * 2)
        
        if len(negative_metadata) > target_negative_count:
            # Randomly sample to achieve 2:1 ratio
            import random
            random.seed(42)  # For reproducibility
            sampled_negatives = random.sample(negative_metadata, target_negative_count)
            dataset.extend(sampled_negatives)
            print(f"Sampled {target_negative_count} negative pages (2x positive examples)")
        else:
            # Use all available negatives if we don't have enough
            dataset.extend(negative_metadata)
            print(f"Using all {len(negative_metadata)} negative pages")
    
    # Shuffle the dataset
    import random
    random.seed(42)
    random.shuffle(dataset)
    
    print(f"\nFinal dataset composition:")
    print(f"Total samples: {len(dataset)}")
    print(f"Positive examples (forms): {sum(1 for d in dataset if d['label'] == 1)}")
    print(f"Negative examples (non-forms): {sum(1 for d in dataset if d['label'] == 0)}")
    if sum(1 for d in dataset if d['label'] == 0) > 0:
        print(f"Class ratio (pos:neg): 1:{sum(1 for d in dataset if d['label'] == 0) / sum(1 for d in dataset if d['label'] == 1):.2f}")
    
    return dataset

## 5. Model Setup

In [13]:
# Option 1: Use pre-trained Donut for document classification
# We'll fine-tune it for binary classification (form vs non-form)

MODEL_NAME = "naver-clova-ix/donut-base"

# Load processor and model
processor = DonutProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

# Move model to device
model = model.to(device)

print(f"Model loaded: {MODEL_NAME}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Model loaded: naver-clova-ix/donut-base
Model parameters: 201.9M


In [14]:
# Configure model for classification task
# Donut uses a decoder, so we'll set up special tokens for our classification task

# Add special tokens for our task
processor.tokenizer.add_special_tokens({
    "additional_special_tokens": [
        "<admin_form>",
        "<not_admin_form>",
        "<classification>"
    ]
})

# Resize model embeddings to accommodate new tokens
model.decoder.resize_token_embeddings(len(processor.tokenizer))

# Set up decoder start token
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<classification>"])[0]
model.config.pad_token_id = processor.tokenizer.pad_token_id

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


## 6. Training Functions

In [16]:
def prepare_labels_for_training(batch, processor):
    """
    Prepare decoder labels for training
    """
    labels_list = []
    for label in batch['labels']:
        if label == 1:
            text = "<classification> <admin_form>"
        else:
            text = "<classification> <not_admin_form>"
        
        # Tokenize the label
        encoding = processor.tokenizer(
            text,
            add_special_tokens=False,
            max_length=10,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        labels_list.append(encoding.input_ids.squeeze())
    
    return torch.stack(labels_list)

In [17]:
def train_epoch(model, dataloader, optimizer, processor, device):
    """
    Train for one epoch
    """
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        pixel_values = batch['pixel_values'].to(device)
        labels = prepare_labels_for_training(batch, processor).to(device)
        
        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, processor, device):
    """
    Evaluate model
    """
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            pixel_values = batch['pixel_values'].to(device)
            true_labels.extend(batch['labels'].numpy())
            
            # Generate predictions
            outputs = model.generate(
                pixel_values,
                decoder_start_token_id=model.config.decoder_start_token_id,
                max_length=10,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )
            
            # Decode predictions
            prediction_text = processor.batch_decode(outputs.sequences)
            
            # Convert text predictions to binary labels
            for pred_text in prediction_text:
                if "<admin_form>" in pred_text:
                    predictions.append(1)
                else:
                    predictions.append(0)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'predictions': predictions,
        'true_labels': true_labels
    }

## 7. Create and Split Dataset

In [19]:
# Create dataset using ALL positive examples with memory-efficient loading
dataset_list = create_dataset_from_pdfs(
    EXAMPLE_FORMS_PATH,
    NON_EXAMPLES_PATH,
    batch_process=True  # Enable lazy loading to save memory
)

# Document-level train/validation split to prevent data leakage
# Group by original filename and source folder to handle duplicates
from collections import defaultdict

doc_groups = defaultdict(list)
for idx, item in enumerate(dataset_list):
    # Create unique document ID combining source folder and filename
    doc_id = f"{item['source_folder']}_{item['original_filename']}"
    doc_groups[doc_id].append(idx)

# Get unique document IDs and their labels (using first page's label)
doc_ids = list(doc_groups.keys())
doc_labels = [dataset_list[doc_groups[doc_id][0]]['label'] for doc_id in doc_ids]

# Split documents (not pages) into train/val
from sklearn.model_selection import train_test_split
train_doc_ids, val_doc_ids = train_test_split(
    doc_ids, 
    test_size=0.2, 
    random_state=42, 
    stratify=doc_labels
)

# Get page indices for train and validation sets
train_indices = []
val_indices = []

for doc_id in train_doc_ids:
    train_indices.extend(doc_groups[doc_id])

for doc_id in val_doc_ids:
    val_indices.extend(doc_groups[doc_id])

# Create train and validation datasets
train_data = [dataset_list[i] for i in train_indices]
val_data = [dataset_list[i] for i in val_indices]

print(f"\nTrain/Val Split:")
print(f"Training samples: {len(train_data)} ({sum(1 for d in train_data if d['label'] == 1)} positive, {sum(1 for d in train_data if d['label'] == 0)} negative)")
print(f"Validation samples: {len(val_data)} ({sum(1 for d in val_data if d['label'] == 1)} positive, {sum(1 for d in val_data if d['label'] == 0)} negative)")
print(f"Training documents: {len(train_doc_ids)}")
print(f"Validation documents: {len(val_doc_ids)}")

# Verify no document overlap
train_docs = set(f"{d['source_folder']}_{d['original_filename']}" for d in train_data)
val_docs = set(f"{d['source_folder']}_{d['original_filename']}" for d in val_data)
print(f"Document overlap: {len(train_docs.intersection(val_docs))} (should be 0)")

Processing 106 positive example PDFs...
Note: Every page in these PDFs is considered a positive example (the form)


Positive examples: 100%|██████████████████████| 106/106 [00:14<00:00,  7.56it/s]



Total positive examples (form pages): 111

Processing 133 negative example PDFs...
Note: No pages in these PDFs contain the administrative form


Collecting negative page info: 100%|██████████| 133/133 [00:02<00:00, 47.50it/s]

Total negative pages available: 2896
Sampled 222 negative pages (2x positive examples)

Final dataset composition:
Total samples: 333
Positive examples (forms): 111
Negative examples (non-forms): 222
Class ratio (pos:neg): 1:2.00

Train/Val Split:
Training samples: 269 (87 positive, 182 negative)
Validation samples: 64 (24 positive, 40 negative)
Training documents: 157
Validation documents: 40
Document overlap: 0 (should be 0)





In [20]:
# Create PyTorch datasets and dataloaders with lazy loading
train_dataset = FormDataset(train_data, processor, lazy_load=True)
val_dataset = FormDataset(val_data, processor, lazy_load=True)

# Use smaller batch sizes to reduce memory usage
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)

## 8. Training

In [22]:
# Training configuration with optimizations
learning_rate = 5e-5
num_epochs = 3  # Reduced from 5 for faster initial training

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Mixed precision training for faster computation (if using GPU)
use_amp = device.type == 'cuda'
scaler = torch.cuda.amp.GradScaler() if use_amp else None

# Training loop
best_val_accuracy = 0
training_history = []

print(f"Starting training with {'mixed precision' if use_amp else 'full precision'}...")
print(f"Training on {len(train_data)} samples, validating on {len(val_data)} samples")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, processor, device)
    print(f"Training loss: {train_loss:.4f}")
    
    # Evaluate every epoch
    val_metrics = evaluate(model, val_loader, processor, device)
    print(f"Validation accuracy: {val_metrics['accuracy']:.4f}")
    print(f"Validation precision: {val_metrics['precision']:.4f}")
    print(f"Validation recall: {val_metrics['recall']:.4f}")
    print(f"Validation F1: {val_metrics['f1']:.4f}")
    
    # Save best model
    if val_metrics['accuracy'] > best_val_accuracy:
        best_val_accuracy = val_metrics['accuracy']
        torch.save(model.state_dict(), 'best_form_classifier.pth')
        print("Saved best model!")
    
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_accuracy': val_metrics['accuracy'],
        'val_precision': val_metrics['precision'],
        'val_recall': val_metrics['recall'],
        'val_f1': val_metrics['f1']
    })
    
    # Early stopping if perfect accuracy
    if val_metrics['accuracy'] >= 0.99:
        print("Achieved 99% accuracy, stopping early!")
        break

print(f"\nTraining completed! Best validation accuracy: {best_val_accuracy:.4f}")

Starting training with full precision...
Training on 269 samples, validating on 64 samples

Epoch 1/3


Training:   0%|                                         | 0/135 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Training:  69%|██████████████████▌        | 93/135 [2:46:48<1:15:20, 107.62s/it]


KeyboardInterrupt: 

## 9. Inference Functions

In [None]:
def detect_form_in_pdf(pdf_path: Path, model, processor, device, batch_size: int = 8) -> Dict:
    """
    Detect if a PDF contains the administrative form and on which page(s)
    
    Returns:
        Dictionary with:
        - 'contains_form': boolean
        - 'form_pages': list of page numbers where form is detected
        - 'confidence_scores': confidence score for each page
    """
    model.eval()
    
    # Convert PDF to images
    images = pdf_to_images(pdf_path)
    if not images:
        return {'contains_form': False, 'form_pages': [], 'confidence_scores': []}
    
    form_pages = []
    confidence_scores = []
    
    # Process in batches
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i+batch_size]
        batch_processed = [preprocess_image_for_donut(img) for img in batch_images]
        
        # Process images
        pixel_values = torch.stack([
            processor(img, return_tensors="pt").pixel_values.squeeze()
            for img in batch_processed
        ]).to(device)
        
        with torch.no_grad():
            # Generate predictions
            outputs = model.generate(
                pixel_values,
                decoder_start_token_id=model.config.decoder_start_token_id,
                max_length=10,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
                output_scores=True
            )
            
            # Decode predictions
            predictions = processor.batch_decode(outputs.sequences)
            
            # Process predictions
            for j, pred_text in enumerate(predictions):
                page_num = i + j + 1
                
                if "<admin_form>" in pred_text:
                    form_pages.append(page_num)
                    # Simple confidence based on presence of token
                    confidence_scores.append((page_num, 1.0))
                else:
                    confidence_scores.append((page_num, 0.0))
    
    return {
        'contains_form': len(form_pages) > 0,
        'form_pages': form_pages,
        'confidence_scores': confidence_scores,
        'total_pages': len(images)
    }

In [None]:
def process_document_corpus(pdf_folder: Path, model, processor, device, 
                          output_file: str = 'form_detection_results.csv') -> pd.DataFrame:
    """
    Process a corpus of PDF documents and save results
    """
    results = []
    pdf_files = list(pdf_folder.glob('*.pdf'))
    
    print(f"Processing {len(pdf_files)} PDF files...")
    
    for pdf_path in tqdm(pdf_files):
        try:
            detection_result = detect_form_in_pdf(pdf_path, model, processor, device)
            
            results.append({
                'filename': pdf_path.name,
                'filepath': str(pdf_path),
                'contains_form': detection_result['contains_form'],
                'form_pages': ','.join(map(str, detection_result['form_pages'])),
                'num_form_pages': len(detection_result['form_pages']),
                'total_pages': detection_result['total_pages']
            })
            
        except Exception as e:
            print(f"Error processing {pdf_path.name}: {e}")
            results.append({
                'filename': pdf_path.name,
                'filepath': str(pdf_path),
                'contains_form': None,
                'form_pages': '',
                'num_form_pages': 0,
                'total_pages': 0,
                'error': str(e)
            })
    
    # Create DataFrame and save
    df_results = pd.DataFrame(results)
    df_results.to_csv(output_file, index=False)
    
    # Print summary
    print(f"\nResults saved to {output_file}")
    print(f"Total documents processed: {len(df_results)}")
    print(f"Documents with forms: {df_results['contains_form'].sum()}")
    print(f"Documents without forms: {(~df_results['contains_form']).sum()}")
    print(f"Processing errors: {df_results['contains_form'].isna().sum()}")
    
    return df_results

## 10. Test on Example Files

In [None]:
# Test on a single example file
if len(example_files) > 0:
    test_pdf = example_files[0]
    print(f"Testing on: {test_pdf.name}")
    
    result = detect_form_in_pdf(test_pdf, model, processor, device)
    
    print(f"\nResults:")
    print(f"Contains form: {result['contains_form']}")
    print(f"Form pages: {result['form_pages']}")
    print(f"Total pages: {result['total_pages']}")

## 11. Process Full Corpus (Example)

In [None]:
# Example: Process a folder of documents
# Uncomment and modify the path to process your full corpus

# CORPUS_PATH = BASE_PATH / 'data' / 'raw' / 'contracts'  # Adjust to your corpus location
# results_df = process_document_corpus(CORPUS_PATH, model, processor, device)

# # Display some results
# print("\nSample results:")
# print(results_df.head(10))

## 12. Save Model and Configuration

In [None]:
# Save the fine-tuned model
save_directory = "./form_classifier_model"
os.makedirs(save_directory, exist_ok=True)

# Save model and processor
model.save_pretrained(save_directory)
processor.save_pretrained(save_directory)

# Save training configuration
config = {
    'model_name': MODEL_NAME,
    'task': 'administrative_form_detection',
    'num_epochs': num_epochs,
    'learning_rate': learning_rate,
    'best_validation_accuracy': best_val_accuracy,
    'training_history': training_history
}

with open(os.path.join(save_directory, 'training_config.json'), 'w') as f:
    json.dump(config, f, indent=2)

print(f"Model saved to {save_directory}")

## Usage Instructions

### To use the trained model on new documents:

```python
# Load the saved model
from transformers import DonutProcessor, VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained('./form_classifier_model')
processor = DonutProcessor.from_pretrained('./form_classifier_model')

# Detect form in a PDF
result = detect_form_in_pdf(pdf_path, model, processor, device)
print(f"Form found on pages: {result['form_pages']}")
```

### Future Improvements:

1. **Use LayoutLMv3** if you need to combine visual and text features for higher accuracy
2. **Data Augmentation**: Add rotations, noise, etc. to training images
3. **Confidence Calibration**: Implement proper probability scores for predictions
4. **Multi-GPU Training**: For processing the full 190k document corpus
5. **Active Learning**: Use model uncertainty to identify which documents to manually label next