In [1]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output
from utils.helpers import extracted_fields_serialized

import spacy
from iterextras import par_for

from ipywidgets import Layout
import hyperwidget

In [2]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(w2_sample_dir, 'multi')
    label_path = os.path.join(w2_sample_dir, 'multi_label.csv')
    return Dataloader(data_dir, label_path, concatenate_pages=True, cache_to_disk=True)

dl = get_dataloader()

In [3]:
field_queries = [
    {   # Field 1
        "name": "EIN",
        "arguments": {
            "x-position": 0.1,
            "y-position": 0.1,
            "entity": "EIN",
            "word-neighbors": ["Employer Identification Number"],
            "word-neighbor-max-top-dist": 100,
            "word-neighbor-max-left-dist": 100,
            "word-neighbor-max-bottom-dist": 0,
            "word-neighbor-max-right-dist": 0,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    },
    {   # Field 2
        "name": "Medicare Tax Withheld",
        "arguments": {
            "x-position": 0.9,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Medicare Tax Withheld"],
            "word-neighbor-max-top-dist": 100,
            "word-neighbor-max-left-dist": 100,
            "word-neighbor-max-bottom-dist": 0,
            "word-neighbor-max-right-dist": 0,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [None]:
num_docs = len(dl)
nlp = spacy.load("en_core_web_sm")
extracted_fields = par_for(
    lambda i:  extract_fields(dl.get_document(i), field_queries, 1000, nlp),
    list(range(num_docs)),
    workers=2,
)
labels = [dl.get_label(i) for i in range(num_docs)]

In [None]:
errors = evaluate_output(extracted_fields, labels, ['EIN', 'Medicare Tax Withheld'])

In [None]:
error_table = hyperwidget.ErrorTable(
    errors = errors
)

In [None]:
extraction_heatmap = hyperwidget.ExtractionHeatmap(
    documents=[dl.get_document(i).as_dict_with_images() for i in range(num_docs)],
    labels=[dl.get_label(i) for i in range(num_docs)],
    extracted_fields=extracted_fields_serialized(extracted_fields)
)

In [None]:
error_table

In [None]:
extraction_heatmap

In [None]:
ocr_visualizer = hyperwidget.OCRVisualizer(
    document=dl.get_document(4).as_dict_with_images()
)

In [None]:
ocr_visualizer

In [None]:
label_page = dl.get_document(0).pages[0]
print("Chosen Lines: ", [label_page.lines[i] for i in ocr_visualizer.selected_lines])
query = generate_query("Control number", [label_page.lines[ocr_visualizer.selected_lines[0]]], [label_page])
print("Generated Query: ", query)

In [4]:
documents = [dl.get_document(i) for i in range(5)]
documents_dict = [d.as_dict_with_images() for d in documents]
pages = [d.pages[0] for d in documents]
multidoc_gen = hyperwidget.MultiDocGen(
    documents=documents_dict
)

In [5]:
multidoc_gen

MultiDocGen(documents=[{'path': '../data/sample/w2/multi/W2_Multi_Sample_Data_input_ADP1_15541.pdf', 'pages': …

In [6]:
labeled_pages, labeled_lines = [], []
for doc in multidoc_gen.selected_lines:
    doc_idx = int(doc)
    labeled_pages.append(pages[doc_idx])
    labeled_lines.append(pages[doc_idx].lines[multidoc_gen.selected_lines[doc][0]])
query = generate_query("Control number", labeled_lines, labeled_pages)

ZeroDivisionError: division by zero

In [None]:
query