In [1]:
import sys
sys.path.append('..')

In [2]:
import difflib as dl
import re
import statistics as stat
import string
import subprocess
import tempfile
import warnings
from collections import Counter, defaultdict, namedtuple
from dataclasses import dataclass, field
from itertools import groupby, combinations
from pathlib import Path
from pprint import pp
from typing import Union
from types import SimpleNamespace

import numpy as np
import pandas as pd
from Bio import SeqUtils
from Bio.Seq import Seq
from IPython.display import display
from ipywidgets import interact
from PIL import Image, ImageDraw

from digi_leap.pylib import (
    line_align as la,
    db,
    image_util as iu,
    label_transforms as lt,
    line_align_subs as subs,
)

In [3]:
DATA = Path('..') / 'data' / 'sernec'
SHEETS = DATA / 'sheets'
DB = DATA / 'sernec.sqlite'

In [4]:
SHEETS = [dict(s) for s in db.select_sheets(DB)]

In [5]:
ocr = [dict(ocr) for ocr in db.select_ocr(DB)]
OCR = defaultdict(list)
for o in ocr:
    OCR[o['label_id']].append(o)

In [6]:
LABELS = [dict(lb) for lb in db.select_labels(DB) if lb['label_id'] in OCR]

In [7]:
@dataclass
class Row:
    left: int = 0
    top: int = 0
    right: int = 0
    bottom: int = 0
    text: list[str] = field(default_factory=list)
    source: tuple[str, str] = field(default_factory=tuple)

    @property
    def center(self) -> int:
        """Get the vertical center line of the box."""
        return (self.bottom + self.top) // 2

    def update(self, box, source=''):
        self.left = box['left']
        self.top = box['top']
        self.right = box['right']
        self.bottom = box['bottom']
        self.text.append(box['text'])
        if source:
            self.source = source

In [8]:
@dataclass
class Group:
    left: int = 0
    top: int = 0
    right: int = 0
    bottom: int = 0
    rows: list[Row] = field(default_factory=list)
    aligned: list[str] = field(default_factory=list)

    @property
    def center(self) -> int:
        """Get the vertical center line of the box."""
        return (self.bottom + self.top) // 2

    @property
    def text(self) -> str:
        return [' '.join(r.text) for r in self.rows]

    def update(self, row):
        self.left = row.left
        self.top = row.top
        self.right = row.right
        self.bottom = row.bottom
        self.rows.append(row)

In [9]:
def get_label(idx):
    with warnings.catch_warnings():  # Turn off EXIF warnings
        warnings.filterwarnings("ignore", category=UserWarning)
        lb = LABELS[idx]
        path = Path('..') / lb['path']
        sheet = Image.open(path)
        label = sheet.crop((lb["left"], lb["top"],
                           lb["right"], lb["bottom"]))
        return lb, label

In [10]:
def filter_boxes(boxes, image_height, conf=0.25, height_threshold=0.25, std_devs=2.0):
    """Remove problem bounding boxes from the data frame.

    Excuses for removing boxes include:
    - Remove bounding boxes with no text.
    - Remove boxes with a low confidence score (from the OCR engine) for the text.
    - Remove boxes that are too tall relative to the label.
    - Remove boxes that are really skinny or really short.
    """
    if len(boxes) < 2:
        return boxes

    too_tall = round(image_height * height_threshold)

    widths = [b['right'] - b['left'] for b in boxes]
    heights = [b['bottom'] - b['top'] for b in boxes]
    too_short = round(stat.mean(widths) - (std_devs * stat.stdev(widths)))
    too_thin = round(stat.mean(heights) - (std_devs * stat.stdev(heights)))

    filtered = []
    for box in boxes:
        width = box['right'] - box['left']
        height = box['bottom'] - box['top']

        # Remove boxes with nothing in them
        box['text'] = box['text'].strip()
        if not box['text']:
            continue

        # Remove boxes with low confidence
        if box['conf'] < conf:
            continue

        # Remove boxes that are too tall
        if height >= too_tall:
            continue

        # Remove boxes that are very thin or very short
        if width < too_thin or height < too_short:
            continue

        # It passes
        filtered.append(box)

    return filtered

In [11]:
def transform_label(lb, label):
    ocr = OCR[lb['label_id']]
    trans = lt.transform_label('deskew', label)
    trans = trans.convert('RGB')
    return trans

In [12]:
COLORS = {
    ('deskew', 'tesseract'): 'red',
    ('deskew', 'easy'): 'blue',
    ('binarize', 'tesseract'): 'green',
    ('binarize', 'easy'): 'orange',
}


def display_boxes(lb, label):
    ocr = OCR[lb['label_id']]
    draw = ImageDraw.Draw(label)

    for o in ocr:
        box = [o['left'], o['top'], o['right'], o['bottom']]
        color = COLORS[(o['pipeline'], o['engine'])]
        draw.rectangle(box, outline=color, width=2)

