# Form Detection using Donut Model

This notebook uses the Donut (Document Understanding Transformer) model to classify PDF pages as forms or non-forms.
Donut is an OCR-free document understanding model that can directly process document images.

In [26]:
# Set environment variable to avoid tokenizer warnings
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import pandas as pd
import fitz  # PyMuPDF for PDF handling
import torch
import numpy as np
from tqdm import tqdm
import json

# Paths
pdf_dir = "../../data/raw/_contracts/"
formpage_dir = "../../data/raw/_formpage/"
example_forms_dir = "../../data/raw/_exampleforms/"
nonexample_forms_dir = "../../data/raw/_nonexamples/"

# Create formpage directory if it doesn't exist
os.makedirs(formpage_dir, exist_ok=True)

# Get list of PDF files
pdf_files = [f for f in os.listdir(pdf_dir) if f.endswith('.pdf')]
print(f"Found {len(pdf_files)} PDF files to process")

Found 193450 PDF files to process


In [28]:
# Load Donut model for document classification
print("Loading Donut model...")

# Using the document classification version of Donut
model_name = "naver-clova-ix/donut-base-finetuned-rvlcdip"
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)
model.eval()

print(f"Model loaded successfully on {device}!")

Loading Donut model...
Model loaded successfully on mps!


In [30]:
# Simple classification function using Donut
def classify_page_with_donut(image, processor, model, device):
    """
    Classify a document page as form or non-form using Donut
    
    Returns:
        dict with 'is_form', 'confidence', and 'raw_output'
    """
    # Prepare the prompt for classification
    # We'll ask Donut to classify the document
    task_prompt = "<s_rvlcdip><s_class>"
    
    # Process image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    # Prepare decoder input
    decoder_input_ids = processor.tokenizer(task_prompt, 
                                           add_special_tokens=False, 
                                           return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    
    # Generate classification
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    
    # Decode the output
    prediction = processor.batch_decode(outputs.sequences)[0]
    prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    prediction = prediction.replace(task_prompt, "").strip()
    
    # Parse the classification result
    # Donut RVL-CDIP outputs document types like "form", "letter", "memo", etc.
    # We'll consider it a form if it contains "form" in the classification
    is_form = "form" in prediction.lower()
    
    # For confidence, we can use the model's confidence scores
    # For now, we'll use a simple binary confidence
    confidence = 0.9 if is_form else 0.1
    
    return {
        'is_form': is_form,
        'confidence': confidence,
        'classification': prediction
    }

In [32]:
# Alternative: Use Donut with custom prompting for better form detection
def classify_form_with_prompt(image, processor, model, device):
    """
    Use Donut with a custom prompt to detect forms
    This approach asks a direct question about the document
    """
    # Create a prompt that asks if this is a form
    prompt = "<s_docvqa><s_question>Is this document a form with fields to fill out?</s_question><s_answer>"
    
    # Process image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    # Prepare decoder input
    decoder_input_ids = processor.tokenizer(prompt, 
                                           add_special_tokens=False, 
                                           return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    
    # Generate answer
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=20,  # Short answer expected
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    
    # Decode the answer
    answer = processor.batch_decode(outputs.sequences)[0]
    answer = answer.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    answer = answer.replace(prompt, "").strip().lower()
    
    # Interpret the answer
    is_form = "yes" in answer or "form" in answer
    confidence = 0.95 if "yes" in answer else 0.7 if "form" in answer else 0.2
    
    return {
        'is_form': is_form,
        'confidence': confidence,
        'answer': answer
    }

In [34]:
# Test on a single PDF
if pdf_files:
    test_file = pdf_files[0]
    print(f"Testing on: {test_file}")
    
    pdf_path = os.path.join(pdf_dir, test_file)
    pdf_document = fitz.open(pdf_path)
    
    # Test first page
    page = pdf_document[0]
    pix = page.get_pixmap()
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    
    # Test classification
    result = classify_page_with_donut(img, processor, model, device)
    print(f"\nClassification result:")
    print(f"  Is form: {result['is_form']}")
    print(f"  Confidence: {result['confidence']:.2f}")
    print(f"  Classification: {result['classification']}")
    
    pdf_document.close()

Testing on: 25581-000.pdf


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



Classification result:
  Is form: True
  Confidence: 0.90
  Classification: <form/></s_class>


In [15]:
# Process PDFs to find form pages
def process_pdfs_with_donut(pdf_files, pdf_dir, formpage_dir, processor, model, device,
                           confidence_threshold=0.5, max_files=None):
    """
    Process PDFs to find and extract form pages using Donut
    """
    if max_files:
        pdf_files = pdf_files[:max_files]
    
    results = []
    
    for pdf_file in tqdm(pdf_files, desc="Processing PDFs"):
        try:
            pdf_path = os.path.join(pdf_dir, pdf_file)
            pdf_document = fitz.open(pdf_path)
            
            best_form_page = None
            best_confidence = 0
            page_results = []
            
            # Check each page
            for page_num in range(len(pdf_document)):
                page = pdf_document[page_num]
                pix = page.get_pixmap()
                img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                
                # Classify page
                detection = classify_page_with_donut(img, processor, model, device)
                
                page_results.append({
                    'page': page_num + 1,
                    'is_form': detection['is_form'],
                    'confidence': detection['confidence'],
                    'classification': detection.get('classification', '')
                })
                
                # Track best form page
                if detection['is_form'] and detection['confidence'] > best_confidence:
                    best_confidence = detection['confidence']
                    best_form_page = page_num
                    
                    # Stop early if very confident
                    if best_confidence > 0.9:
                        break
            
            # Store results
            result = {
                'file': pdf_file,
                'total_pages': len(pdf_document),
                'has_form': best_form_page is not None,
                'form_page': best_form_page + 1 if best_form_page is not None else None,
                'confidence': best_confidence,
                'page_details': page_results
            }
            results.append(result)
            
            # Extract best form page if found
            if best_form_page is not None and best_confidence > confidence_threshold:
                output_pdf = fitz.open()
                output_pdf.insert_pdf(pdf_document, from_page=best_form_page, to_page=best_form_page)
                
                output_path = os.path.join(formpage_dir, pdf_file)
                output_pdf.save(output_path)
                output_pdf.close()
            
            pdf_document.close()
            
        except Exception as e:
            print(f"\nError processing {pdf_file}: {str(e)}")
            results.append({
                'file': pdf_file,
                'error': str(e)
            })
    
    return results

In [None]:
# Process a small batch
print("Processing first 10 PDFs...")
results = process_pdfs_with_donut(
    pdf_files[:10], 
    pdf_dir, 
    formpage_dir, 
    processor, 
    model, 
    device,
    confidence_threshold=0.5
)

# Convert to DataFrame for analysis
results_df = pd.DataFrame(results)
print(f"\nProcessed {len(results_df)} PDFs")
print(f"Found {results_df['has_form'].sum()} PDFs with forms")

# Show some results
if len(results_df) > 0:
    print("\nSample results:")
    for _, row in results_df.head().iterrows():
        if 'error' not in row:
            status = "Form found" if row['has_form'] else "No form"
            print(f"{row['file']}: {status} (confidence: {row['confidence']:.2f})")

Processing first 10 PDFs...


Processing PDFs:   0%|                                   | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFOR

In [None]:
# Test on example forms to verify detection
if os.path.exists(example_forms_dir):
    print("\n=== Testing on known example forms ===")
    example_files = [f for f in os.listdir(example_forms_dir) if f.endswith('.pdf')][:5]
    
    for example_file in example_files:
        pdf_path = os.path.join(example_forms_dir, example_file)
        
        try:
            pdf = fitz.open(pdf_path)
            page = pdf[0]
            pix = page.get_pixmap()
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            pdf.close()
            
            # Test detection
            result = classify_page_with_donut(img, processor, model, device)
            
            print(f"\n{example_file}:")
            print(f"  Detected as form: {result['is_form']}")
            print(f"  Classification: {result['classification']}")
            
        except Exception as e:
            print(f"Error testing {example_file}: {e}")

In [None]:
# Save results
output_path = '../../data/intermediate_products/donut_form_detection_results.csv'
os.makedirs(os.path.dirname(output_path), exist_ok=True)
results_df.to_csv(output_path, index=False)
print(f"Results saved to: {output_path}")

In [None]:
# Full processing - uncomment to process all PDFs
"""
print("\n=== PROCESSING ALL PDFs WITH DONUT ===")
print(f"Total PDFs to process: {len(pdf_files)}")
print("This will take a while...")

all_results = process_pdfs_with_donut(
    pdf_files, 
    pdf_dir, 
    formpage_dir, 
    processor, 
    model, 
    device,
    confidence_threshold=0.5
)

# Save full results
all_results_df = pd.DataFrame(all_results)
all_results_df.to_csv('../../data/intermediate_products/donut_form_detection_all.csv', index=False)

print(f"\nProcessing complete!")
print(f"Total processed: {len(all_results_df)}")
print(f"Forms found: {all_results_df['has_form'].sum()}")
"""

## Notes on Donut for Form Detection

### Advantages:
1. **No OCR needed** - Donut processes images directly
2. **Document understanding** - Trained specifically on documents
3. **Simple to use** - Just asks "what type of document is this?"
4. **Fast** - Single pass through the model

### Limitations:
1. The base model may need fine-tuning for specific form types
2. Classification categories are pre-defined (but include "form")
3. Less flexible than example-based matching

### Next Steps:
1. Fine-tune Donut on your specific forms if needed
2. Use the extracted form pages for downstream processing
3. Consider using Donut for form field extraction later