In [1]:
from gemini_models import *

In [4]:
F1040_p1.model_fields

{'form_number': FieldInfo(annotation=str, required=True, description='The Form Number'),
 'tax_year': FieldInfo(annotation=str, required=True, description='The tax year of this form'),
 'primary_first_name': FieldInfo(annotation=str, required=True, description='The first name of the primary tax payer'),
 'primary_last_name': FieldInfo(annotation=str, required=True, description='The last name of the primary tax payer'),
 'primary_ssn_last_4': FieldInfo(annotation=str, required=True, description='Only the last 4 digits of the SSN number of the primary tax payer.'),
 'spouse_first_name': FieldInfo(annotation=str, required=True, description='The first name of the spouse. If not provided, please return an empty string.'),
 'spouse_last_name': FieldInfo(annotation=str, required=True, description='The first name of the spouse. If not provided, please return an empty string.'),
 'spouse_ssn_last_4': FieldInfo(annotation=str, required=True, description='Only the last 4 digits of the SSN of the 

In [9]:
import sqlite3
import pandas as pd

conn = sqlite3.connect('documents.db')

# List of unique filenames
query = '''
select distinct filename from pages order by created_at desc
'''
all_files = pd.read_sql_query(query, conn)

In [29]:
all_files['filename'].values

array(['acord_28_2025_02_09_102520_906.pdf', '25-GL.pdf',
       'ACORD25TEST.png', 'test_sched_c.pdf', 'testacord25.pdf',
       'g217064g51y66.jpg', '1120S_bal_sheet_2024_12_27_085143_649.pdf',
       'Merged PDF File.pdf', '1065_k1_2024_12_27_085231_070.pdf',
       'drivers_license_test.pdf', 'RFD 2022 TAX RETURN.pdf'],
      dtype=object)

In [30]:
ex = 'RFD 2022 TAX RETURN.pdf' # all_files['filename']

In [31]:
# Distinct list of files already uploadeded
# df_pages, df_extracted, df_info
query = f'''
select * from pages
where filename = '{ex}'
'''
df_pages = pd.read_sql_query(query, conn)

query = f'''
select p.filename base_file, e2.* from extracted2 e2 
join (select filename, preprocessed from pages) p on e2.filename = p.preprocessed
where p.filename = '{ex}'
'''
df_extracted = pd.read_sql_query(query, conn)

query = f'''
select p.filename base_file, i.* from call_info i 
join (select filename, preprocessed from pages) p on i.filename = p.preprocessed
where p.filename = '{ex}'
'''
df_info = pd.read_sql_query(query, conn)

In [32]:
df_extracted

Unnamed: 0,base_file,filename,key,value,page_label,page_score,page_num,created_at
0,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_6\prepro...,business_name,"RAINFLOW DEVELOPMENTS, LLC",business_license,0.629119,6,2025-02-17 23:43:55
1,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_6\prepro...,current_standing,,business_license,0.629119,6,2025-02-17 23:43:55
2,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_11\prepr...,business_name,"RAINFLOW DEVELOPMENTS, LLC",1065_p1,1.000000,11,2025-02-17 23:43:55
3,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_11\prepr...,city_state,"Los Banos, CA 93635",1065_p1,1.000000,11,2025-02-17 23:43:55
4,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_11\prepr...,cost_of_goods_sold,,1065_p1,1.000000,11,2025-02-17 23:43:55
...,...,...,...,...,...,...,...,...
211,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_91\prepr...,shareholder_name,Raaj V Desor,1065_k1,1.000000,91,2025-02-17 23:43:55
212,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_91\prepr...,ssn_last_4,9948,1065_k1,1.000000,91,2025-02-17 23:43:55
213,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_91\prepr...,tax_year,2022,1065_k1,1.000000,91,2025-02-17 23:43:55
214,RFD 2022 TAX RETURN.pdf,debug_images\RFD 2022 TAX RETURN\page_92\prepr...,business_name,"RAINFLOW DEVELOPMENTS, LLC",business_license,0.716266,92,2025-02-17 23:43:55


In [8]:
all_files

Unnamed: 0,filename
0,acord_28_2025_02_09_102520_906.pdf
1,25-GL.pdf
2,ACORD25TEST.png
3,test_sched_c.pdf
4,testacord25.pdf
5,g217064g51y66.jpg
6,1120S_bal_sheet_2024_12_27_085143_649.pdf
7,Merged PDF File.pdf
8,1065_k1_2024_12_27_085231_070.pdf
9,drivers_license_test.pdf


In [4]:
df_results.columns #.groupby('page_label')['page_score'].mean().plot(kind='bar')

Index(['filename', 'preprocessed', 'page_number', 'image_width',
       'image_height', 'lines', 'words', 'bboxes', 'normalized_bboxes',
       'tokens', 'words_for_clf', 'processing_time', 'clf_type', 'page_label',
       'page_score', 'created_at'],
      dtype='object')

In [4]:
prompt = (
                "Extract the structured data from this document. "
                "If SPII is requested, only return partial data. "
                "If a field exists but contains no value, return an empty string."
            )

In [5]:
print(prompt)

Extract the structured data from this document. If SPII is requested, only return partial data. If a field exists but contains no value, return an empty string.


In [None]:
import os
import pickle
from PIL import Image
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm

import torch
from paddleocr import PaddleOCR
from transformers import AutoTokenizer, LayoutLMModel

from scipy.optimize import linear_sum_assignment


# Initialize OCR and LayoutLM
ocr = PaddleOCR(use_angle_cls=True, rec=False, lang="en")
tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
model = LayoutLMModel.from_pretrained("impira/layoutlm-document-qa")
model = model.eval()  # Set model to inference mode

def extract_features(image_path, question="What is in the document?"):
    """
    Extract words, bounding boxes (normalized), and embeddings from the given image.
    Normalization of bounding boxes to a 0-1000 range is performed here.
    
    Args:
        image_path (str): Path to the image.
        question (str): A question for tokenization with LayoutLM. This can be a dummy question 
                        as we primarily need embeddings.
                        
    Returns:
        words (list of str): Detected words from OCR.
        normalized_bboxes (list of lists): Normalized bounding boxes in [x1, y1, x2, y2] format.
        embeddings (np.ndarray): Token embeddings from LayoutLM.
    """
    # Step 1: Run PaddleOCR on the image
    ocr_results = ocr.ocr(image_path, cls=True)[0]

    # Extract words and bounding boxes
    words = [line[1][0] for line in ocr_results]
    boxes = [line[0] for line in ocr_results]

    # Convert quadrilateral OCR boxes to rectangular bounding boxes
    bboxes = []
    for box in boxes:
        x1 = min(point[0] for point in box)
        y1 = min(point[1] for point in box)
        x2 = max(point[0] for point in box)
        y2 = max(point[1] for point in box)
        bboxes.append([x1, y1, x2, y2])

    # Load image to get dimensions and normalize bboxes
    image = Image.open(image_path)
    image_width, image_height = image.size
    normalized_bboxes = [
        [
            int((x1 / image_width) * 1000),
            int((y1 / image_height) * 1000),
            int((x2 / image_width) * 1000),
            int((y2 / image_height) * 1000),
        ]
        for (x1, y1, x2, y2) in bboxes
    ]

    # Step 3: Tokenize question and words for LayoutLM
    encoding = tokenizer(
        question.split(),
        words,
        is_split_into_words=True,
        return_token_type_ids=True,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )

    # Align bounding boxes with tokens
    word_ids = encoding.word_ids(0)
    bbox = []
    for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), word_ids):
        if s == 1 and w is not None:
            bbox.append(normalized_bboxes[w])
        elif i == tokenizer.sep_token_id:
            bbox.append([1000] * 4)
        else:
            bbox.append([0] * 4)

    encoding["bbox"] = torch.tensor([bbox])

    # Move encoding to model device if needed
    # If model is on CPU, this is not strictly necessary
    params = {k: v.to(model.device) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**params)
    embeddings = outputs.last_hidden_state[0].cpu().numpy()

    return words, normalized_bboxes, embeddings

