# Zero-Shot Administrative Form Classifier with Parallel Processing

This notebook implements a zero-shot classifier that identifies administrative forms using embedding similarity, with multiprocessing support for efficient large-scale processing.

In [None]:
# 1. Setup and imports
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from pdf2image import convert_from_path
from tqdm import tqdm
import json
import pickle
from typing import List, Dict, Tuple, Optional
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import seaborn as sns

# Multiprocessing imports
from multiprocessing import Pool, cpu_count
import time

# Model imports
from transformers import (
    DonutProcessor, 
    VisionEncoderDecoderModel,
    CLIPProcessor, 
    CLIPModel,
    AutoImageProcessor, 
    AutoModel
)

# Check device and cores
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Available CPU cores: {cpu_count()}")

In [None]:
# 2. Configuration
BASE_PATH = Path('/Users/admin-tascott/Documents/GitHub/chehalis')
EXAMPLE_FORMS_PATH = BASE_PATH / 'data' / 'raw' / '_exampleforms'
NON_EXAMPLES_PATH = BASE_PATH / 'data' / 'raw' / '_nonexamples'

# Model selection
MODEL_TYPE = "clip"  # Options: "donut", "clip", "dinov2"

# Processing parameters
IMAGE_DPI = 150
BATCH_SIZE = 8
SIMILARITY_THRESHOLD = 0.85

# Parallel processing parameters
USE_MULTIPROCESSING = True
N_WORKERS = cpu_count() - 1
MAX_PDFS_PER_WORKER = 100

# Output paths
EMBEDDINGS_CACHE_PATH = BASE_PATH / 'code' / 'preprocessing' / 'cached_embeddings'
EMBEDDINGS_CACHE_PATH.mkdir(exist_ok=True)

In [None]:
# 3. Load model
def load_model(model_type: str):
    """Load the specified model for embedding extraction"""
    
    if model_type == "donut":
        MODEL_DIR = "./form_classifier_model"
        if os.path.exists(MODEL_DIR):
            print(f"Loading fine-tuned Donut model from {MODEL_DIR}")
            model = VisionEncoderDecoderModel.from_pretrained(MODEL_DIR)
            processor = DonutProcessor.from_pretrained(MODEL_DIR)
        else:
            print("Loading base Donut model")
            model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
            processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        return model.encoder.to(device), processor
    
    elif model_type == "clip":
        print("Loading CLIP model")
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        return model.vision_model.to(device), processor
    
    elif model_type == "dinov2":
        print("Loading DINOv2 model")
        processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
        model = AutoModel.from_pretrained('facebook/dinov2-base')
        return model.to(device), processor
    
    else:
        raise ValueError(f"Unknown model type: {model_type}")

# Load the selected model
model, processor = load_model(MODEL_TYPE)
model.eval()
print(f"Model loaded successfully")

