In [None]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output
from utils.helpers import extracted_fields_serialized

import spacy
from iterextras import par_for

from ipywidgets import Layout
import hyperwidget

In [None]:
def get_w2_sample_dataloader(doc_type: str):
    main_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(main_dir, doc_type)
    label_path = os.path.join(main_dir, f'{doc_type}_label.csv')
    return Dataloader(data_dir, label_path, concatenate_pages=True, cache_to_disk=True)

def get_w2_train_dataloader(doc_type: str):
    main_dir = os.path.join('../data', 'full', 'w2', 'train')
    data_dir = os.path.join(main_dir, doc_type)
    label_path = os.path.join(main_dir, f'{doc_type}_label.csv')
    return Dataloader(data_dir, label_path, concatenate_pages=True, cache_to_disk=True)

sample_single = get_w2_sample_dataloader('single')
sample_multi = get_w2_sample_dataloader('multi')
full_single = get_w2_train_dataloader('single')
full_multi = get_w2_train_dataloader('multi')

nlp = spacy.load("en_core_web_sm")

In [None]:
single_field_queries = [
#     {   # Field 1 Acc 1.0
#         "name": "EIN",
#         "arguments": {
#             "x-position": 0.10,
#             "y-position": 0.05,
#             "entity": "EIN",
#             "word-neighbors": ["Employer Identification Number"],
#             "word-neighbor-max-top-dist": 13,
#             "word-neighbor-max-left-dist": 15,
#             "word-neighbor-max-bottom-dist": 0,
#             "word-neighbor-max-right-dist": 0,
#         },
#         "weights": {
#             "x-position": 0.25,
#             "y-position": 0.25,
#             "entity": 0.25,
#             "word-neighbors": 0.25,
#         }
#     },
#     {   # Field 2 Acc 1.0
#         "name": "Employer's Name",
#         "arguments": {
#             "x-position": 0.22,
#             "y-position": 0.09,
#             "entity": "",
#             "word-neighbors": ["Employer's name, address, and ZIP code"],
#             "word-neighbor-max-top-dist": 15,
#             "word-neighbor-max-left-dist": 150,
#             "word-neighbor-max-bottom-dist": 0,
#             "word-neighbor-max-right-dist": 0,
#         },
#         "weights": {
#             "x-position": 0.33,
#             "y-position": 0.33,
#             "entity": 0.,
#             "word-neighbors": 0.33,
#         }
#     },
#     {   # Field 3 Acc 1.0
#         "name": "Employer's Street Address",
#         "arguments": {
#             "x-position": 0.22,
#             "y-position": 0.16,
#             "entity": "",
#             "word-neighbors": ["Control number", "Employer's name, address, and ZIP code"],
#             "word-neighbor-max-top-dist": 30,
#             "word-neighbor-max-left-dist": 150,
#             "word-neighbor-max-bottom-dist": 50,
#             "word-neighbor-max-right-dist": 0,
#         },
#         "weights": {
#             "x-position": 0.25,
#             "y-position": 0.30,
#             "entity": 0.,
#             "word-neighbors": 0.45,
#         }
#     },
#     {   # Field 4 Acc 1.0
#         "name": "Employer's City-State-Zip",
#         "arguments": {
#             "x-position": 0.22,
#             "y-position": 0.2,
#             "entity": "",
#             "word-neighbors": ["Control number", "Employer's name, address, and ZIP code"],
#             "word-neighbor-max-top-dist": 50,
#             "word-neighbor-max-left-dist": 150,
#             "word-neighbor-max-bottom-dist": 30,
#             "word-neighbor-max-right-dist": 0,
#         },
#         "weights": {
#             "x-position": 0.25,
#             "y-position": 0.30,
#             "entity": 0.,
#             "word-neighbors": 0.45,
#         }
#     },
]

In [None]:
dl = full_single
field_queries = single_field_queries

In [None]:
fields = [f["name"] for f in field_queries]

num_docs = len(dl)
extracted_fields = par_for(
    lambda i:  extract_fields(dl.get_document(i), field_queries, 1000, nlp),
    list(range(num_docs)),
    workers=8,
)
labels = [dl.get_label(i) for i in range(num_docs)]

errors = evaluate_output(extracted_fields, labels, fields)

In [None]:
error_table = hyperwidget.ErrorTable(
    errors = errors
)
error_table