# Administrative Form Detection using Donut

This notebook implements a classifier to detect standardized administrative forms in PDF documents using the Donut transformer model.

## Objectives:
1. Detect if a PDF document contains a specific standardized administrative form
2. Identify which page number the form appears on

## Approach:
- Use Donut (Document Understanding Transformer) for OCR-free document classification
- Process PDFs page by page to identify the administrative form

## Future Options:
- Consider LayoutLMv3 for combining visual and text features if higher accuracy is needed

## 1. Setup and Dependencies

In [None]:
# Install required packages
!pip install transformers torch torchvision pdf2image pillow numpy pandas tqdm scikit-learn

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from pdf2image import convert_from_path
from tqdm import tqdm
import json
from typing import List, Tuple, Dict, Optional

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel,
    AutoConfig,
    AutoTokenizer,
    AutoImageProcessor
)
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Data Exploration

In [None]:
# Define paths
BASE_PATH = Path('/Users/admin-tascott/Documents/GitHub/chehalis')
EXAMPLE_FORMS_PATH = BASE_PATH / 'data' / 'raw' / '_exampleforms'
NON_EXAMPLES_PATH = BASE_PATH / 'data' / 'raw' / '_nonexamples'

# Check if paths exist
print(f"Example forms path exists: {EXAMPLE_FORMS_PATH.exists()}")
print(f"Non-examples path exists: {NON_EXAMPLES_PATH.exists()}")

# List example files
if EXAMPLE_FORMS_PATH.exists():
    example_files = list(EXAMPLE_FORMS_PATH.glob('*.pdf'))
    print(f"\nFound {len(example_files)} example form files")
    for f in example_files[:5]:  # Show first 5
        print(f"  - {f.name}")

# List non-example files
if NON_EXAMPLES_PATH.exists():
    non_example_files = list(NON_EXAMPLES_PATH.glob('*.pdf'))
    print(f"\nFound {len(non_example_files)} non-example files")
    for f in non_example_files[:5]:  # Show first 5
        print(f"  - {f.name}")

## 3. PDF Processing Functions

In [None]:
def pdf_to_images(pdf_path: Path, dpi: int = 200) -> List[Image.Image]:
    """
    Convert PDF to list of PIL Images (one per page)
    
    Args:
        pdf_path: Path to PDF file
        dpi: Resolution for conversion
    
    Returns:
        List of PIL Images
    """
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        return images
    except Exception as e:
        print(f"Error converting {pdf_path}: {e}")
        return []

def preprocess_image_for_donut(image: Image.Image, size: Tuple[int, int] = (1280, 960)) -> Image.Image:
    """
    Preprocess image for Donut model
    
    Args:
        image: PIL Image
        size: Target size (width, height)
    
    Returns:
        Preprocessed PIL Image
    """
    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Resize while maintaining aspect ratio
    image.thumbnail(size, Image.Resampling.LANCZOS)
    
    return image

## 4. Dataset Creation

In [None]:
class FormDataset(Dataset):
    """
    Dataset for administrative form classification
    """
    def __init__(self, data_list: List[Dict], processor):
        self.data = data_list
        self.processor = processor
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image']
        label = item['label']
        
        # Process image
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        return {
            'pixel_values': pixel_values.squeeze(),
            'labels': label,
            'metadata': {
                'file_path': item.get('file_path', ''),
                'page_num': item.get('page_num', -1)
            }
        }

In [None]:
def create_dataset_from_pdfs(example_path: Path, non_example_path: Optional[Path] = None, 
                           max_samples_per_class: int = None) -> List[Dict]:
    """
    Create dataset from PDF files
    
    Returns:
        List of dictionaries with 'image', 'label', 'file_path', 'page_num'
    """
    dataset = []
    
    # Process positive examples (forms)
    if example_path.exists():
        pdf_files = list(example_path.glob('*.pdf'))[:max_samples_per_class]
        print(f"Processing {len(pdf_files)} positive example PDFs...")
        
        for pdf_path in tqdm(pdf_files, desc="Positive examples"):
            images = pdf_to_images(pdf_path)
            for page_num, image in enumerate(images):
                dataset.append({
                    'image': preprocess_image_for_donut(image),
                    'label': 1,  # 1 for administrative form
                    'file_path': str(pdf_path),
                    'page_num': page_num + 1
                })
    
    # Process negative examples (non-forms)
    if non_example_path and non_example_path.exists():
        pdf_files = list(non_example_path.glob('*.pdf'))[:max_samples_per_class]
        print(f"\nProcessing {len(pdf_files)} negative example PDFs...")
        
        for pdf_path in tqdm(pdf_files, desc="Negative examples"):
            images = pdf_to_images(pdf_path)
            for page_num, image in enumerate(images):
                dataset.append({
                    'image': preprocess_image_for_donut(image),
                    'label': 0,  # 0 for non-form
                    'file_path': str(pdf_path),
                    'page_num': page_num + 1
                })
    
    print(f"\nTotal dataset size: {len(dataset)} pages")
    print(f"Positive examples (forms): {sum(1 for d in dataset if d['label'] == 1)}")
    print(f"Negative examples (non-forms): {sum(1 for d in dataset if d['label'] == 0)}")
    
    return dataset

