# Specific Form Type Detection using Donut

This notebook uses Donut to detect a specific type of form (as represented by examples in `_exampleforms`),
not just any form. We'll use Donut's ability to extract features and compare documents.

In [29]:
# Set environment variables
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import pandas as pd
import fitz  # PyMuPDF
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import transformers
transformers.logging.set_verbosity_error()

# Paths
pdf_dir = "../../data/raw/_contracts/"
formpage_dir = "../../data/raw/_formpage/"
example_forms_dir = "../../data/raw/_exampleforms/"
nonexample_forms_dir = "../../data/raw/_nonexamples/"

os.makedirs(formpage_dir, exist_ok=True)

pdf_files = [f for f in os.listdir(pdf_dir) if f.endswith('.pdf')]
print(f"Found {len(pdf_files)} PDF files to process")

Found 193450 PDF files to process


In [31]:
# Load Donut model
print("Loading Donut model for feature extraction...")

# Load the base Donut model (not fine-tuned for any specific task)
model_name = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)
model.eval()

print(f"Model loaded on {device}")

Loading Donut model for feature extraction...
Model loaded on mps


In [33]:
# Alternative: Load vision encoder directly for better feature extraction
def extract_donut_features_v2(image, processor, model, device):
    """
    Alternative feature extraction that might work better
    """
    # Process image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    # Get encoder outputs
    with torch.no_grad():
        # Try different pooling strategies
        encoder_outputs = model.encoder(pixel_values)
        
        # Option 1: Global average pooling
        features_avg = encoder_outputs.last_hidden_state.mean(dim=[1, 2])  # Average over spatial dimensions
        
        # Option 2: Use pooler output if available
        if hasattr(encoder_outputs, 'pooler_output') and encoder_outputs.pooler_output is not None:
            features_pooler = encoder_outputs.pooler_output
        else:
            features_pooler = features_avg
        
        # Use average pooling
        features = features_avg.cpu().numpy()
        
        # L2 normalize
        features = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)
    
    return features

# Keep original function but add v2
def extract_donut_features(image, processor, model, device):
    """
    Extract visual features from document using Donut's vision encoder
    """
    # Process image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    # Extract features from the encoder (Swin Transformer)
    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs = model.encoder(pixel_values)
        # Use the last hidden states and pool them
        features = encoder_outputs.last_hidden_state.mean(dim=1)
        features = features.cpu().numpy()
        # Normalize
        features = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-6)
    
    return features

In [35]:
# Function to extract visual features using Donut's encoder
def extract_donut_features(image, processor, model, device):
    """
    Extract visual features from document using Donut's vision encoder
    """
    # Process image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    # Extract features from the encoder (Swin Transformer)
    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs = model.encoder(pixel_values)
        
        # Get the shape of the hidden states
        hidden_states = encoder_outputs.last_hidden_state
        
        # Check dimensions before pooling
        if len(hidden_states.shape) == 4:  # [batch, height, width, channels]
            # Mean pool over spatial dimensions (height and width)
            features = hidden_states.mean(dim=[1, 2])
        elif len(hidden_states.shape) == 3:  # [batch, sequence, features]
            # Mean pool over sequence dimension
            features = hidden_states.mean(dim=1)
        else:
            # Fallback: flatten and take mean
            features = hidden_states.flatten(start_dim=1).mean(dim=1, keepdim=True)
        
        features = features.cpu().numpy()
        
        # L2 normalize
        if len(features.shape) == 1:
            features = features.reshape(1, -1)
        norm = np.linalg.norm(features, axis=1, keepdims=True)
        features = features / (norm + 1e-6)
    
    return features

In [37]:
# Debug: Check encoder output shape
print("Checking encoder output shape...")

# Test with a dummy image
test_img = Image.new('RGB', (224, 224), color='white')
pixel_values = processor(test_img, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)