In [13]:
def trim_boxes(lb, label):
    ocr = OCR[lb['label_id']]

    for o in ocr:
        box = [o['left'], o['top'], o['right'], o['bottom']]
        box = label.crop(box)
        if box.size[0] == 0 or box.size[1] == 0:
            continue
        proj = iu.profile_projection(box)
        above = np.where(proj > 0)
        if above and len(above[0]) > 0:
            o['bottom'] = o['top'] + above[0][-1]
            o['top'] += above[0][0]

In [14]:
def line_overlap(top1, bottom1, top2, bottom2, eps=1):
    area = min(bottom1 - top1, bottom2 - top2)
    y_min = max(top1, top2)
    y_max = min(bottom1, bottom2)
    inter = max(0, y_min - y_min)
    return inter / (area + eps)

In [15]:
def get_rows(lb, label, overlap_fraction=0.1):
    lines = []
    ocr = OCR[lb['label_id']]

    grouped = groupby(ocr, lambda o: (o['pipeline'], o['engine']))

    for source, boxes in grouped:
        boxes = list(boxes)
        boxes = filter_boxes(boxes, boxes[0]['height'])
        if not boxes:
            continue

        boxes = sorted(boxes, key=lambda b: b['left'])

        rows: list[Row] = []

        for box in boxes:
            center_line = [r for r in rows if box['top']
                           <= r.center <= box['bottom']]

            # Did we find a row that this box overlaps?
            if len(center_line) == 1:
                center_line[0].update(box)
                continue

            # Is the box big and covers multiple rows?
            if len(center_line) > 1:
                mid = (box['top'] + box['bottom']) // 2
                center_line = sorted(
                    center_line, key=lambda r: abs(r.center - mid))
                center_line[0].update(box)
                continue

            # Do we have significant overlap anyway?
            over = [(line_overlap(box['top'], box['bottom'], r.top, r.bottom), r) for r in rows]
            over = sorted(over, key=lambda r: -r[0])
            if over and over[0][0] > overlap_fraction:
                over[0][1].update(box)
                continue

            # We have a new row
            row = Row()
            row.update(box, source=source)
            rows.append(row)

        lines.append(sorted(rows, key=lambda r: r.top))

    return lines

In [16]:
def group_rows(rows, overlap_fraction=0.1):
    groups = []

    for source in rows:
        for ln in source:
            overlaps = [g for g in groups if ln.top <= g.center <= ln.bottom]

            # Did we find a row that this box overlaps?
            if len(overlaps) == 1:
                overlaps[0].update(ln)
                continue

            # Is the line big and covers multiple rows?
            if len(overlaps) > 1:
                mid = (ln.top + ln.bottom) // 2
                overlaps = sorted(overlaps, key=lambda r: abs(r.center - mid))
                overlaps[0].update(ln)
                continue

            # Do we have significant overlap anyway?
            over = [(line_overlap(g.top, g.bottom, ln.top, ln.bottom), g) for g in groups]
            over = sorted(over, key=lambda r: -r[0])
            if over and over[0][0] > overlap_fraction:
                over[0][1].update(ln)
                continue

            # We have a new row
            group = Group()
            group.update(ln)
            groups.append(group)

    groups = sorted(groups, key=lambda g: g.top)

    return groups

In [17]:
def consensus(group):
    cons = []
    if len(group.text) < 1:
        return ''

    group.aligned = la.align_all(group.text, subs.SUBS)
    str_len = len(group.aligned[0])

    counts = defaultdict(int)
    for i in range(str_len):
        counts = Counter(s[i] for s in group.aligned).most_common()
        counts = sorted(counts, key=lambda c: (-c[1], c[0]))
        cons.append(counts[0][0])

    result = ''.join(cons)
    result = result.replace('⋄', '')
    return result

In [18]:
TIE_BREAKER = {}
REPLACER = {}

# Need to do a dictionary check on weak consensus


def heuristics(groups):
    cons = []
    for group in groups:
        if len(group.text) == 1:
            continue
        con = consensus(group)
        con = re.sub(r"(\w)\s([.?!,;:\-'\"])", r'\1\2', con)
        con = con.replace('_', '')
        cons.append(con)
    return cons

In [20]:
def ocr_label(idx):
    lb, label = get_label(idx)
    label = transform_label(lb, label)
    # trim_boxes(lb, label)
    rows = get_rows(lb, label)
    display_boxes(lb, label)
    groups = group_rows(rows)
    cons = heuristics(groups)

    print('\n'.join(cons))
    display(label)
    for group in groups:
        print('\n'.join(group.aligned))
        print()


interact(ocr_label, idx=(0, len(LABELS) - 1));
# ocr_label(0)

interactive(children=(IntSlider(value=5561, description='idx', max=11122), Output()), _dom_classes=('widget-in…