def build_template_database(base_dir='clf_images'):
    """
    Traverse the directory structure:
    clf_images/
       label1/
         base/
           template1.png
           template2.png
       label2/
         base/
           template1.png
    ...
    
    Extract features for each template image and store in a dictionary:
    
    template_db = {
      'label_name': [
         {
           'filename': 'template_image_name.png',
           'words': [...],
           'bboxes': [...],
           'embeddings': np.array([...])
         },
         ...
      ],
      ...
    }
    """
    template_db = {}
    for label_name in os.listdir(base_dir):
        label_path = os.path.join(base_dir, label_name)
        if not os.path.isdir(label_path):
            continue
        base_path = os.path.join(label_path, 'base')
        if not os.path.exists(base_path):
            continue
        
        templates = []
        for fname in os.listdir(base_path):
            if fname.lower().endswith(('png', 'jpg', 'jpeg')):
                image_path = os.path.join(base_path, fname)
                words, bboxes, embeddings = extract_features(image_path)
                templates.append({
                    'filename': fname,
                    'words': words,
                    'bboxes': bboxes,
                    'embeddings': embeddings
                })
        if templates:
            template_db[label_name] = templates

    # Save the template database for future use
    with open('template_db.pkl', 'wb') as f:
        pickle.dump(template_db, f)

    return template_db

