In [1]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output
from utils.helpers import extracted_fields_serialized

import spacy
from iterextras import par_for

from ipywidgets import Layout
import hyperwidget

In [2]:
def get_w2_sample_dataloader(doc_type: str):
    main_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(main_dir, doc_type)
    label_path = os.path.join(main_dir, f'{doc_type}_label.csv')
    return Dataloader(data_dir, label_path, concatenate_pages=True, cache_to_disk=True)

def get_w2_train_dataloader(doc_type: str):
    main_dir = os.path.join('../data', 'full', 'w2', 'train')
    data_dir = os.path.join(main_dir, doc_type)
    label_path = os.path.join(main_dir, f'{doc_type}_label.csv')
    return Dataloader(data_dir, label_path, concatenate_pages=True, cache_to_disk=True)

sample_single = get_w2_sample_dataloader('single')
sample_multi = get_w2_sample_dataloader('multi')
full_single = get_w2_train_dataloader('single')
full_multi = get_w2_train_dataloader('multi')

nlp = spacy.load("en_core_web_sm")

In [21]:
visualizer = hyperwidget.OCRVisualizer(
    document=dl.get_document(8).as_dict_with_images()
)
visualizer

OCRVisualizer(document={'path': '../data/sample/w2/single/W2_XL_input_2807.pdf', 'pages': [{'lines': [{'words'…

In [8]:
page = dl.get_document(0).pages[0]
labeled_pages, labeled_lines = [page], [page[idx] for idx in visualizer.selected_lines]
query = generate_query("TEST", labeled_lines, labeled_pages)
query

{'name': 'TEST',
 'arguments': {'x-position': 0.35587188612099646,
  'y-position': 0.08022922636103152,
  'entity': 'SSN',
  'word-neighbors': [],
  'word-neighbor-max-top-dist': 21.928571428571427,
  'word-neighbor-max-left-dist': 21.928571428571427,
  'word-neighbor-max-bottom-dist': 21.928571428571427,
  'word-neighbor-max-right-dist': 21.928571428571427},
 'weights': {'x-position': 0.25,
  'y-position': 0.25,
  'entity': 0.25,
  'word-neighbors': 0.25}}

EIN,Employer's Name,Employer's Street Address,Employer's City-State-Zip,Employee Social Security Number,Employee Name,Employee Street Address,Employee's City-State-Zip,Control Number,"Wages, Tips & Other Compensation",Federal Income Tax Withheld,Social Security Wages,Social Security Tax Withheld,Medicare Wages & Tips,Medicare Tax Withheld,Social Security Tips,Allocated Tips,Advance EIC Payment,Dependent Care Benefits,Non-qualified Plans,12a Column 1,12a Column 2,12b Column 1,12b Column 2,12c Column 1,12c Column 2,12d Column 1,12d Column 2,Statutary Employee,Retirement Plan,Third Party Sick Pay,State_1,Employee State ID_1,State Wages & Tips_1,State Income Tax_1,Local Wages & Tips_1,Local Income Tax_1,Locality Name_1,State_2,Employee State ID_2,State Wages & Tips_2,State Income Tax_2,Local Wages & Tips_2,Local Income Tax_2,Locality Name_2


In [30]:
single_field_queries = [
    {   # Field 5
        "name": "Employee Social Security Number",
        "arguments": {
            "x-position": 0.36,
            "y-position": 0.08,
            "entity": "SSN",
            "word-neighbors": ["social security number"],
            "word-neighbor-max-top-dist": 100,
            "word-neighbor-max-left-dist": 100,
            "word-neighbor-max-bottom-dist": 0,
            "word-neighbor-max-right-dist": 0,
        },
        "weights": {
            "x-position": 0.,
            "y-position": 0.,
            "entity": 0.,
            "word-neighbors": 0.25,
        }
    },
#         {   # Field x
#         "name": "",
#         "arguments": {
#             "x-position": 0.5,
#             "y-position": 0.5,
#             "entity": "",
#             "word-neighbors": [""],
#             "word-neighbor-max-top-dist": 0,
#             "word-neighbor-max-left-dist": 0,
#             "word-neighbor-max-bottom-dist": 0,
#             "word-neighbor-max-right-dist": 0,
#         },
#         "weights": {
#             "x-position": 0.,
#             "y-position": 0.,
#             "entity": 0.,
#             "word-neighbors": 0.,
#         }
#     },
]
dl = sample_single
field_queries = single_field_queries
fields = [f["name"] for f in field_queries]

In [31]:
num_docs = len(dl)
extracted_fields = par_for(
    lambda i:  extract_fields(dl.get_document(i), field_queries, 1000, nlp),
    list(range(num_docs)),
    workers=2,
)
labels = [dl.get_label(i) for i in range(num_docs)]

errors = evaluate_output(extracted_fields, labels, fields)
error_table = hyperwidget.ErrorTable(
    errors = errors
)

extraction_heatmap = hyperwidget.ExtractionHeatmap(
    documents=[dl.get_document(i).as_dict_with_images() for i in range(num_docs)],
    labels=[dl.get_label(i) for i in range(num_docs)],
    extracted_fields=extracted_fields_serialized(extracted_fields)
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9.0), HTML(value='')))


Field: Employee Social Security Number	Accuracy: 1.0


In [32]:
error_table

ErrorTable(errors={'Employee Social Security Number': []})

In [33]:
extraction_heatmap

ExtractionHeatmap(documents=[{'path': '../data/sample/w2/single/W2_XL_input_1192.pdf', 'pages': [{'lines': [{'…

In [34]:
"Employee's social security number" == "Employee's social security number"

True