In [None]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output

from ipywidgets import Layout
import hyperwidget

In [None]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(w2_sample_dir, 'single_clean')
    label_path = os.path.join(w2_sample_dir, 'single_label.csv')
    return Dataloader(data_dir, label_path)

dl = get_dataloader()

In [None]:
field_queries = [
    {   # Field 1
        "name": "EIN",
        "arguments": {
            "x-position": 0.1,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Employer", "Identification", "Number"],
            "word-neighbor-top-thres": 50,
            "word-neighbor-left-thres": 200,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    },
    {   # Field 2
        "name": "Medicare Tax withheld",
        "arguments": {
            "x-position": 0.9,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Medicare", "Tax", "Withheld"],
            "word-neighbor-top-thres": 50,
            "word-neighbor-left-thres": 200,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [None]:
k = 10
num_docs = len(dl)

extracted_fields = [
    extract_fields(dl.get_document(i), field_queries)
    for i in range(num_docs)
]
labels = [dl.get_label(i) for i in range(num_docs)]

In [None]:
errors = evaluate_output(extracted_fields, labels, ['EIN', 'Medicare Tax withheld'])

In [None]:
label_page = dl.get_document(0).pages[0]
ocr_visualizer = hyperwidget.OCRVisualizer(
    page=label_page.as_dict(),
    layout=Layout(overflow_x='auto')
)

In [None]:
ocr_visualizer

In [None]:
print("Chosen Lines: ", [label_page.lines[i] for i in ocr_visualizer.line_idxs])
query = generate_query("Control number", label_page.lines[ocr_visualizer.line_idxs[0]], label_page)
print("Generated Query: ", query)