def load_template_database(db_path='template_db.pkl'):
    with open(db_path, 'rb') as f:
        template_db = pickle.load(f)
    return template_db

from scipy.stats import wasserstein_distance

def semantic_similarity(emb1, emb2):
    """
    Compute semantic similarity by comparing only the minimum number of embeddings
    between the two sets. This avoids issues caused by differing lengths.

    Args:
        emb1 (np.ndarray): Embeddings for the first set of tokens, shape (N1, D).
        emb2 (np.ndarray): Embeddings for the second set of tokens, shape (N2, D).

    Returns:
        float: Semantic similarity score.
    """
    # Determine the minimum length
    min_len = min(len(emb1), len(emb2))

    # Truncate embeddings to the minimum length
    emb1_trimmed = emb1[:min_len]
    emb2_trimmed = emb2[:min_len]

    # Compute cosine similarity on the trimmed embeddings
    avg_emb1 = np.mean(emb1_trimmed, axis=0)
    avg_emb2 = np.mean(emb2_trimmed, axis=0)
    sim = np.dot(avg_emb1, avg_emb2) / (np.linalg.norm(avg_emb1) * np.linalg.norm(avg_emb2) + 1e-10)

    return sim


def structural_similarity(bboxes1, bboxes2):
    """
    Compute a structural similarity score based on bounding boxes.
    Simple heuristic:
    - Compare corresponding boxes (up to the min length).
    - Compute Euclidean distances, then convert to similarity.
    """
    min_len = min(len(bboxes1), len(bboxes2))
    b1 = np.array(bboxes1[:min_len])
    b2 = np.array(bboxes2[:min_len])
    # Euclidean distances
    dists = np.sqrt(np.sum((b1 - b2)**2, axis=1))
    # Convert distance to similarity: sim = 1/(1+dist)
    similarities = 1 / (1 + dists)
    return float(np.mean(similarities))

def compare_image_to_templates(image_path, template_db):
    """
    Given a new image, extract features and compare them to each template in the database.
    Return the best-matching label, the highest score, and a dictionary of all scores.

    Score is defined as min(semantic_similarity, structural_similarity) as per user request.

    Returns:
        best_label (str): The label with the highest final score.
        best_score (float): The highest final score.
        all_scores (dict): A dictionary where keys are labels, and values are lists of dicts:
                           [
                             {
                               'filename': <str>,
                               'sem_sim': <float>,
                               'struc_sim': <float>,
                               'final_score': <float>
                             },
                             ...
                           ]
    """
    words, bboxes, embeddings = extract_features(image_path)
    best_score = -float('inf')
    best_label = None
    all_scores = {}

    for label, templates in template_db.items():
        label_scores = []
        for t in templates:
            sem_sim = semantic_similarity(embeddings, t['embeddings'])
            struc_sim = structural_similarity(bboxes, t['bboxes'])
            final_score = min(sem_sim, struc_sim)

            # Record the details for this template
            template_result = {
                'filename': t['filename'],
                'sem_sim': sem_sim,
                'struc_sim': struc_sim,
                'final_score': final_score
            }
            label_scores.append(template_result)

            # Check if this is the best score so far
            if final_score > best_score:
                best_score = final_score
                best_label = label

        all_scores[label] = label_scores

    return best_label, best_score, all_scores


# if __name__ == "__main__":
#     # Example usage:
#     # 1. Build the template database once
#     # template_db = build_template_database('clf_images')

#     # Or load it if already built
#     template_db = load_template_database('template_db.pkl')

#     # 2. Classify a new image
#     test_image = r'test_pages\passport_example.png'
#     label, score, all_scores = compare_image_to_templates(test_image, template_db)
#     print(f"Predicted label: {label}, Score: {score}")
#     print(all_scores)


In [2]:
template_db = load_template_database('template_db.pkl')

In [None]:
template_db.keys()

In [None]:
test_image = r'test_pages\passport_example.png'
label, score, all_scores = compare_image_to_templates(test_image, template_db)
print(f"Predicted label: {label}, Score: {score}")
print(all_scores)

