In [29]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
importlib.reload(utils.extract)
from utils.extract import extract_fields
from utils.rule_gen import generate_query
importlib.reload(utils.evaluator)
from utils.evaluator import evaluate_output

from ipywidgets import Layout
import hyperwidget

In [11]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(w2_sample_dir, 'single_clean')
    label_path = os.path.join(w2_sample_dir, 'single_label.csv')
    return Dataloader(data_dir, label_path)

dl = get_dataloader()

In [26]:
field_queries = [
    {   # Field 1
        "name": "EIN",
        "arguments": {
            "x-position": 0.1,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Employer", "Identification", "Number"],
            "word-neighbor-top-thres": 50,
            "word-neighbor-left-thres": 200,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    },
    {   # Field 2
        "name": "Medicare Tax withheld",
        "arguments": {
            "x-position": 0.9,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Medicare", "Tax", "Withheld"],
            "word-neighbor-top-thres": 50,
            "word-neighbor-left-thres": 200,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [30]:
k = 10
num_docs = len(dl)

extracted_fields = [
    extract_fields(dl.get_document(i), field_queries)
    for i in range(num_docs)
]
labels = [dl.get_label(i) for i in range(num_docs)]

In [24]:
errors = evaluate_output(extracted_fields, labels, ['EIN', 'Medicare Tax withheld'])

Field: EIN	Accuracy: 0.7777777777777778
Field: Medicare Tax withheld	Accuracy: 0.8888888888888888


In [28]:
print(errors)

{'EIN': [(7, '199 Brent Row Suite 392', '95-5783877'), (8, '353 Sanders Fork', '55-0290753')], 'Medicare Tax withheld': [(2, '6 Medicare tax withheld', '1191.23')]}


In [3]:
label_page = dl.get_document(0).pages[0]
ocr_visualizer = hyperwidget.OCRVisualizer(
    page=label_page.as_dict(),
    layout=Layout(overflow_x='auto')
)

In [4]:
ocr_visualizer

OCRVisualizer(layout=Layout(overflow_x='auto'), page={'width': 1228, 'height': 1636, 'b64_image': 'iVBORw0KGgo…

In [5]:
print("Chosen Lines: ", [label_page.lines[i] for i in ocr_visualizer.line_idxs])
query = generate_query("Control number", label_page.lines[ocr_visualizer.line_idxs[0]], label_page)
print("Generated Query: ", query)

Chosen Lines:  [line(9854906)]
Generated Query:  {'name': 'Control number', 'arguments': {'x-position': 0.09201954397394137, 'y-position': 0.7121026894865525, 'entity': 'DATE', 'word-neighbors': ['‘d', 'Control', 'number', '9', 'Advance', 'EIC', 'payment', 'Shelly', 'Holmes', 'e', "Employee's", 'first', 'name', 'and', 'initial'], 'word-neighbor-top-thres': 0.05, 'word-neighbor-left-thres': 0.1}, 'weights': {'x-position': 0.5, 'y-position': 0.2, 'entity': 0.5, 'word-neighbors': 0.2}}
