# Find Most Similar Non-Example Pages to Administrative Forms

This notebook uses the Donut model to find pages in the non-examples folder that are most similar to the administrative forms in the examples folder.

## Approach:
1. Extract visual embeddings from the Donut encoder for all example form pages
2. Extract embeddings for all non-example pages
3. Compute similarity scores and find the most similar non-example page
4. Visualize the results

## 1. Setup and Dependencies

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from pdf2image import convert_from_path
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple, Dict
import json

from transformers import DonutProcessor, VisionEncoderDecoderModel

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Model and Define Paths

In [None]:
# Define paths
BASE_PATH = Path('/Users/admin-tascott/Documents/GitHub/chehalis')
EXAMPLE_FORMS_PATH = BASE_PATH / 'data' / 'raw' / '_exampleforms'
NON_EXAMPLES_PATH = BASE_PATH / 'data' / 'raw' / '_nonexamples'

# Check if trained model exists, otherwise use base model
MODEL_DIR = "./form_classifier_model"
if os.path.exists(MODEL_DIR):
    print(f"Loading trained model from {MODEL_DIR}")
    model = VisionEncoderDecoderModel.from_pretrained(MODEL_DIR)
    processor = DonutProcessor.from_pretrained(MODEL_DIR)
else:
    print("Loading base Donut model")
    MODEL_NAME = "naver-clova-ix/donut-base"
    processor = DonutProcessor.from_pretrained(MODEL_NAME)
    model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

# Move model to device
model = model.to(device)
model.eval()  # Set to evaluation mode

print(f"Model loaded successfully")

## 3. PDF Processing Functions