with torch.no_grad():
    encoder_outputs = model.encoder(pixel_values)
    print(f"Encoder output shape: {encoder_outputs.last_hidden_state.shape}")
    print(f"Number of dimensions: {len(encoder_outputs.last_hidden_state.shape)}")
    
# Test feature extraction
test_features = extract_donut_features(test_img, processor, model, device)
print(f"Extracted features shape: {test_features.shape}")

Checking encoder output shape...
Encoder output shape: torch.Size([1, 4800, 1024])
Number of dimensions: 3
Extracted features shape: (1, 1024)


In [None]:
# This cell has been removed - features are re-extracted in the cell above

In [41]:
# Load example forms and extract their features
print("Loading example forms (specific type we're looking for)...")

example_features = []
example_names = []

if os.path.exists(example_forms_dir):
    example_files = [f for f in os.listdir(example_forms_dir) if f.endswith('.pdf')]
    print(f"Found {len(example_files)} example forms of the specific type")
    
    for example_file in tqdm(example_files, desc="Processing examples"):
        try:
            pdf_path = os.path.join(example_forms_dir, example_file)
            pdf = fitz.open(pdf_path)
            page = pdf[0]  # First page only
            pix = page.get_pixmap()
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            pdf.close()
            
            # Extract features
            features = extract_donut_features(img, processor, model, device)
            example_features.append(features)
            example_names.append(example_file)
            
        except Exception as e:
            print(f"Error loading {example_file}: {e}")
    
    print(f"Successfully loaded {len(example_features)} example features")
else:
    print(f"No example forms found at {example_forms_dir}")