## 5. Model Setup

In [None]:
# Option 1: Use pre-trained Donut for document classification
# We'll fine-tune it for binary classification (form vs non-form)

MODEL_NAME = "naver-clova-ix/donut-base"

# Load processor and model
processor = DonutProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

# Move model to device
model = model.to(device)

print(f"Model loaded: {MODEL_NAME}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

In [None]:
# Configure model for classification task
# Donut uses a decoder, so we'll set up special tokens for our classification task

# Add special tokens for our task
processor.tokenizer.add_special_tokens({
    "additional_special_tokens": [
        "<admin_form>",
        "<not_admin_form>",
        "<classification>"
    ]
})

# Resize model embeddings to accommodate new tokens
model.decoder.resize_token_embeddings(len(processor.tokenizer))

# Set up decoder start token
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<classification>"])[0]
model.config.pad_token_id = processor.tokenizer.pad_token_id

## 6. Training Functions

In [None]:
def prepare_labels_for_training(batch, processor):
    """
    Prepare decoder labels for training
    """
    labels_list = []
    for label in batch['labels']:
        if label == 1:
            text = "<classification> <admin_form>"
        else:
            text = "<classification> <not_admin_form>"
        
        # Tokenize the label
        encoding = processor.tokenizer(
            text,
            add_special_tokens=False,
            max_length=10,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        labels_list.append(encoding.input_ids.squeeze())
    
    return torch.stack(labels_list)

In [None]:
def train_epoch(model, dataloader, optimizer, processor, device):
    """
    Train for one epoch
    """
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        pixel_values = batch['pixel_values'].to(device)
        labels = prepare_labels_for_training(batch, processor).to(device)
        
        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, processor, device):
    """
    Evaluate model
    """
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            pixel_values = batch['pixel_values'].to(device)
            true_labels.extend(batch['labels'].numpy())
            
            # Generate predictions
            outputs = model.generate(
                pixel_values,
                decoder_start_token_id=model.config.decoder_start_token_id,
                max_length=10,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )
            
            # Decode predictions
            prediction_text = processor.batch_decode(outputs.sequences)
            
            # Convert text predictions to binary labels
            for pred_text in prediction_text:
                if "<admin_form>" in pred_text:
                    predictions.append(1)
                else:
                    predictions.append(0)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'predictions': predictions,
        'true_labels': true_labels
    }

## 7. Create and Split Dataset

In [None]:
# Create dataset
# Limit samples for initial testing
dataset_list = create_dataset_from_pdfs(
    EXAMPLE_FORMS_PATH,
    NON_EXAMPLES_PATH,
    max_samples_per_class=10  # Adjust based on your needs
)

# Split into train and validation
train_data, val_data = train_test_split(dataset_list, test_size=0.2, random_state=42, stratify=[d['label'] for d in dataset_list])

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

In [None]:
# Create PyTorch datasets and dataloaders
train_dataset = FormDataset(train_data, processor)
val_dataset = FormDataset(val_data, processor)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

## 8. Training

In [None]:
# Training configuration
learning_rate = 5e-5
num_epochs = 5

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
best_val_accuracy = 0
training_history = []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, processor, device)
    print(f"Training loss: {train_loss:.4f}")
    
    # Evaluate
    val_metrics = evaluate(model, val_loader, processor, device)
    print(f"Validation accuracy: {val_metrics['accuracy']:.4f}")
    print(f"Validation precision: {val_metrics['precision']:.4f}")
    print(f"Validation recall: {val_metrics['recall']:.4f}")
    print(f"Validation F1: {val_metrics['f1']:.4f}")
    
    # Save best model
    if val_metrics['accuracy'] > best_val_accuracy:
        best_val_accuracy = val_metrics['accuracy']
        torch.save(model.state_dict(), 'best_form_classifier.pth')
        print("Saved best model!")
    
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_accuracy': val_metrics['accuracy'],
        'val_precision': val_metrics['precision'],
        'val_recall': val_metrics['recall'],
        'val_f1': val_metrics['f1']
    })

## 9. Inference Functions