# Second Tier Classifier

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"  # Replace with your chosen model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Create inference pipeline
classifier = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Define prompt
text = "Based on the text below, determine whether the document is one of the following: lease_document, certificate_of_good_standing, or business_license. If it is none of these, respond 'None'.\n\nText: This agreement is between..."
response = classifier(text, max_length=100, num_return_sequences=1)

print(response[0]["generated_text"])


  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading shards: 100%|██████████| 2/2 [06:21<00:00, 190.92s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.11s/it]
Device set to use cuda:0
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Based on the text below, determine whether the document is one of the following: lease_document, certificate_of_good_standing, or business_license. If it is none of these, respond 'None'.

Text: This agreement is between... and.... It is a lease agreement for the property located at... The parties agree to rent the property for the amount of... per month. This lease is for a term of... months. The lessees are responsible for paying all utilities and taxes. The


In [2]:
import sqlite3
import pandas as pd
conn = sqlite3.connect('documents.db')

In [3]:
# Query to show all tables
query = "SELECT name FROM sqlite_master WHERE type='table';"
tables = pd.read_sql_query(query, conn)

import sqlite3

def recreate_extracted_table():
    conn = sqlite3.connect('documents.db')
    cursor = conn.cursor()
    
    # Drop the existing extracted table
    cursor.execute('DROP TABLE IF EXISTS extracted')
    
    # Recreate the extracted table with the desired schema
    cursor.execute('''
        CREATE TABLE extracted (
            key TEXT,
            label TEXT,
            label_bbox TEXT,
            label_confidence REAL,
            value TEXT,
            value_bbox TEXT,
            value_confidence REAL,
            page_num INTEGER,
            annotated_image_path TEXT,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
    ''')
    
    conn.commit()
    conn.close()

# Call the function to recreate the table
recreate_extracted_table()

In [4]:
df = pd.read_sql_query("select * from pages", conn)
# df = df.loc[df['filename'].str.contains('dl')]

In [5]:
df.loc[98,['preprocessed','lines']]

preprocessed    debug_images\drivers_license_test\page_1\prepr...
lines           ["NEW YORK", "DRIVER", "Courtesy of Governor E...
Name: 98, dtype: object

In [None]:
df['tokens']

In [7]:
s = eval(df['words_for_clf'].values[0])

In [9]:
img = df['preprocessed'].values[0]

In [8]:
list(['a', 'b', 'c'])

['a', 'b', 'c']

In [13]:
def funfun():
    return 1,2,3,4,5

In [14]:
t = funfun()

In [None]:
if t:
    print('yes')

In [16]:
a,b,c,d,e = t

# Fallback Tier

In [9]:
from transformers import CLIPProcessor, CLIPModel, pipeline
from PIL import Image

# Load models
text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
image_model = CLIPModel.from_pretrained("zer0int/CLIP-GmP-ViT-L-14")
image_processor = CLIPProcessor.from_pretrained("zer0int/CLIP-GmP-ViT-L-14")

# Define labels (candidate classes)
# Align values
fallback_labels = [
    "Drivers License", 
    "Passport", 
    "Lease Document", 
    "Certificate of Good Standing", 
    "Business License"
]

# Text-Based Classification
def classify_using_text(text, labels, threshold=0.6):
    """
    Classify document using text-based zero-shot classification.
    """
    result = text_classifier(text, candidate_labels=labels)
    all_scores = result["scores"]
    best_label, best_score = result["labels"][0], result["scores"][0]
    if best_score >= threshold:
        return best_label, best_score, None, all_scores, 'text_clf'
    return None

# Image-Based Classification
def classify_using_image(image_path, labels, threshold=0.6):
    """
    Classify document using image-based zero-shot classification.
    """
    image = Image.open(image_path)
    inputs = image_processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = image_model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)  # Image-text similarity scores
    all_scores = {l: p.item() for l, p in zip(labels, probs[0])}
    best_label = labels[probs.argmax()]
    best_score = probs.max().item()
    if best_score >= threshold:
        return best_label, best_score, None, all_scores, 'image_clf'
    return None

# Combined Workflow
def classify_document(image_path, text, labels, threshold=0.6):
    """
    Classify document using text-based classification first,
    and fall back to image-based classification if needed.
    """
    # Step 1: Text-based classification
    text_result = classify_using_text(text, labels, threshold)
    if text_result:
        return text_result

    # Step 2: Fallback to image-based classification
    image_result = classify_using_image(image_path, labels, threshold)
    if image_result:
        return image_result

    # Step 3: Final fallback
    return 'Unknown', 0, None, None, None
    # return image_result

