In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import pytesseract
from PIL import Image
import cv2
import pandas as pd 
import skimage
from pathlib import Path

## Dataset paths and loading

In [2]:
from common import *

  if len(img.shape) is 3 and not isinstance(padColor, (list, tuple, np.ndarray)): # color image but only one color provided


## Load sample image

In [3]:
im, data = load_sample(image_id=2)

## Drawing funcs

In [4]:
# Linetypes
def draw_pipelines(image, data=data):
    draw = image.copy()
    solid_lines = np.stack(data["lines"].query("type=='solid'")["box"])
    dashed_lines = np.stack(data["lines"].query("type=='dashed'")["box"])

    draw = cv2.drawContours(draw, solid_lines.reshape(-1,2,2), -1, (255, 255, 0), thickness=2)
    draw = cv2.drawContours(draw, dashed_lines.reshape(-1,2,2), -1, (0, 255, 255), thickness=2)
    return draw

def draw_symbols(image, data=data, color=None, thickness=2):
    draw = image.copy()
    for i, group in data["symbols"].groupby("class"):
        color_ = color or (np.random.rand(3)*255).astype(np.uint8)
        symbols = np.stack(group["box"])
        draw_rects(draw, symbols, color=[int(c) for c in color_], thickness=thickness)
    return draw

def draw_text_boxes(image, data=data, color=(255,0,255), thickness=1):
    draw = image.copy()
    text_boxes = np.stack(data["words"]["box"])
    draw_rects(draw, text_boxes, color=color, thickness=thickness)
    return draw

In [45]:
%matplotlib tk


# im = cv2.imread("test.jpg")
draw = im.copy()
draw = draw_pipelines(draw)
draw = draw_symbols(draw)
draw = draw_text_boxes(draw)
plt.imshow(draw)

<matplotlib.image.AxesImage at 0x250dcfbb700>

## Text removal

In [6]:
from craft_text_detector import (
    load_craftnet_model,
    load_refinenet_model,
    get_prediction,
)

# load models
refine_net = load_refinenet_model(cuda=True)
craft_net = load_craftnet_model(cuda=True)

In [7]:
window = np.array(im.shape[:2])/5
wh, ww = window
sh, sw = (window/2).astype(int)

In [47]:
from skimage.util import view_as_windows

gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
t, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
t = view_as_windows(thresh, (wh,ww), (sh,sw))

offsets = np.zeros((*t.shape[:2],2))
for i in range(t.shape[0]):
    for j in range(t.shape[1]):
        offsets[i,j] = (sw*j, sh*i)

# plt.imshow(thresh)
# plt.plot(offsets.reshape(-1,2)[:,0],offsets.reshape(-1,2)[:,1],"rx")

plt.imshow(t[0,0])
t.shape

(9, 9, 810, 1051)

In [9]:
outputs = []
for window in t.reshape(-1,int(wh),int(ww)):
    prediction_result = get_prediction(
        image=window,
        craft_net=craft_net,
        refine_net=refine_net,
        cuda=True,
        poly=False
    )
    outputs.append(prediction_result)

In [48]:
fig, axs = plt.subplots(1,2, sharex=True, sharey=True)
# plt.imshow(thresh)
draw1 = cv2.cvtColor(thresh,cv2.COLOR_GRAY2RGB)
draw2 = cv2.cvtColor(thresh,cv2.COLOR_GRAY2RGB)


offs_boxes = []
for offs, output in zip(offsets.reshape(-1,2), outputs):
    if len(output["boxes"])>0:
        boxes = output["boxes"] + offs

        offs_boxes.append(boxes)
        draw1 = cv2.drawContours(draw1,boxes[...,np.newaxis,:].astype(int),-1, (0,0,0), thickness=-1)
        draw2 = cv2.drawContours(draw2,boxes[...,np.newaxis,:].astype(int),-1, (255,0,0), thickness=5)

axs[0].imshow(draw1)
axs[1].imshow(draw2)

<matplotlib.image.AxesImage at 0x250c67667c0>

In [11]:
r = np.vstack(offs_boxes)[4, 0::2].astype(int)
plt.imshow(thresh[rect_to_slice(r)])

