In [1]:
import sys
import os
import importlib
sys.path.append('..')

import utils
from utils.dataloader import Dataloader
from utils.extract import extract_fields
from utils.rule_gen import generate_query
from utils.evaluator import evaluate_output

from ipywidgets import Layout
import hyperwidget

In [2]:
def get_dataloader():
    w2_sample_dir = os.path.join('../data', 'sample', 'receipt')
    data_dir = os.path.join(w2_sample_dir, 'image')
    label_path = os.path.join(w2_sample_dir, 'label.csv')
    return Dataloader(data_dir, label_path)

dl = get_dataloader()

In [15]:
field_queries = [
    {   # Generated query 1
        'name': 'total',
        'arguments': {
            'x-position': 0.6506024096385542,
            'y-position': 0.5963172804532578,
            'entity': 'CARDINAL',
            'word-neighbors': ['Total', 'Payable'],
            'word-neighbor-top-thres': 0.05,
            'word-neighbor-left-thres': 0.1
        }, 
        'weights': {
            'x-position': 0.25,
            'y-position': 0.25,
            'entity': 0.25,
            'word-neighbors': 0.25
        }
    },
    {   # Manual address
        "name": "address",
        "arguments": {
            "x-position": 0.5,
            "y-position": 0.2,
            "entity": "GPE",
            "word-neighbors": ["Tel"],
            "word-neighbor-top-thres": 0.05,
            "word-neighbor-left-thres": 0.1,
        },
        "weights": {
            "x-position": 0.25,
            "y-position": 0.25,
            "entity": 0.25,
            "word-neighbors": 0.25,
        }
    }
]

In [9]:
k = 20
num_docs = len(dl)

extracted_fields = [
    extract_fields(dl.get_document(i), field_queries, k)
    for i in range(num_docs)
]
labels = [dl.get_label(i) for i in range(num_docs)]

In [10]:
errors = evaluate_output(extracted_fields, labels, ['total', 'address'])

Field: total	Accuracy: 0.18181818181818177
Field: address	Accuracy: 0.0


In [11]:
error_table = hyperwidget.ErrorTable(
    errors = errors
)

In [12]:
pages = [dl.get_document(i).pages[0].as_dict() for i in range(num_docs)]
labels = [dl.get_label(i) for i in range(num_docs)]
extracted_fields_serializable = [
    {
        k: [f.as_dict() for f in fl]
        for (k, fl)in fd.items()
    }
    for fd in extracted_fields
]

extraction_heatmap = hyperwidget.ExtractionHeatmap(
    pages=pages,
    labels=labels,
    extracted_fields=extracted_fields_serializable
)

In [13]:
error_table

ErrorTable(errors={'total': [(0, '45.90', '15.90'), (1, 'Total @% supplies: — 30.23', '48.04'), (2, '233 = zo …

In [14]:
extraction_heatmap

ExtractionHeatmap(extracted_fields=[{'total': [{'line': {'height': 33, 'width': 78, 'left': 447, 'top': 826, '…

In [4]:
ocr_visualizer = hyperwidget.OCRVisualizer(
    page=dl.get_document(0).pages[0].as_dict()
)

In [5]:
ocr_visualizer

OCRVisualizer(page={'width': 747, 'height': 1412, 'b64_image': 'iVBORw0KGgoAAAANSUhEUgAAAusAAAWECAIAAAD0nt1rAA…

In [7]:
print("Chosen Lines: ", [dl.get_document(0).pages[0].lines[i] for i in ocr_visualizer.line_idxs])
query = generate_query("total", dl.get_document(0).pages[0].lines[ocr_visualizer.line_idxs[0]], dl.get_document(0).pages[0])
print("Generated Query: ", query)

Chosen Lines:  [line(45.90)]
Generated Query:  {'name': 'total', 'arguments': {'x-position': 0.6506024096385542, 'y-position': 0.5963172804532578, 'entity': 'CARDINAL', 'word-neighbors': ['§0.00', '15.90'], 'word-neighbor-top-thres': 0.05, 'word-neighbor-left-thres': 0.1}, 'weights': {'x-position': 0.25, 'y-position': 0.25, 'entity': 0.25, 'word-neighbors': 0.25}}