# # Example Usage
# image_path = "path/to/your/image.jpg"
# text = """
# This agreement is made on the 1st of January, 2025, between the Landlord and the Tenant. 
# It outlines the terms and conditions for renting the property located at 123 Main Street.
# """  # Replace with OCR-extracted text

# result = classify_document(img, ' '.join(t for t in list(s)[:100]), fallback_labels)
# print(result)


  from .autonotebook import tqdm as notebook_tqdm
Device set to use cuda:0


In [10]:
# Define labels (candidate classes)
fallback_labels = {
    "This is an official drivers license document":"drivers_license", 
    "This is a government-issued passport":"passport", 
    "This is a legal lease agreement document":"lease_document", 
    "This is a certificate verifying good standing for a business":"certificate_of_good_standing", 
    "This is an official business license document":"business_license"
}

In [11]:
img = r'test_pages\cert_test.png'
classify_using_image(img, list(fallback_labels.keys()), threshold=0.6)

('This is an official business license document',
 0.988182008266449,
 None,
 {'This is an official drivers license document': 0.00013374103582464159,
  'This is a government-issued passport': 4.688190529122949e-06,
  'This is a legal lease agreement document': 1.4443784493778367e-05,
  'This is a certificate verifying good standing for a business': 0.011665068566799164,
  'This is an official business license document': 0.988182008266449},
 'image_clf')

In [None]:
import random
import string

def generate_100_word_statement():
    words = []
    while len(words) < 100:
        word_length = random.randint(1, 10)
        word = ''.join(random.choices(string.ascii_lowercase, k=word_length))
        words.append(word)
    
    # Join the words into a single string
    statement = ' '.join(words)
    
    # Ensure the statement is exactly 100 words long
    statement_words = statement.split()
    if len(statement_words) > 100:
        statement_words = statement_words[:100]
    
    # Truncate the last word if it is too lengthy
    if len(statement_words[-1]) > 10:
        statement_words[-1] = statement_words[-1][:10]
    
    return ' '.join(statement_words)

# Example usage
statement = generate_100_word_statement()
print(statement)

In [None]:
' '.join(list(set(statement.split())))

# Individual Testing

In [1]:
from fast_processor import main

Device set to use cuda:0


In [2]:
m = main(r'test_pages\lease_example.pdf')


converting pages...: 1it [00:03,  3.96s/it]
0it [00:00, ?it/s]Some weights of the model checkpoint at Snowflake/snowflake-arctic-embed-l-v2.0 were not used when initializing XLMRobertaModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
1it [00:01,  1.86s/it]


ValueError: No objects to concatenate

In [3]:
d = {'key1':1, 'key2':2}

In [None]:
list(d.keys())

# Extractor Testing

In [7]:
import json
from transformers import CLIPProcessor, CLIPModel, pipeline
from PIL import Image

# -----------------------------
# Step 0: Define your labels.json content.
# In practice, this JSON would be stored on disk and loaded via json.load().
# Here, we define it as a dictionary for demonstration.
labels_questions = {
    "Drivers License": [
        "What is the full name of this license holder?",
        "What US state is this drivers license from?",
        "What is the expiration date of this license?"
    ],
    "Passport": [
        "What is the passport number?",
        "What is the nationality of the passport holder?",
        "What is the date of issue of the passport?"
    ],
    "Lease Document": [
        "What is the lease start date?",
        "What is the monthly rent amount?",
        "Who is the landlord or lessor?"
    ],
    "Certificate of Good Standing": [
        "What is the certificate number?",
        "What is the date of issuance?",
        "What is the registered company name?"
    ],
    "Business License": [
        "What is the business license number?",
        "What is the expiration date of the license?",
        "What is the name of the business?"
    ]
}

# Optionally, if you store this on disk as labels.json, you can load it with:
# with open("labels.json", "r") as f:
#     labels_questions = json.load(f)

# -----------------------------
# Pre-existing fallback labels list
fallback_labels = [
    "Drivers License", 
    "Passport", 
    "Lease Document", 
    "Certificate of Good Standing", 
    "Business License"
]

# -----------------------------
# Load classification models
text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
image_model = CLIPModel.from_pretrained("zer0int/CLIP-GmP-ViT-L-14")
image_processor = CLIPProcessor.from_pretrained("zer0int/CLIP-GmP-ViT-L-14")