In [None]:
def pdf_to_images(pdf_path: Path, dpi: int = 200) -> List[Image.Image]:
    """Convert PDF to list of PIL Images"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        return images
    except Exception as e:
        print(f"Error converting {pdf_path}: {e}")
        return []

def preprocess_image_for_donut(image: Image.Image, size: Tuple[int, int] = (1280, 960)) -> Image.Image:
    """Preprocess image for Donut model"""
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image.thumbnail(size, Image.Resampling.LANCZOS)
    return image

## 4. Extract Embeddings Function

In [None]:
def extract_embeddings(image: Image.Image, model, processor, device):
    """
    Extract visual embeddings from Donut encoder
    
    Returns:
        numpy array of embeddings
    """
    # Preprocess image
    image = preprocess_image_for_donut(image)
    
    # Get pixel values
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs = model.encoder(pixel_values)
        
        # Use the pooled output or mean of last hidden states
        if hasattr(encoder_outputs, 'pooler_output') and encoder_outputs.pooler_output is not None:
            embeddings = encoder_outputs.pooler_output
        else:
            # Mean pool the last hidden states
            embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
        
        embeddings = embeddings.cpu().numpy()
    
    return embeddings[0]  # Return first (and only) embedding in batch

## 5. Process All Example Forms

In [None]:
# Extract embeddings for all example form pages
example_embeddings = []
example_metadata = []

print("Processing example forms...")
pdf_files = list(EXAMPLE_FORMS_PATH.glob('*.pdf'))

for pdf_path in tqdm(pdf_files, desc="Example PDFs"):
    images = pdf_to_images(pdf_path)
    for page_num, image in enumerate(images):
        embedding = extract_embeddings(image, model, processor, device)
        example_embeddings.append(embedding)
        example_metadata.append({
            'file_path': str(pdf_path),
            'filename': pdf_path.name,
            'page_num': page_num + 1,
            'total_pages': len(images)
        })

example_embeddings = np.array(example_embeddings)
print(f"\nExtracted embeddings for {len(example_embeddings)} example form pages")
print(f"Embedding shape: {example_embeddings.shape}")

## 6. Process All Non-Example Pages

In [None]:
# Extract embeddings for all non-example pages
non_example_embeddings = []
non_example_metadata = []
non_example_images = []  # Store for visualization

print("\nProcessing non-example documents...")
pdf_files = list(NON_EXAMPLES_PATH.glob('*.pdf'))

# Limit number of PDFs for memory management (adjust as needed)
max_pdfs = 50  # Process first 50 PDFs
pdf_files = pdf_files[:max_pdfs]

for pdf_path in tqdm(pdf_files, desc="Non-example PDFs"):
    images = pdf_to_images(pdf_path)
    for page_num, image in enumerate(images):
        embedding = extract_embeddings(image, model, processor, device)
        non_example_embeddings.append(embedding)
        non_example_metadata.append({
            'file_path': str(pdf_path),
            'filename': pdf_path.name,
            'page_num': page_num + 1,
            'total_pages': len(images)
        })
        # Store resized image for visualization
        non_example_images.append(preprocess_image_for_donut(image))

non_example_embeddings = np.array(non_example_embeddings)
print(f"\nExtracted embeddings for {len(non_example_embeddings)} non-example pages")
print(f"Embedding shape: {non_example_embeddings.shape}")

## 7. Compute Similarity Scores

In [None]:
# Compute cosine similarity between all example and non-example pages
print("\nComputing similarity scores...")
similarity_matrix = cosine_similarity(non_example_embeddings, example_embeddings)

# Find the maximum similarity for each non-example page
max_similarities = similarity_matrix.max(axis=1)
most_similar_example_idx = similarity_matrix.argmax(axis=1)

# Find the overall most similar non-example page
most_similar_idx = max_similarities.argmax()
most_similar_score = max_similarities[most_similar_idx]
most_similar_example = most_similar_example_idx[most_similar_idx]

print(f"\nMost similar non-example page:")
print(f"  File: {non_example_metadata[most_similar_idx]['filename']}")
print(f"  Page: {non_example_metadata[most_similar_idx]['page_num']}")
print(f"  Similarity score: {most_similar_score:.4f}")
print(f"  Most similar to example: {example_metadata[most_similar_example]['filename']}, page {example_metadata[most_similar_example]['page_num']}")

## 8. Find Top K Most Similar Pages

In [None]:
# Get top K most similar non-example pages
K = 10
top_k_indices = np.argsort(max_similarities)[-K:][::-1]

print(f"\nTop {K} most similar non-example pages:")
results = []

for rank, idx in enumerate(top_k_indices):
    metadata = non_example_metadata[idx]
    similar_example_idx = most_similar_example_idx[idx]
    similar_example = example_metadata[similar_example_idx]
    
    result = {
        'rank': rank + 1,
        'filename': metadata['filename'],
        'page': metadata['page_num'],
        'similarity_score': max_similarities[idx],
        'similar_to_example': similar_example['filename'],
        'similar_to_page': similar_example['page_num']
    }
    results.append(result)
    
    print(f"\n{rank + 1}. File: {metadata['filename']}, Page: {metadata['page_num']}")
    print(f"   Similarity: {max_similarities[idx]:.4f}")
    print(f"   Most similar to: {similar_example['filename']}, page {similar_example['page_num']}")

# Create DataFrame for easy viewing
results_df = pd.DataFrame(results)
results_df

## 9. Visualize Most Similar Pages

In [None]:
# Visualize top 5 most similar non-example pages
fig, axes = plt.subplots(2, 5, figsize=(20, 10))
fig.suptitle('Top 5 Most Similar Non-Example Pages', fontsize=16)

for i, idx in enumerate(top_k_indices[:5]):
    # Non-example page
    ax_non = axes[0, i]
    ax_non.imshow(non_example_images[idx])
    ax_non.set_title(f"Non-Example\n{non_example_metadata[idx]['filename']}\nPage {non_example_metadata[idx]['page_num']}\nScore: {max_similarities[idx]:.3f}")
    ax_non.axis('off')
    
    # Most similar example page
    similar_example_idx = most_similar_example_idx[idx]
    example_meta = example_metadata[similar_example_idx]
    
    # Load the example image for visualization
    example_pdf_path = Path(example_meta['file_path'])
    example_images = pdf_to_images(example_pdf_path)
    example_image = preprocess_image_for_donut(example_images[example_meta['page_num'] - 1])
    
    ax_ex = axes[1, i]
    ax_ex.imshow(example_image)
    ax_ex.set_title(f"Similar Example\n{example_meta['filename']}\nPage {example_meta['page_num']}")
    ax_ex.axis('off')

plt.tight_layout()
plt.show()

## 10. Analyze Similarity Distribution

In [None]:
# Plot distribution of similarity scores
plt.figure(figsize=(10, 6))
plt.hist(max_similarities, bins=50, alpha=0.7, edgecolor='black')
plt.axvline(x=most_similar_score, color='red', linestyle='--', label=f'Most similar: {most_similar_score:.3f}')
plt.xlabel('Maximum Cosine Similarity Score')
plt.ylabel('Number of Non-Example Pages')
plt.title('Distribution of Similarity Scores\n(Non-Example Pages vs Most Similar Example Form)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Statistics
print(f"\nSimilarity Statistics:")
print(f"Mean similarity: {max_similarities.mean():.4f}")
print(f"Std deviation: {max_similarities.std():.4f}")
print(f"Min similarity: {max_similarities.min():.4f}")
print(f"Max similarity: {max_similarities.max():.4f}")
print(f"\nPages with similarity > 0.9: {(max_similarities > 0.9).sum()}")
print(f"Pages with similarity > 0.8: {(max_similarities > 0.8).sum()}")
print(f"Pages with similarity > 0.7: {(max_similarities > 0.7).sum()}")

## 11. Save Results

In [None]:
# Save detailed results to CSV
all_results = []

for idx, (score, example_idx) in enumerate(zip(max_similarities, most_similar_example_idx)):
    metadata = non_example_metadata[idx]
    similar_example = example_metadata[example_idx]
    
    all_results.append({
        'non_example_file': metadata['filename'],
        'non_example_page': metadata['page_num'],
        'non_example_total_pages': metadata['total_pages'],
        'similarity_score': score,
        'most_similar_example_file': similar_example['filename'],
        'most_similar_example_page': similar_example['page_num']
    })

results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values('similarity_score', ascending=False)
results_df.to_csv('similarity_results.csv', index=False)

print(f"\nResults saved to similarity_results.csv")
print(f"\nTop 10 most similar pages:")
print(results_df.head(10))

## 12. Identify Potential Misclassified Pages

In [None]:
# Pages with very high similarity might be misclassified
threshold = 0.85  # Adjust based on your needs
potential_misclassified = results_df[results_df['similarity_score'] > threshold]

print(f"\nPotential misclassified pages (similarity > {threshold}):")
print(f"Found {len(potential_misclassified)} pages")

if len(potential_misclassified) > 0:
    print("\nThese non-example pages are very similar to the administrative forms:")
    for _, row in potential_misclassified.iterrows():
        print(f"\n- File: {row['non_example_file']}, Page: {row['non_example_page']}")
        print(f"  Similarity: {row['similarity_score']:.4f}")
        print(f"  Similar to: {row['most_similar_example_file']}, page {row['most_similar_example_page']}")
    
    # Save potential misclassified pages
    potential_misclassified.to_csv('potential_misclassified_pages.csv', index=False)
    print("\nSaved to potential_misclassified_pages.csv")

## Usage Notes

This notebook helps identify:

1. **Most Similar Pages**: Which pages in the non-examples are most visually similar to the administrative forms
2. **Potential Misclassifications**: Pages with very high similarity scores might actually be administrative forms that were incorrectly placed in the non-examples folder
3. **Similarity Distribution**: Understanding how similar non-form pages are to form pages helps set classification thresholds

### Next Steps:
- Review pages with high similarity scores (>0.85) to check if they're actually administrative forms
- Use the similarity threshold to improve the classifier's decision boundary
- Consider adding highly similar non-forms as hard negative examples in training