Loading example forms (specific type we're looking for)...
Found 74 example forms of the specific type


Processing examples: 100%|██████████████████████| 74/74 [01:51<00:00,  1.51s/it]

Successfully loaded 74 example features





In [45]:
# Load non-examples (different types of documents)
print("\nLoading non-examples (other document types)...")

nonexample_features = []

if os.path.exists(nonexample_forms_dir):
    nonexample_files = [f for f in os.listdir(nonexample_forms_dir) if f.endswith('.pdf')]
    print(f"Found {len(nonexample_files)} non-example documents")
    
    # Process first page of each non-example
    for nonexample_file in tqdm(nonexample_files, desc="Processing non-examples"):  # Limit to 10
        try:
            pdf_path = os.path.join(nonexample_forms_dir, nonexample_file)
            pdf = fitz.open(pdf_path)
            
            # Just use first page of non-examples
            page = pdf[0]
            pix = page.get_pixmap()
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            
            features = extract_donut_features(img, processor, model, device)
            nonexample_features.append(features)
            
            pdf.close()
            
        except Exception as e:
            print(f"Error loading {nonexample_file}: {e}")
    
    print(f"Loaded {len(nonexample_features)} non-example features")


Loading non-examples (other document types)...
Found 11 non-example documents


Processing non-examples: 100%|██████████████████| 11/11 [00:16<00:00,  1.46s/it]

Loaded 11 non-example features





In [47]:
# Analyze similarity between examples
if len(example_features) > 1:
    print("\nAnalyzing similarity between example forms...")
    
    similarities = []
    for i in range(len(example_features)):
        for j in range(i+1, len(example_features)):
            sim = cosine_similarity(example_features[i], example_features[j])[0][0]
            similarities.append(sim)
    
    avg_similarity = np.mean(similarities)
    min_similarity = np.min(similarities)
    max_similarity = np.max(similarities)
    
    print(f"Example-to-example similarity:")
    print(f"  Average: {avg_similarity:.3f}")
    print(f"  Min: {min_similarity:.3f}")
    print(f"  Max: {max_similarity:.3f}")
    
    # Suggest threshold - use a lower threshold since we're looking for specific forms
    suggested_threshold = min_similarity * 0.8  # 80% of minimum similarity
    print(f"\nSuggested threshold: {suggested_threshold:.3f}")
else:
    suggested_threshold = 0.7  # Default threshold


Analyzing similarity between example forms...
Example-to-example similarity:
  Average: 0.990
  Min: 0.940
  Max: 1.000

Suggested threshold: 0.752


In [49]:
# Detection function for specific form type
def detect_specific_form_type(image, processor, model, device, 
                             example_features, nonexample_features=None,
                             similarity_threshold=0.7, negative_threshold=0.8):
    """
    Detect if a page is the specific form type represented by examples
    """
    # Extract features from the page
    page_features = extract_donut_features(image, processor, model, device)
    
    # Compare to positive examples
    positive_similarities = []
    for ex_feat in example_features:
        sim = cosine_similarity(page_features, ex_feat)[0][0]
        positive_similarities.append(sim)
    
    max_positive_sim = max(positive_similarities) if positive_similarities else 0
    avg_positive_sim = np.mean(positive_similarities) if positive_similarities else 0
    
    # Compare to negative examples if provided
    if nonexample_features:
        negative_similarities = []
        for neg_feat in nonexample_features:
            sim = cosine_similarity(page_features, neg_feat)[0][0]
            negative_similarities.append(sim)
        max_negative_sim = max(negative_similarities) if negative_similarities else 0
        avg_negative_sim = np.mean(negative_similarities) if negative_similarities else 0
    else:
        max_negative_sim = 0
        avg_negative_sim = 0
    
    # Decision logic: must be similar to examples
    is_specific_form = max_positive_sim > similarity_threshold
    
    # Additional check: if too similar to non-examples relative to examples, reject
    if nonexample_features and max_negative_sim > 0:
        # Only reject if negative similarity is very close to positive similarity
        similarity_ratio = max_positive_sim / (max_negative_sim + 1e-6)
        if similarity_ratio < 1.2:  # Less than 20% more similar to positives than negatives
            is_specific_form = False
    
    # Confidence based on:
    # 1. How similar to positive examples
    # 2. How much more similar to positives than negatives
    if nonexample_features and len(negative_similarities) > 0:
        # Confidence is high when positive sim is high AND negative sim is low
        confidence = max_positive_sim * (1 - max_negative_sim)
    else:
        confidence = max_positive_sim
    
    return {
        'is_specific_form': is_specific_form,
        'confidence': confidence,
        'max_similarity_to_examples': max_positive_sim,
        'avg_similarity_to_examples': avg_positive_sim,
        'max_similarity_to_nonexamples': max_negative_sim
    }

In [51]:
# Debug: Check if features are being extracted correctly
print("=== DEBUGGING FEATURE EXTRACTION ===\n")

# Test with first example
if len(example_features) > 0 and len(example_names) > 0:
    test_file = example_names[0]
    test_feat_stored = example_features[0]
    
    # Re-load and re-extract features for the same file
    pdf_path = os.path.join(example_forms_dir, test_file)
    pdf = fitz.open(pdf_path)
    page = pdf[0]
    pix = page.get_pixmap()
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    pdf.close()
    
    # Extract features again
    test_feat_new = extract_donut_features(img, processor, model, device)
    
    # Compare
    print(f"Testing with: {test_file}")
    print(f"Stored feature shape: {test_feat_stored.shape}")
    print(f"New feature shape: {test_feat_new.shape}")
    
    # Self-similarity (should be 1.0)
    self_similarity = cosine_similarity(test_feat_new, test_feat_new)[0][0]
    print(f"Self-similarity (should be 1.0): {self_similarity:.6f}")
    
    # Similarity between stored and new extraction (should be 1.0)
    stored_vs_new = cosine_similarity(test_feat_stored, test_feat_new)[0][0]
    print(f"Stored vs new extraction (should be 1.0): {stored_vs_new:.6f}")
    
    # Check against all stored features
    print(f"\nSimilarity to all stored examples:")
    for i, ex_feat in enumerate(example_features[:5]):
        sim = cosine_similarity(test_feat_new, ex_feat)[0][0]
        print(f"  Example {i}: {sim:.6f}")
    
    # Check feature statistics
    print(f"\nFeature statistics:")
    print(f"  Stored - Min: {test_feat_stored.min():.4f}, Max: {test_feat_stored.max():.4f}, Mean: {test_feat_stored.mean():.4f}")
    print(f"  New    - Min: {test_feat_new.min():.4f}, Max: {test_feat_new.max():.4f}, Mean: {test_feat_new.mean():.4f}")
    
    # Check normalization
    print(f"\nNormalization check:")
    print(f"  Stored norm: {np.linalg.norm(test_feat_stored):.6f}")
    print(f"  New norm: {np.linalg.norm(test_feat_new):.6f}")

=== DEBUGGING FEATURE EXTRACTION ===

Testing with: 25581-000.pdf
Stored feature shape: (1, 1024)
New feature shape: (1, 1024)
Self-similarity (should be 1.0): 1.000000
Stored vs new extraction (should be 1.0): 1.000000

Similarity to all stored examples:
  Example 0: 1.000000
  Example 1: 0.998013
  Example 2: 0.996304
  Example 3: 0.999286
  Example 4: 0.992019

Feature statistics:
  Stored - Min: -0.9341, Max: 0.1705, Mean: -0.0010
  New    - Min: -0.9341, Max: 0.1705, Mean: -0.0010

Normalization check:
  Stored norm: 1.000000
  New norm: 1.000000


In [61]:
# Debug the detection function itself
print("\n=== DEBUGGING DETECTION FUNCTION ===\n")

if len(example_features) > 0:
    # Load first example
    test_file = example_names[0]
    pdf_path = os.path.join(example_forms_dir, test_file)
    pdf = fitz.open(pdf_path)
    page = pdf[0]
    pix = page.get_pixmap()
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    pdf.close()
    
    # Manual detection steps
    print(f"Testing detection on: {test_file}")
    print(f"Threshold: {suggested_threshold}")
    
    # Extract features
    page_features = extract_donut_features(img, processor, model, device)
    print(f"\\nPage features shape: {page_features.shape}")
    print(f"Page features norm: {np.linalg.norm(page_features):.6f}")
    
    # Compare to each example
    print(f"\\nComparing to {len(example_features)} examples:")
    similarities = []
    for i, ex_feat in enumerate(example_features):
        sim = cosine_similarity(page_features, ex_feat)[0][0]
        similarities.append(sim)
        if i < 5:  # Show first 5
            print(f"  Example {i}: {sim:.6f} {'*' if i == 0 else ''}")  # Mark if it's comparing to itself
    
    max_sim = max(similarities)
    print(f"\\nMax similarity: {max_sim:.6f}")
    print(f"Is above threshold ({suggested_threshold:.3f})? {max_sim > suggested_threshold}")
    
    # The first example should have similarity ~1.0 to itself
    print(f"\\nExpected ~1.0 for first example, got: {similarities[0]:.6f}")


=== DEBUGGING DETECTION FUNCTION ===

Testing detection on: 25581-000.pdf
Threshold: 0.7523807525634766
\nPage features shape: (1, 1024)
Page features norm: 1.000000
\nComparing to 74 examples:
  Example 0: 1.000000 *
  Example 1: 0.998013 
  Example 2: 0.996304 
  Example 3: 0.999286 
  Example 4: 0.992019 
\nMax similarity: 1.000000
Is above threshold (0.752)? True
\nExpected ~1.0 for first example, got: 1.000000


In [63]:
# Debug: Check similarity to non-examples
print("\n=== CHECKING SIMILARITY TO NON-EXAMPLES ===\n")

if len(example_features) > 0 and len(nonexample_features) > 0:
    # Test first example against non-examples
    test_file = example_names[0]
    pdf_path = os.path.join(example_forms_dir, test_file)
    pdf = fitz.open(pdf_path)
    page = pdf[0]
    pix = page.get_pixmap()
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    pdf.close()
    
    # Extract features
    test_features = extract_donut_features(img, processor, model, device)
    
    print(f"Testing {test_file} against non-examples:")
    
    # Check against non-examples
    for i, neg_feat in enumerate(nonexample_features):
        sim = cosine_similarity(test_features, neg_feat)[0][0]
        print(f"  Non-example {i}: {sim:.3f}")
    
    # Also check an actual non-example against examples
    print("\n\nChecking first non-example against examples:")
    for i, ex_feat in enumerate(example_features[:5]):
        sim = cosine_similarity(nonexample_features[0], ex_feat)[0][0]
        print(f"  Example {i}: {sim:.3f}")


=== CHECKING SIMILARITY TO NON-EXAMPLES ===

Testing 25581-000.pdf against non-examples:
  Non-example 0: 0.942
  Non-example 1: 0.972
  Non-example 2: 0.987
  Non-example 3: 0.982
  Non-example 4: 0.969
  Non-example 5: 0.967
  Non-example 6: 0.970
  Non-example 7: 0.971
  Non-example 8: 0.968
  Non-example 9: 0.970
  Non-example 10: 0.969


Checking first non-example against examples:
  Example 0: 0.942
  Example 1: 0.937
  Example 2: 0.942
  Example 3: 0.944
  Example 4: 0.935


In [65]:
# Test on example forms themselves
print("\nTesting detection on example forms (should all be detected)...")

threshold = suggested_threshold if 'suggested_threshold' in locals() else 0.5

for i, (example_file, example_feat) in enumerate(zip(example_names[:5], example_features[:5])):
    # Load the image
    pdf_path = os.path.join(example_forms_dir, example_file)
    pdf = fitz.open(pdf_path)
    page = pdf[0]
    pix = page.get_pixmap()
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    pdf.close()
    
    # Test detection
    result = detect_specific_form_type(
        img, processor, model, device,
        example_features, nonexample_features,
        similarity_threshold=threshold
    )
    
    print(f"\n{example_file}:")
    print(f"  Detected: {result['is_specific_form']}")
    print(f"  Max similarity: {result['max_similarity_to_examples']:.3f}")
    print(f"  Confidence: {result['confidence']:.3f}")


Testing detection on example forms (should all be detected)...

25581-000.pdf:
  Detected: False
  Max similarity: 1.000
  Confidence: 0.013

99171-000.pdf:
  Detected: False
  Max similarity: 1.000
  Confidence: 0.015

13924-002.pdf:
  Detected: False
  Max similarity: 1.000
  Confidence: 0.018

1197-000.pdf:
  Detected: False
  Max similarity: 1.000
  Confidence: 0.014

67419-000.pdf:
  Detected: False
  Max similarity: 1.000
  Confidence: 0.022


In [19]:
# Process PDFs to find specific form type
def process_pdfs_for_specific_form(pdf_files, pdf_dir, formpage_dir, 
                                  processor, model, device,
                                  example_features, nonexample_features=None,
                                  similarity_threshold=0.7, max_files=None):
    """
    Process PDFs to find pages matching the specific form type
    """
    if max_files:
        pdf_files = pdf_files[:max_files]
    
    results = []
    
    for pdf_file in tqdm(pdf_files, desc="Processing PDFs"):
        try:
            pdf_path = os.path.join(pdf_dir, pdf_file)
            pdf_document = fitz.open(pdf_path)
            
            best_match_page = None
            best_similarity = 0
            best_confidence = 0
            page_results = []
            
            # Check each page
            for page_num in range(len(pdf_document)):
                page = pdf_document[page_num]
                pix = page.get_pixmap()
                img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                
                # Detect specific form type
                detection = detect_specific_form_type(
                    img, processor, model, device,
                    example_features, nonexample_features,
                    similarity_threshold
                )
                
                page_results.append({
                    'page': page_num + 1,
                    'is_specific_form': detection['is_specific_form'],
                    'similarity': detection['max_similarity_to_examples'],
                    'confidence': detection['confidence']
                })
                
                # Track best matching page
                if detection['is_specific_form'] and detection['max_similarity_to_examples'] > best_similarity:
                    best_similarity = detection['max_similarity_to_examples']
                    best_confidence = detection['confidence']
                    best_match_page = page_num
                    
                    # Stop early if very high similarity
                    if best_similarity > 0.95:
                        break
            
            # Store results
            result = {
                'file': pdf_file,
                'total_pages': len(pdf_document),
                'has_specific_form': best_match_page is not None,
                'best_match_page': best_match_page + 1 if best_match_page is not None else None,
                'best_similarity': best_similarity,
                'confidence': best_confidence
            }
            results.append(result)
            
            # Extract best matching page if found
            if best_match_page is not None:
                output_pdf = fitz.open()
                output_pdf.insert_pdf(pdf_document, from_page=best_match_page, to_page=best_match_page)
                
                output_path = os.path.join(formpage_dir, pdf_file)
                output_pdf.save(output_path)
                output_pdf.close()
            
            pdf_document.close()
            
        except Exception as e:
            print(f"\nError processing {pdf_file}: {str(e)}")
            results.append({
                'file': pdf_file,
                'error': str(e)
            })
    
    return results

In [None]:
# Process a test batch
print("\nProcessing first 10 PDFs to find specific form type...")

results = process_pdfs_for_specific_form(
    pdf_files[:10],
    pdf_dir,
    formpage_dir,
    processor,
    model,
    device,
    example_features,
    nonexample_features,
    similarity_threshold=threshold
)

# Convert to DataFrame
results_df = pd.DataFrame(results)
print(f"\nProcessed {len(results_df)} PDFs")
print(f"Found {results_df['has_specific_form'].sum()} PDFs with the specific form type")

# Show results
print("\nResults:")
for _, row in results_df.iterrows():
    if 'error' not in row:
        if row['has_specific_form']:
            print(f"{row['file']}: Found on page {row['best_match_page']} (similarity: {row['best_similarity']:.3f})")
        else:
            print(f"{row['file']}: Not found")

In [None]:
# Save results
output_path = '../../data/intermediate_products/donut_specific_form_detection.csv'
os.makedirs(os.path.dirname(output_path), exist_ok=True)
results_df.to_csv(output_path, index=False)
print(f"\nResults saved to: {output_path}")

In [None]:
# Full processing
"""
print("\n=== PROCESSING ALL PDFs FOR SPECIFIC FORM TYPE ===")
print(f"Looking for forms similar to the {len(example_features)} examples")
print(f"Using similarity threshold: {threshold:.3f}")

all_results = process_pdfs_for_specific_form(
    pdf_files,
    pdf_dir,
    formpage_dir,
    processor,
    model,
    device,
    example_features,
    nonexample_features,
    similarity_threshold=threshold
)

# Save results
all_results_df = pd.DataFrame(all_results)
all_results_df.to_csv('../../data/intermediate_products/donut_specific_form_all.csv', index=False)

print(f"\nProcessing complete!")
print(f"Total processed: {len(all_results_df)}")
print(f"Found specific form type in: {all_results_df['has_specific_form'].sum()} documents")
"""

## Approach Summary

This notebook uses Donut's vision encoder to:
1. Extract visual features from your specific example forms
2. Compare new documents against these examples
3. Find pages that match your specific form type (not just any form)

### Key Differences:
- **Not asking "is this a form?"** - Instead asking "is this like my example forms?"
- **Uses similarity matching** - Compares visual features to examples
- **Handles non-examples** - Can reject documents that are forms but not your type

### Advantages:
- No need to describe the form in words
- Learns from your examples
- Can distinguish between different form types
- Uses Donut's document understanding capabilities