In [None]:
# 4. Basic helper functions
def pdf_to_images(pdf_path: Path, dpi: int = IMAGE_DPI) -> List[Image.Image]:
    """Convert PDF to list of PIL Images"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        return images
    except Exception as e:
        print(f"Error converting {pdf_path}: {e}")
        return []

def preprocess_image(image: Image.Image, model_type: str) -> Image.Image:
    """Preprocess image based on model requirements"""
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    if model_type == "donut":
        image.thumbnail((1280, 960), Image.Resampling.LANCZOS)
    else:
        image.thumbnail((224, 224), Image.Resampling.LANCZOS)
    
    return image

In [None]:
# 5. Embedding extraction functions
@torch.no_grad()
def extract_embeddings(images: List[Image.Image], model, processor, model_type: str, 
                      batch_size: int = BATCH_SIZE) -> np.ndarray:
    """Extract embeddings for a list of images"""
    embeddings = []
    
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i+batch_size]
        batch_images = [preprocess_image(img, model_type) for img in batch_images]
        
        if model_type in ["donut", "clip"]:
            inputs = processor(images=batch_images, return_tensors="pt").to(device)
            outputs = model(**inputs)
        else:  # dinov2
            inputs = processor(images=batch_images, return_tensors="pt").to(device)
            outputs = model(**inputs)
        
        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            batch_embeddings = outputs.pooler_output
        elif hasattr(outputs, 'last_hidden_state'):
            batch_embeddings = outputs.last_hidden_state.mean(dim=1)
        else:
            batch_embeddings = outputs[0].mean(dim=1)
        
        embeddings.append(batch_embeddings.cpu().numpy())
    
    return np.vstack(embeddings)

@torch.no_grad()
def extract_embedding_single(image: Image.Image, model, processor, model_type: str) -> np.ndarray:
    """Extract embedding for a single image"""
    image = preprocess_image(image, model_type)
    
    if model_type in ["donut", "clip"]:
        inputs = processor(images=image, return_tensors="pt").to(device)
        outputs = model(**inputs)
    else:  # dinov2
        inputs = processor(images=image, return_tensors="pt").to(device)
        outputs = model(**inputs)
    
    if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
        embedding = outputs.pooler_output
    elif hasattr(outputs, 'last_hidden_state'):
        embedding = outputs.last_hidden_state.mean(dim=1)
    else:
        embedding = outputs[0].mean(dim=1)
    
    return embedding.cpu().numpy()[0]

In [None]:
# 6. Parallel processing functions
def init_worker(model_type_arg):
    """Initialize model in each worker process"""
    global worker_model, worker_processor, worker_model_type
    worker_model_type = model_type_arg
    
    if model_type_arg == "clip":
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        worker_model = model.vision_model
    elif model_type_arg == "donut":
        model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
        processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        worker_model = model.encoder
    elif model_type_arg == "dinov2":
        processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
        model = AutoModel.from_pretrained('facebook/dinov2-base')
        worker_model = model
    
    worker_model.eval()
    worker_processor = processor

def process_pdf_parallel(pdf_path: Path) -> Dict:
    """Process a single PDF in a worker process"""
    global worker_model, worker_processor, worker_model_type
    
    results = {
        'filename': pdf_path.name,
        'path': str(pdf_path),
        'embeddings': [],
        'page_numbers': [],
        'error': None
    }
    
    try:
        images = pdf_to_images(pdf_path, dpi=IMAGE_DPI)
        if not images:
            results['error'] = "Failed to convert PDF"
            return results
        
        for page_num, image in enumerate(images, 1):
            try:
                image = preprocess_image(image, worker_model_type)
                
                with torch.no_grad():
                    if worker_model_type in ["donut", "clip"]:
                        inputs = worker_processor(images=image, return_tensors="pt")
                        outputs = worker_model(**inputs)
                    else:  # dinov2
                        inputs = worker_processor(images=image, return_tensors="pt")
                        outputs = worker_model(**inputs)
                    
                    if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                        embedding = outputs.pooler_output
                    elif hasattr(outputs, 'last_hidden_state'):
                        embedding = outputs.last_hidden_state.mean(dim=1)
                    else:
                        embedding = outputs[0].mean(dim=1) if isinstance(outputs, tuple) else outputs.mean(dim=1)
                    
                    embedding = embedding.numpy()[0]
                
                results['embeddings'].append(embedding)
                results['page_numbers'].append(page_num)
                
            except Exception as e:
                print(f"Error processing page {page_num} of {pdf_path.name}: {e}")
        
        results['num_pages'] = len(images)
        
    except Exception as e:
        results['error'] = str(e)
        print(f"Error processing {pdf_path.name}: {e}")
    
    return results

def process_pdfs_parallel(pdf_files: List[Path], model_type: str, n_workers: int = None) -> List[Dict]:
    """Process multiple PDFs in parallel"""
    if n_workers is None:
        n_workers = N_WORKERS
    
    print(f"\nProcessing {len(pdf_files)} PDFs using {n_workers} workers...")
    start_time = time.time()
    
    with Pool(n_workers, initializer=init_worker, initargs=(model_type,)) as pool:
        results = list(tqdm(
            pool.imap(process_pdf_parallel, pdf_files),
            total=len(pdf_files),
            desc="Processing PDFs"
        ))
    
    elapsed = time.time() - start_time
    print(f"Processed {len(pdf_files)} PDFs in {elapsed:.1f} seconds")
    print(f"Average: {elapsed/len(pdf_files):.2f} seconds per PDF")
    
    return results

In [None]:
# 7. Build reference embeddings
cache_file = EMBEDDINGS_CACHE_PATH / f"{MODEL_TYPE}_reference_embeddings.pkl"

if cache_file.exists():
    print(f"Loading cached reference embeddings from {cache_file}")
    with open(cache_file, 'rb') as f:
        cache_data = pickle.load(f)
        reference_embeddings = cache_data['embeddings']
        reference_metadata = cache_data['metadata']
else:
    print("Building reference embeddings from example forms...")
    pdf_files = list(EXAMPLE_FORMS_PATH.glob('*.pdf'))
    print(f"Processing {len(pdf_files)} example form PDFs")
    
    if USE_MULTIPROCESSING and len(pdf_files) > 10:
        # Use parallel processing
        print("Using parallel processing for reference embeddings...")
        results = process_pdfs_parallel(pdf_files, MODEL_TYPE, N_WORKERS)
        
        reference_embeddings = []
        reference_metadata = []
        
        for result in results:
            if result['error']:
                print(f"Skipping {result['filename']} due to error: {result['error']}")
                continue
            
            for embedding, page_num in zip(result['embeddings'], result['page_numbers']):
                reference_embeddings.append(embedding)
                reference_metadata.append({
                    'file_path': result['path'],
                    'filename': result['filename'],
                    'page_num': page_num,
                    'total_pages': result['num_pages']
                })
    else:
        # Use sequential processing
        reference_embeddings = []
        reference_metadata = []
        
        for pdf_path in tqdm(pdf_files, desc="Processing example forms"):
            images = pdf_to_images(pdf_path)
            if images:
                embeddings = extract_embeddings(images, model, processor, MODEL_TYPE)
                for i, embedding in enumerate(embeddings):
                    reference_embeddings.append(embedding)
                    reference_metadata.append({
                        'file_path': str(pdf_path),
                        'filename': pdf_path.name,
                        'page_num': i + 1,
                        'total_pages': len(images)
                    })
    
    reference_embeddings = np.array(reference_embeddings)
    
    # Cache the embeddings
    print(f"Caching reference embeddings to {cache_file}")
    with open(cache_file, 'wb') as f:
        pickle.dump({
            'embeddings': reference_embeddings,
            'metadata': reference_metadata,
            'model_type': MODEL_TYPE
        }, f)

print(f"\nReference set contains {len(reference_embeddings)} form pages")
print(f"Embedding dimension: {reference_embeddings.shape[1]}")

In [None]:
# 8. Similarity computation functions
reference_prototype = reference_embeddings.mean(axis=0)
print(f"Reference prototype shape: {reference_prototype.shape}")

def compute_similarity_to_prototype(embedding: np.ndarray, prototype: np.ndarray) -> float:
    """Compute cosine similarity to prototype"""
    return cosine_similarity(embedding.reshape(1, -1), prototype.reshape(1, -1))[0, 0]

def compute_similarity_to_references(embedding: np.ndarray, references: np.ndarray, 
                                   method: str = 'max') -> float:
    """Compute similarity to reference set"""
    similarities = cosine_similarity(embedding.reshape(1, -1), references)[0]
    
    if method == 'max':
        return similarities.max()
    elif method == 'mean':
        return similarities.mean()
    elif method == 'top_k':
        k = min(5, len(similarities))
        return np.sort(similarities)[-k:].mean()
    else:
        raise ValueError(f"Unknown method: {method}")

In [None]:
# 9. Classification functions
def classify_pdf_zero_shot(pdf_path: Path, 
                          model, 
                          processor,
                          model_type: str,
                          reference_embeddings: np.ndarray,
                          reference_prototype: np.ndarray,
                          threshold: float = SIMILARITY_THRESHOLD,
                          use_prototype: bool = True) -> Dict:
    """Classify pages in a PDF using zero-shot similarity matching"""
    results = {
        'filename': pdf_path.name,
        'contains_form': False,
        'form_pages': [],
        'page_scores': [],
        'max_similarity': 0.0,
        'error': None
    }
    
    try:
        images = pdf_to_images(pdf_path)
        if not images:
            results['error'] = "Failed to convert PDF"
            return results
        
        for page_num, image in enumerate(images, 1):
            embedding = extract_embedding_single(image, model, processor, model_type)
            
            if use_prototype:
                similarity = compute_similarity_to_prototype(embedding, reference_prototype)
            else:
                similarity = compute_similarity_to_references(embedding, reference_embeddings, method='max')
            
            results['page_scores'].append({
                'page': page_num,
                'similarity': float(similarity)
            })
            
            if similarity >= threshold:
                results['form_pages'].append(page_num)
                results['contains_form'] = True
            
            results['max_similarity'] = max(results['max_similarity'], similarity)
        
        results['total_pages'] = len(images)
        
    except Exception as e:
        results['error'] = str(e)
    
    return results

def classify_pdf_batch_parallel(pdf_paths: List[Path],
                               reference_embeddings: np.ndarray,
                               reference_prototype: np.ndarray,
                               model_type: str,
                               threshold: float = SIMILARITY_THRESHOLD,
                               use_prototype: bool = True,
                               n_workers: int = None) -> List[Dict]:
    """Classify a batch of PDFs in parallel"""
    
    # Extract embeddings in parallel
    embedding_results = process_pdfs_parallel(pdf_paths, model_type, n_workers)
    
    # Classify based on embeddings
    classification_results = []
    
    for result in embedding_results:
        if result['error']:
            classification_results.append({
                'filename': result['filename'],
                'contains_form': False,
                'form_pages': [],
                'page_scores': [],
                'max_similarity': 0.0,
                'error': result['error']
            })
            continue
        
        form_pages = []
        page_scores = []
        max_similarity = 0.0
        
        for embedding, page_num in zip(result['embeddings'], result['page_numbers']):
            if use_prototype:
                similarity = compute_similarity_to_prototype(embedding, reference_prototype)
            else:
                similarity = compute_similarity_to_references(embedding, reference_embeddings, method='max')
            
            page_scores.append({
                'page': page_num,
                'similarity': float(similarity)
            })
            
            if similarity >= threshold:
                form_pages.append(page_num)
            
            max_similarity = max(max_similarity, similarity)
        
        classification_results.append({
            'filename': result['filename'],
            'contains_form': len(form_pages) > 0,
            'form_pages': form_pages,
            'page_scores': page_scores,
            'max_similarity': max_similarity,
            'total_pages': result.get('num_pages', len(result['embeddings'])),
            'error': None
        })
    
    return classification_results

In [None]:
# 10. Threshold tuning
print("Testing on known examples to tune threshold...")

positive_scores = []
negative_scores = []

# Test on subset of examples
print("\nComputing self-similarity for positive examples...")
for i in range(min(50, len(reference_embeddings))):
    embedding = reference_embeddings[i]
    other_embeddings = np.delete(reference_embeddings, i, axis=0)
    similarity = compute_similarity_to_references(embedding, other_embeddings, method='max')
    positive_scores.append(similarity)

# Test on non-examples
print("\nTesting on non-examples...")
non_example_files = list(NON_EXAMPLES_PATH.glob('*.pdf'))[:20]

for pdf_path in tqdm(non_example_files, desc="Processing non-examples"):
    result = classify_pdf_zero_shot(
        pdf_path, model, processor, MODEL_TYPE,
        reference_embeddings, reference_prototype,
        threshold=0.0,
        use_prototype=False
    )
    
    if result['page_scores']:
        for page_score in result['page_scores']:
            negative_scores.append(page_score['similarity'])

# Analyze scores
positive_scores = np.array(positive_scores)
negative_scores = np.array(negative_scores)

print(f"\nPositive scores - Mean: {positive_scores.mean():.3f}, Std: {positive_scores.std():.3f}")
print(f"Negative scores - Mean: {negative_scores.mean():.3f}, Std: {negative_scores.std():.3f}")

# Find optimal threshold
gap_threshold = (positive_scores.min() + negative_scores.max()) / 2
percentile_threshold = np.percentile(negative_scores, 99)
OPTIMAL_THRESHOLD = max(gap_threshold, percentile_threshold)
print(f"\nSelected optimal threshold: {OPTIMAL_THRESHOLD:.3f}")

In [None]:
# 11. Batch processing function
def process_document_folder(folder_path: Path,
                          model,
                          processor,
                          model_type: str,
                          reference_embeddings: np.ndarray,
                          reference_prototype: np.ndarray,
                          threshold: float,
                          output_file: str = 'zero_shot_results.csv',
                          max_files: Optional[int] = None,
                          use_prototype: bool = True,
                          use_parallel: bool = None) -> pd.DataFrame:
    """Process all PDFs in a folder using zero-shot classification"""
    
    if use_parallel is None:
        use_parallel = USE_MULTIPROCESSING
    
    results = []
    pdf_files = list(folder_path.glob('*.pdf'))
    
    if max_files:
        pdf_files = pdf_files[:max_files]
    
    print(f"Processing {len(pdf_files)} PDF files...")
    
    if use_parallel and len(pdf_files) > 10:
        # Use parallel processing
        print(f"Using parallel processing with {N_WORKERS} workers...")
        batch_size = min(MAX_PDFS_PER_WORKER * N_WORKERS, len(pdf_files))
        
        for i in range(0, len(pdf_files), batch_size):
            batch_files = pdf_files[i:i+batch_size]
            print(f"\nProcessing batch {i//batch_size + 1}/{(len(pdf_files) + batch_size - 1)//batch_size}")
            
            batch_results = classify_pdf_batch_parallel(
                batch_files,
                reference_embeddings,
                reference_prototype,
                model_type,
                threshold=threshold,
                use_prototype=use_prototype,
                n_workers=N_WORKERS
            )
            
            for result in batch_results:
                results.append({
                    'filename': result['filename'],
                    'contains_form': result['contains_form'],
                    'form_pages': ','.join(map(str, result['form_pages'])),
                    'num_form_pages': len(result['form_pages']),
                    'total_pages': result.get('total_pages', 0),
                    'max_similarity': result['max_similarity'],
                    'error': result['error']
                })
    else:
        # Use sequential processing
        print("Using sequential processing...")
        for pdf_path in tqdm(pdf_files):
            result = classify_pdf_zero_shot(
                pdf_path, model, processor, model_type,
                reference_embeddings, reference_prototype,
                threshold=threshold,
                use_prototype=use_prototype
            )
            
            results.append({
                'filename': result['filename'],
                'contains_form': result['contains_form'],
                'form_pages': ','.join(map(str, result['form_pages'])),
                'num_form_pages': len(result['form_pages']),
                'total_pages': result.get('total_pages', 0),
                'max_similarity': result['max_similarity'],
                'error': result['error']
            })
    
    # Create DataFrame and save
    df_results = pd.DataFrame(results)
    df_results.to_csv(output_file, index=False)
    
    # Print summary
    print(f"\nResults saved to {output_file}")
    print(f"Total documents processed: {len(df_results)}")
    print(f"Documents with forms: {df_results['contains_form'].sum()}")
    print(f"Documents without forms: {(~df_results['contains_form']).sum()}")
    print(f"Processing errors: {df_results['error'].notna().sum()}")
    
    return df_results

In [None]:
# 12. Test classification
print("Testing on sample documents...")

test_results = process_document_folder(
    NON_EXAMPLES_PATH,
    model,
    processor,
    MODEL_TYPE,
    reference_embeddings,
    reference_prototype,
    threshold=OPTIMAL_THRESHOLD,
    output_file='zero_shot_test_results.csv',
    max_files=50,
    use_prototype=False
)

print("\nSample results:")
print(test_results.head(10))

In [None]:
# 13. Performance comparison (Sequential vs Parallel)
print("\nPerformance Comparison: Sequential vs Parallel Processing")
print("=" * 60)

test_files = list(NON_EXAMPLES_PATH.glob('*.pdf'))[:20]

if len(test_files) >= 10:
    # Sequential processing
    print("\n1. Sequential Processing:")
    start_time = time.time()
    seq_results = process_document_folder(
        NON_EXAMPLES_PATH,
        model,
        processor,
        MODEL_TYPE,
        reference_embeddings,
        reference_prototype,
        threshold=OPTIMAL_THRESHOLD,
        output_file='sequential_test.csv',
        max_files=20,
        use_prototype=False,
        use_parallel=False
    )
    seq_time = time.time() - start_time
    print(f"Sequential time: {seq_time:.1f} seconds")
    
    # Parallel processing
    print("\n2. Parallel Processing:")
    start_time = time.time()
    par_results = process_document_folder(
        NON_EXAMPLES_PATH,
        model,
        processor,
        MODEL_TYPE,
        reference_embeddings,
        reference_prototype,
        threshold=OPTIMAL_THRESHOLD,
        output_file='parallel_test.csv',
        max_files=20,
        use_prototype=False,
        use_parallel=True
    )
    par_time = time.time() - start_time
    print(f"Parallel time: {par_time:.1f} seconds")
    
    # Results
    speedup = seq_time / par_time
    print(f"\nSpeedup: {speedup:.1f}x faster with {N_WORKERS} workers")
    print(f"\nEstimated time for 190k documents:")
    print(f"Sequential: {190000 * seq_time / len(test_files) / 3600:.1f} hours")
    print(f"Parallel: {190000 * par_time / len(test_files) / 3600:.1f} hours")

In [None]:
# 14. Process full corpus (example code)
print("\nExample code for processing full corpus:")
print("""
# To process your full 190k corpus:
CORPUS_PATH = BASE_PATH / 'data' / 'raw' / 'all_contracts'

results_df = process_document_folder(
    CORPUS_PATH,
    model,
    processor,
    MODEL_TYPE,
    reference_embeddings,
    reference_prototype,
    threshold=OPTIMAL_THRESHOLD,
    output_file='full_corpus_results.csv',
    use_prototype=False,  # Use full reference set for better accuracy
    use_parallel=True     # Use parallel processing
)
""")