# -----------------------------
# Text-Based Classification
def classify_using_text(text, labels, threshold=0.6):
    """
    Classify document using text-based zero-shot classification.
    """
    result = text_classifier(text, candidate_labels=labels)
    all_scores = result["scores"]
    best_label, best_score = result["labels"][0], result["scores"][0]
    if best_score >= threshold:
        return best_label, best_score, None, all_scores, 'text_clf'
    return None

# -----------------------------
# Image-Based Classification
def classify_using_image(image_path, labels, threshold=0.6):
    """
    Classify document using image-based zero-shot classification.
    """
    image = Image.open(image_path)
    inputs = image_processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = image_model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)  # Image-text similarity scores
    all_scores = {l: p.item() for l, p in zip(labels, probs[0])}
    best_label = labels[probs.argmax()]
    best_score = probs.max().item()
    if best_score >= threshold:
        return best_label, best_score, None, all_scores, 'image_clf'
    return None

# -----------------------------
# Combined Workflow: Document Classification
def classify_document(image_path, text, labels, threshold=0.6):
    """
    Classify document using text-based classification first,
    and fall back to image-based classification if needed.
    """
    # Step 1: Text-based classification
    text_result = classify_using_text(text, labels, threshold)
    if text_result:
        return text_result

    # Step 2: Fallback to image-based classification
    image_result = classify_using_image(image_path, labels, threshold)
    if image_result:
        return image_result

    # Step 3: Final fallback: return Unknown if both methods fail.
    return 'Unknown', 0, None, None, None

# -----------------------------
# Load a QA model for information extraction.
qa_pipeline = pipeline("document-question-answering", model="impira/layoutlm-document-qa")

from custom_pipeline import layoutlm_paddleocr_pipeline

def extract_information(input, doc_label, labels_questions):
    """
    Extracts information from document_text based on the questions
    associated with doc_label, returning both answers and confidence scores.
    """
    # Retrieve the list of questions for the document type.
    questions = labels_questions.get(doc_label, [])
    answers = {}
    for question in questions:
        result = layoutlm_paddleocr_pipeline(image_path=input, question=question)
        # result includes "answer" and "score"
        answers[question] = {"answer": result["answer"], "confidence": result["score"]}
    return answers



Device set to use cuda:0
Device set to use cuda:0


In [8]:
# Example usage with your updated snippet:
image_path = df.loc[98, ['preprocessed']].values[0]
document_text = df.loc[98, ['lines']].values[0]

# Step 1: Classify the document.
doc_label, score, _, scores_dict, method = classify_document(image_path, document_text, fallback_labels)
print(f"Document classified as: {doc_label} (score: {score}, method: {method})")
print("Full scores:", scores_dict)

# Step 2: If classified, extract the relevant information.
if doc_label != 'Unknown':
    # extracted_data = extract_information(image_path, doc_label, labels_questions)
    extracted_data = extract_information(image_path, doc_label, labels_questions)
    print("\nExtracted Information:")
    for question, info in extracted_data.items():
        print(f"- {question}: {info['answer']} (confidence: {info['confidence']:.3f})")
else:
    print("Document classification failed. No extraction performed.")

Document classified as: Drivers License (score: 0.6841147541999817, method: text_clf)
Full scores: [0.6841147541999817, 0.25053372979164124, 0.04078324884176254, 0.012909149751067162, 0.011659120209515095]
[2025/02/03 21:12:22] ppocr DEBUG: dt_boxes num : 16, elapsed : 0.041593074798583984
[2025/02/03 21:12:22] ppocr DEBUG: cls num  : 16, elapsed : 0.017036914825439453
[2025/02/03 21:12:22] ppocr DEBUG: rec_res num  : 16, elapsed : 0.15475130081176758
[2025/02/03 21:12:22] ppocr DEBUG: dt_boxes num : 16, elapsed : 0.04526019096374512
[2025/02/03 21:12:22] ppocr DEBUG: cls num  : 16, elapsed : 0.019980192184448242
[2025/02/03 21:12:23] ppocr DEBUG: rec_res num  : 16, elapsed : 0.1740555763244629
[2025/02/03 21:12:23] ppocr DEBUG: dt_boxes num : 16, elapsed : 0.0445399284362793
[2025/02/03 21:12:23] ppocr DEBUG: cls num  : 16, elapsed : 0.01824164390563965
[2025/02/03 21:12:23] ppocr DEBUG: rec_res num  : 16, elapsed : 0.2006824016571045

Extracted Information:
- What is the full name of

In [None]:
!pip install pytesseract