In [None]:
def detect_form_in_pdf(pdf_path: Path, model, processor, device, batch_size: int = 8) -> Dict:
    """
    Detect if a PDF contains the administrative form and on which page(s)
    
    Returns:
        Dictionary with:
        - 'contains_form': boolean
        - 'form_pages': list of page numbers where form is detected
        - 'confidence_scores': confidence score for each page
    """
    model.eval()
    
    # Convert PDF to images
    images = pdf_to_images(pdf_path)
    if not images:
        return {'contains_form': False, 'form_pages': [], 'confidence_scores': []}
    
    form_pages = []
    confidence_scores = []
    
    # Process in batches
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i+batch_size]
        batch_processed = [preprocess_image_for_donut(img) for img in batch_images]
        
        # Process images
        pixel_values = torch.stack([
            processor(img, return_tensors="pt").pixel_values.squeeze()
            for img in batch_processed
        ]).to(device)
        
        with torch.no_grad():
            # Generate predictions
            outputs = model.generate(
                pixel_values,
                decoder_start_token_id=model.config.decoder_start_token_id,
                max_length=10,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
                output_scores=True
            )
            
            # Decode predictions
            predictions = processor.batch_decode(outputs.sequences)
            
            # Process predictions
            for j, pred_text in enumerate(predictions):
                page_num = i + j + 1
                
                if "<admin_form>" in pred_text:
                    form_pages.append(page_num)
                    # Simple confidence based on presence of token
                    confidence_scores.append((page_num, 1.0))
                else:
                    confidence_scores.append((page_num, 0.0))
    
    return {
        'contains_form': len(form_pages) > 0,
        'form_pages': form_pages,
        'confidence_scores': confidence_scores,
        'total_pages': len(images)
    }

In [None]:
def process_document_corpus(pdf_folder: Path, model, processor, device, 
                          output_file: str = 'form_detection_results.csv') -> pd.DataFrame:
    """
    Process a corpus of PDF documents and save results
    """
    results = []
    pdf_files = list(pdf_folder.glob('*.pdf'))
    
    print(f"Processing {len(pdf_files)} PDF files...")
    
    for pdf_path in tqdm(pdf_files):
        try:
            detection_result = detect_form_in_pdf(pdf_path, model, processor, device)
            
            results.append({
                'filename': pdf_path.name,
                'filepath': str(pdf_path),
                'contains_form': detection_result['contains_form'],
                'form_pages': ','.join(map(str, detection_result['form_pages'])),
                'num_form_pages': len(detection_result['form_pages']),
                'total_pages': detection_result['total_pages']
            })
            
        except Exception as e:
            print(f"Error processing {pdf_path.name}: {e}")
            results.append({
                'filename': pdf_path.name,
                'filepath': str(pdf_path),
                'contains_form': None,
                'form_pages': '',
                'num_form_pages': 0,
                'total_pages': 0,
                'error': str(e)
            })
    
    # Create DataFrame and save
    df_results = pd.DataFrame(results)
    df_results.to_csv(output_file, index=False)
    
    # Print summary
    print(f"\nResults saved to {output_file}")
    print(f"Total documents processed: {len(df_results)}")
    print(f"Documents with forms: {df_results['contains_form'].sum()}")
    print(f"Documents without forms: {(~df_results['contains_form']).sum()}")
    print(f"Processing errors: {df_results['contains_form'].isna().sum()}")
    
    return df_results

## 10. Test on Example Files

In [None]:
# Test on a single example file
if len(example_files) > 0:
    test_pdf = example_files[0]
    print(f"Testing on: {test_pdf.name}")
    
    result = detect_form_in_pdf(test_pdf, model, processor, device)
    
    print(f"\nResults:")
    print(f"Contains form: {result['contains_form']}")
    print(f"Form pages: {result['form_pages']}")
    print(f"Total pages: {result['total_pages']}")

## 11. Process Full Corpus (Example)

In [None]:
# Example: Process a folder of documents
# Uncomment and modify the path to process your full corpus

# CORPUS_PATH = BASE_PATH / 'data' / 'raw' / 'contracts'  # Adjust to your corpus location
# results_df = process_document_corpus(CORPUS_PATH, model, processor, device)

# # Display some results
# print("\nSample results:")
# print(results_df.head(10))

## 12. Save Model and Configuration

In [None]:
# Save the fine-tuned model
save_directory = "./form_classifier_model"
os.makedirs(save_directory, exist_ok=True)

# Save model and processor
model.save_pretrained(save_directory)
processor.save_pretrained(save_directory)

# Save training configuration
config = {
    'model_name': MODEL_NAME,
    'task': 'administrative_form_detection',
    'num_epochs': num_epochs,
    'learning_rate': learning_rate,
    'best_validation_accuracy': best_val_accuracy,
    'training_history': training_history
}

with open(os.path.join(save_directory, 'training_config.json'), 'w') as f:
    json.dump(config, f, indent=2)

print(f"Model saved to {save_directory}")

## Usage Instructions

### To use the trained model on new documents:

```python
# Load the saved model
from transformers import DonutProcessor, VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained('./form_classifier_model')
processor = DonutProcessor.from_pretrained('./form_classifier_model')

# Detect form in a PDF
result = detect_form_in_pdf(pdf_path, model, processor, device)
print(f"Form found on pages: {result['form_pages']}")
```

### Future Improvements:

1. **Use LayoutLMv3** if you need to combine visual and text features for higher accuracy
2. **Data Augmentation**: Add rotations, noise, etc. to training images
3. **Confidence Calibration**: Implement proper probability scores for predictions
4. **Multi-GPU Training**: For processing the full 190k document corpus
5. **Active Learning**: Use model uncertainty to identify which documents to manually label next