In [1]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output

from ipywidgets import Layout
import hyperwidget

In [2]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'w2')
    data_dir = os.path.join(w2_sample_dir, 'single_clean')
    label_path = os.path.join(w2_sample_dir, 'single_label.csv')
    return Dataloader(data_dir, label_path)

dl = get_dataloader()

In [None]:
field_queries = [
    {   # Field 1
        "name": "EIN",
        "arguments": {
            "x-position": 0.1,
            "y-position": 0.1,
            "entity": "EIN",
            "word-neighbors": ["Employer", "Identification", "Number"],
            "word-neighbor-top-thres": 0.05,
            "word-neighbor-left-thres": 0.1,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    },
    {   # Field 2
        "name": "Medicare Tax withheld",
        "arguments": {
            "x-position": 0.9,
            "y-position": 0.1,
            "entity": "CARDINAL",
            "word-neighbors": ["Medicare", "Tax", "Withheld"],
            "word-neighbor-top-thres": 0.05,
            "word-neighbor-left-thres": 0.1,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [None]:
k = 100
num_docs = len(dl)

extracted_fields = [
    extract_fields(dl.get_document(i), field_queries, k)
    for i in range(num_docs)
]
labels = [dl.get_label(i) for i in range(num_docs)]

In [None]:
errors = evaluate_output(extracted_fields, labels, ['EIN', 'Medicare Tax withheld'])

In [None]:
error_table = hyperwidget.ErrorTable(
    errors = errors
)

In [None]:
pages = [dl.get_document(i).pages[0].as_dict() for i in range(num_docs)]
labels = [dl.get_label(i) for i in range(num_docs)]
extracted_fields_serializable = [
    {
        k: [f.as_dict() for f in fl]
        for (k, fl)in fd.items()
    }
    for fd in extracted_fields
]

extraction_heatmap = hyperwidget.ExtractionHeatmap(
    pages=pages,
    labels=labels,
    extracted_fields=extracted_fields_serializable
)

In [None]:
error_table

In [None]:
extraction_heatmap

In [None]:
ocr_visualizer = hyperwidget.OCRVisualizer(
    page=dl.get_document(0).pages[0].as_dict()
)

In [None]:
ocr_visualizer

In [None]:
print("Chosen Lines: ", [label_page.lines[i] for i in ocr_visualizer.line_idxs])
query = generate_query("Control number", [label_page.lines[ocr_visualizer.line_idxs[0]]], [label_page])
print("Generated Query: ", query)

In [7]:
pages = [dl.get_document(i).pages[0] for i in range(5)]
pages_dict = [p.as_dict() for p in pages]
multidoc_gen = hyperwidget.MultiDocGen(
    pages=pages_dict
)

In [8]:
multidoc_gen

MultiDocGen(pages=[{'width': 1236, 'height': 1658, 'b64_image': 'iVBORw0KGgoAAAANSUhEUgAABNQAAAZ6CAIAAACbo6xpA…

In [9]:
labeled_pages, labeled_lines = [], []
print(multidoc_gen.selected_lines)
for doc in multidoc_gen.selected_lines:
    doc_idx = int(doc)
    labeled_pages.append(pages[doc_idx])
    labeled_lines.append(pages[doc_idx].lines[multidoc_gen.selected_lines[doc][0]])
query = generate_query("Control number", labeled_lines, labeled_pages)

{'0': [7], '1': [6], '2': [6], '3': [7], '4': [8]}
{'SSN': 2.0, 'CARDINAL': 3.0}


In [10]:
query

{'name': 'Control number',
 'arguments': {'x-position': 0.3131067961165049,
  'y-position': 0.020989143546441498,
  'entity': 'CARDINAL',
  'word-neighbors': ["Employee's", '{a', 'social', 'security', 'number'],
  'word-neighbor-top-thres': 0.05,
  'word-neighbor-left-thres': 0.1},
 'weights': {'x-position': 0.25,
  'y-position': 0.25,
  'entity': 0.25,
  'word-neighbors': 0.25}}