<matplotlib.image.AxesImage at 0x250c315a910>

In [12]:
boxes = np.vstack(offs_boxes)
print(len(boxes))
boxes_nms = non_max_suppression_fast(boxes[:,0::2].reshape(-1,4), overlapThresh=0.4)
print(boxes_nms.shape)


522
(153, 4)


In [14]:
fig, axs = plt.subplots(1,2, sharex=True, sharey=True)
# for psm in (0, 1, 3, 4, 5, 6, 7, 11, 12, 13):
psm = 7
draw2 = cv2.cvtColor(thresh,cv2.COLOR_GRAY2RGB)
text_cleanup = thresh.copy()

boxes_filtered = []

for i, r in enumerate(boxes_nms.reshape(-1,2,2)):
    crop = im[rect_to_slice(r, margin=5)]

    h, w = crop.shape[:2]

    tall = h > 1.3*w
    if tall:
        crop = cv2.rotate(crop, cv2.ROTATE_90_CLOCKWISE)


    try:
        text = pytesseract.image_to_string(crop, config=f"--oem 3 --psm {psm}")
    except pytesseract.TesseractError:
        print("Oopsie from tesseract")
        break
    
    if len(text)>0:
        alpha_percent = alpha_count(text) / len(text)

        if alpha_percent < 0.4 :
            draw_rects(draw2, r, (255,0,0), thickness=5)
        else:
            draw_rects(draw2, r, (0,255,0), thickness=5)
            draw_rects(text_cleanup, r, 0, thickness=-1)
            boxes_filtered.append(r.flatten())

        cv2.putText(draw2, text.strip(), r[0], cv2.FONT_HERSHEY_PLAIN, 2, (0,255,255))


axs[0].imshow(draw2)
axs[1].imshow(text_cleanup)

<matplotlib.image.AxesImage at 0x250e18d5430>

In [20]:

# boxA = (Ax1,Ay1,Ax2,Ay2)
# boxB = (Bx1,By1,Bx2,By2)
def boxesIntersect(boxA, boxB):
    if boxA[0] > boxB[2]:
        return False  # boxA is right of boxB
    if boxB[0] > boxA[2]:
        return False  # boxA is left of boxB
    if boxA[3] < boxB[1]:
        return False  # boxA is above boxB
    if boxA[1] > boxB[3]:
        return False  # boxA is below boxB
    return True

def getIntersectionArea(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # intersection area
    return (xB - xA + 1) * (yB - yA + 1)

def getUnionAreas(boxA, boxB, interArea=None):
    area_A = getArea(boxA)
    area_B = getArea(boxB)
    if interArea is None:
        interArea = getIntersectionArea(boxA, boxB)
    return area_A + area_B - interArea

def getArea(box):
    return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)

def iou(boxA, boxB):
    # if boxes dont intersect
    if not boxesIntersect(boxA, boxB):
        return 0
    interArea = getIntersectionArea(boxA, boxB)
    union = getUnionAreas(boxA, boxB, interArea=interArea)
    # intersection over union
    iou = interArea / union
    assert iou >= 0
    return iou

In [83]:
gts = np.stack(data["words"]["box"])
preds = np.stack(boxes_filtered)

h,w,_ = im.shape
gts = gts[(gts[:,0]< w) & (gts[:,0]> 0) ]
gts = gts[(gts[:,1]< h) & (gts[:,1]> 0) ]

print("gts:")
print(gts.shape)
print("preds:")
print(preds.shape)

gts:
(149, 4)
preds:
(137, 4)


In [84]:
thresh = 0.3
ious = np.zeros((len(preds), len(gts)))

# intersections = ious > thresh

In [85]:
for i, pred_box in enumerate(preds):
    for j, gt_box in enumerate(gts):
        ious[i,j]=iou(pred_box, gt_box)

TP = np.any(ious>thresh, axis=1)
FP = ~TP
FN = ~np.any(ious>thresh, axis=0)

recall = np.sum(TP) / (np.sum(FN) + np.sum(TP))
print(recall)

precision = np.sum(TP) / (np.sum(FP) + np.sum(TP))
print(precision)

0.8733333333333333
0.9562043795620438
