In [1]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output
from utils.helpers import extracted_fields_serialized

from ipywidgets import Layout
import hyperwidget

In [2]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(w2_sample_dir, 'single_clean')
    label_path = os.path.join(w2_sample_dir, 'single_label.csv')
    return Dataloader(data_dir, label_path)

dl = get_dataloader()

In [None]:
field_queries = [
    {   # Field 1
        "name": "EIN",
        "arguments": {
            "x-position": 0.1,
            "y-position": 0.1,
            "entity": "EIN",
            "word-neighbors": ["Employer", "Identification", "Number"],
            "word-neighbor-top-thres": 0.05,
            "word-neighbor-left-thres": 0.1,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    },
    {   # Field 2
        "name": "Medicare Tax withheld",
        "arguments": {
            "x-position": 0.9,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Medicare", "Tax", "Withheld"],
            "word-neighbor-top-thres": 0.05,
            "word-neighbor-left-thres": 0.1,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [None]:
num_docs = len(dl)

extracted_fields = [
    extract_fields(dl.get_document(i), field_queries, 1000)
    for i in range(num_docs)
]
labels = [dl.get_label(i) for i in range(num_docs)]

In [None]:
errors = evaluate_output(extracted_fields, labels, ['EIN', 'Medicare Tax withheld'])

In [None]:
error_table = hyperwidget.ErrorTable(
    errors = errors
)

In [None]:
extraction_heatmap = hyperwidget.ExtractionHeatmap(
    documents=[dl.get_document(i).as_dict_with_images() for i in range(num_docs)],
    labels=[dl.get_label(i) for i in range(num_docs)],
    extracted_fields=extracted_fields_serialized(extracted_fields)
)

In [None]:
error_table

In [None]:
extraction_heatmap

In [None]:
ocr_visualizer = hyperwidget.OCRVisualizer(
    document=dl.get_document(0).as_dict_with_images()
)

In [None]:
ocr_visualizer

In [None]:
print("Chosen Lines: ", [label_page.lines[i] for i in ocr_visualizer.line_idxs])
query = generate_query("Control number", [label_page.lines[ocr_visualizer.line_idxs[0]]], [label_page])
print("Generated Query: ", query)

In [3]:
documents = [dl.get_document(i) for i in range(5)]
documents_dict = [d.as_dict_with_images() for d in documents]
pages = [d.pages[0] for d in documents]
multidoc_gen = hyperwidget.MultiDocGen(
    documents=documents_dict
)

In [4]:
multidoc_gen

MultiDocGen(documents=[{'path': '../data/sample/w2/single_clean/W2_XL_input_1192.pdf', 'pages': [{'lines': [{'…

In [7]:
labeled_pages, labeled_lines = [], []
for doc in multidoc_gen.selected_lines:
    doc_idx = int(doc)
    labeled_pages.append(pages[doc_idx])
    labeled_lines.append(pages[doc_idx].lines[multidoc_gen.selected_lines[doc][0]])
query = generate_query("Control number", labeled_lines, labeled_pages)

In [8]:
query

{'name': 'Control number',
 'arguments': {'x-position': 0.8120915032679739,
  'y-position': 0.09848484848484848,
  'entity': 'CARDINAL',
  'word-neighbors': ['Website',
   'Social',
   'tax',
   'Federal',
   'withheld',
   'the',
   'security',
   'Medicare',
   'www.irs.gov/efile.',
   'IRS',
   'at',
   'Visit',
   '2',
   '6',
   'income',
   '4'],
  'word-neighbor-top-thres': 0.05,
  'word-neighbor-left-thres': 0.1},
 'weights': {'x-position': 0.25,
  'y-position': 0.25,
  'entity': 0.25,
  'word-neighbors': 0.25}}