In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install vietocr==0.3.5

Collecting vietocr==0.3.5
[?25l  Downloading https://files.pythonhosted.org/packages/95/c1/c6343a49b124586c0cf3359ee013079c21a3271720fd456ac2b0b11354df/vietocr-0.3.5-py3-none-any.whl (61kB)
[K     |████████████████████████████████| 61kB 3.3MB/s 
Collecting lmdb==1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/83/32/b809a148132dc76ca9ada176c08c7996f1d36c7cb24dac0907e50e472562/lmdb-1.0.0-cp37-cp37m-manylinux1_x86_64.whl (280kB)
[K     |████████████████████████████████| 286kB 5.8MB/s 
[?25hCollecting gdown==3.11.0
  Downloading https://files.pythonhosted.org/packages/db/f9/757abd4b0ebf60f3d276b599046c515c070fab5161b22abb952e35f3c0a4/gdown-3.11.0.tar.gz
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting einops==0.2.0
  Downloading https://files.pythonhosted.org/packages/89/32/5ded0a73d2e14ef5a6908a930c3e1e9f92ffead482a2f153182b7429066e/einops-0.2.0

In [3]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import os
import imutils
from collections import defaultdict
import string
from openpyxl import Workbook
from PIL import Image
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

In [4]:
config = Cfg.load_config_from_name('vgg_transformer')
config['weights'] = 'https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA'
config['cnn']['pretrained']=False
config['device'] = 'cuda:0'
config['predictor']['beamsearch']=False

In [5]:
detector = Predictor(config)

Cached Downloading: /root/.cache/gdown/https-COLON--SLASH--SLASH-drive.google.com-SLASH-uc-QUESTION-id-EQUAL-13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
Downloading...
From: https://drive.google.com/uc?id=13327Y1tz1ohsm5YZMyXVMPIOjoOA0OaA
To: /root/.cache/gdown/tmp9beoh4vq/dl
152MB [00:00, 161MB/s]


In [6]:
def invert_img(img):
    '''
    return binary image with white lines, text and black background
    '''
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_gray = cv2.bilateralFilter(img_gray, 9, 15, 15) 
    not_image = cv2.bitwise_not(img_gray)
    img_bin = cv2.adaptiveThreshold(not_image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, -4)
    return img_bin

In [7]:
def get_vertical_lines(img_bin):
    '''
    return vertical lines
    '''
    kernel_length_v = 10
    #create kernel
    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_length_v))
    im_temp1 = cv2.erode(img_bin, vertical_kernel, iterations=3)
    vertical_lines_img = cv2.dilate(im_temp1, vertical_kernel, iterations=3)
    return vertical_lines_img

In [8]:
def get_horizontal_lines(img_bin):
    '''
    return horizontal lines
    '''
    kernel_length_h = 20
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_length_h, 1))
    im_temp2 = cv2.erode(img_bin, horizontal_kernel, iterations=3)
    horizontal_lines_img = cv2.dilate(im_temp2, horizontal_kernel, iterations=3)
    return horizontal_lines_img

In [9]:
def combine_box(box_i, box_j):
    x_min = min(box_i[0], box_j[0])
    y_min = min(box_i[1], box_j[1])
    x_max = max(box_i[0] + box_i[2], box_j[0] + box_j[2])
    y_max = max(box_i[1] + box_i[3], box_j[1] + box_j[3])
    box = [x_min, y_min, x_max - x_min, y_max - y_min]
    return box

def combine_point(x1, y1, x2, y2):
    x_min = min(x1, x2)
    y_min = min(y1, y2)
    x_max = max(x1, x2)
    y_max = max(y1, y2)
    w = x_max - x_min if x_max - x_min > 0 else 1
    h = y_max - y_min if y_max - y_min > 0 else 1
    box = [x_min, y_min, w, h]
    return box

In [10]:
def joints_ver_hor_lines(vertical_lines_img, horizontal_lines_img):
    '''
    return joints vertical and horizontal lines
    '''
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    ## add weight thuc hien viec cong hai anh theo weights
    table_segment = cv2.addWeighted(vertical_lines_img, 0.5, horizontal_lines_img, 0.5, 0.0)
    ## thuc hien dao nguoc anh va lam nho nhung vung mau trang ==> duong net mau den se hien ra to hon
    table_segment = cv2.erode(cv2.bitwise_not(table_segment), kernel, iterations=2)
    thresh, table_segment = cv2.threshold(table_segment, 0, 255, cv2.THRESH_OTSU)
    plt.imshow(table_segment, 'gray')
    return cv2.bitwise_not(table_segment)

In [11]:
def extend_boxes(horizontal_boxes, vertical_boxes, image_h, image_w, scale = 1):
    INTERSECT_THRESHOLD = 15   
    # # Create new horizontal lines to connect the top of the vertical lines
    # # horizontal_boxes = sorted(horizontal_boxes, key=lambda k: [k[1], k[0]], reverse=True)
    # vertical_boxes = sorted(vertical_boxes, key=lambda k: [k[0], k[1]])
    extended_horizontal_boxes = horizontal_boxes.copy()
    extended_vertical_boxes = vertical_boxes.copy()

    #####################################
    # Create new horizontal lines to connect the bottom of the vertical lines
    horizontal_boxes = sorted(horizontal_boxes, key=lambda k: [k[1], k[0]])
    vertical_boxes = sorted(vertical_boxes, key=lambda k: [k[0], k[1]])
    bottom_horizontal = horizontal_boxes[-1]
    most_left_ver = vertical_boxes[0]
    most_right_ver = vertical_boxes[-1]
    check_bottom_hori = False
    for i in range(len(horizontal_boxes)):
        x1, y1, h1, w1 = horizontal_boxes[i]
        if abs(y1 - max(most_left_ver[1] + most_left_ver[3], most_right_ver[1] + most_right_ver[3])) < INTERSECT_THRESHOLD:
            check_bottom_hori = True
            break
    if not check_bottom_hori:
        y_h = min(most_left_ver[1] + most_left_ver[3], most_right_ver[1] + most_right_ver[3])
        if y_h <= image_h:
            y_h = y_h - 5
        extended_horizontal_boxes.append(combine_point(most_left_ver[0], y_h, most_right_ver[0], y_h))
    
    ##############################

    # create horizontal lines to connect 2 horizontal lines
    horizontal_boxes = sorted(horizontal_boxes, key=lambda k: [k[1], k[0]])
    for i in range(len(horizontal_boxes)-1):
        box1 = horizontal_boxes[i]
        x1, y1, w1, h1 = box1
        for j in range(i+1, len(horizontal_boxes)):
            box2 = horizontal_boxes[j]
            x2, y2, w2, h2 = box2
            if abs(x2 - x1 - w1) < INTERSECT_THRESHOLD*3  and abs(y2 - y1) < 5:
                extended_horizontal_boxes.append(combine_point(x1 + w1, y1, x2, y2))
                break
    
    ###################################

    #create vertical lines to connect 2 vertical lines
    for i in range(len(vertical_boxes)-1):
        box1 = vertical_boxes[i]
        x1, y1, w1, h1 = box1
        for j in range(i+1, len(vertical_boxes)):
            box2 = vertical_boxes[j]
            x2, y2, w2, h2 = box2
            if abs(y2 - y1 - h1) < INTERSECT_THRESHOLD*3 and abs (x2 - x1) < 15:
                extended_vertical_boxes.append(combine_point(x1, y1, x1, y2))
                break

    ###################################

    # extend horizontal lines till meet vertical lines:
    ## left
    for i in range(len(horizontal_boxes)):
        x1, y1, w1, h1 = horizontal_boxes[i]
        ## left
        if x1 - most_left_ver[0] < INTERSECT_THRESHOLD*2 and w1 > 100:
            extended_horizontal_boxes.append(combine_point(most_left_ver[0], y1, x1, y1))
            break
        ## right
        if most_right_ver[0] - x1 < INTERSECT_THRESHOLD*2 and w1 > 100:
            extended_horizontal_boxes.append(combine_point(x1, y1, most_right_ver[0], y1))
            break

    vertical_boxes = sorted(extended_vertical_boxes, key=lambda k: [k[0], k[1]])
    horizontal_boxes = sorted(extended_horizontal_boxes, key=lambda k: [k[1], k[0]])

    return horizontal_boxes, vertical_boxes

In [12]:
def extract_line_boxes(horizontal_lines, vertical_lines):
    '''
    return bounding boxes of lines
    '''
    image_h, image_w = horizontal_lines.shape[:2]
    contours_h, _ = cv2.findContours(horizontal_lines.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    rects_h = [cv2.boundingRect(c) for c in contours_h]
    hori_boxes = sorted(rects_h, key=lambda k: [k[1], k[0]])
    contours_v, _ = cv2.findContours(vertical_lines.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    rects_v = [cv2.boundingRect(c) for c in contours_v]
    ver_boxes = sorted(rects_v, key=lambda k: [k[0], k[1]])
    if len(hori_boxes) == 0 or len(ver_boxes) == 0:
        hori_extend_boxes, ver_extend_boxes = hori_boxes, ver_boxes
    else:
        hori_extend_boxes, ver_extend_boxes = extend_boxes(hori_boxes, ver_boxes, image_h, image_w, scale = 1)

    return hori_extend_boxes, ver_extend_boxes


In [13]:
def draw_cell_lines(cell_img, horizontal_boxes, vertical_boxes):
    '''
    return table image
    '''
    #Draw horizontal lines
    for box in horizontal_boxes:
        x,y,w,h = box
        start_x = x
        start_y = int(y+h/2)
        end_x = x+w
        end_y = start_y
        cell_img = cv2.line(cell_img, (start_x, start_y), (end_x, end_y), (255,255,255), 1)

    #Draw vertical lines
    for box in vertical_boxes:
        x,y,w,h = box
        start_x = int(x+w/2)
        start_y = y
        end_x = start_x
        end_y = y+h
        cell_img = cv2.line(cell_img, (start_x, start_y), (end_x, end_y), (255,255,255), 1)
    return cell_img

In [14]:
def draw_cells(img, cells, font_path='arial.pil'):
    img_draw = img.copy()
    COLOR_TEXT = (0, 0, 255)
    THICK_BOX = 2
    for idx, (x, y, w, h) in enumerate(cells):
        cv2.rectangle(img_draw, (x, y), (x + w, y + h), COLOR_TEXT, THICK_BOX)
        cv2.putText(img_draw, str(idx), (x , y), cv2.FONT_HERSHEY_DUPLEX, 0.6,
                    COLOR_TEXT, 1)
    from PIL import Image
    img_draw = Image.fromarray(img_draw)
    img_draw = np.array(img_draw)
    return img_draw

In [15]:
def table_segment(contours):
    '''
    return bounding box covers all tables
    '''
    list_area = []
    list_contour = []
    for c in contours:
        area = cv2.contourArea(c)
        if area > 50000:
            list_contour.append(c)
    x_min = 10000
    y_min = 10000
    x_max = 0
    y_max = 0
    for idx, c in enumerate(list_contour):
        xc, yc, wc, hc = cv2.boundingRect(c)
        if xc < x_min:
            x_min = xc
        if yc < y_min:
            y_min = yc
        if (xc + wc) > x_max:
            x_max = xc + wc
        if (yc + hc) > y_max:
            y_max = yc + hc
    return x_min, y_min, x_max, y_max

In [16]:
def iou_cal(box1, box2):
    x1, y1, w1, h1, area_1 = box1
    x2, y2, w2, h2, area_2 = box2
    x_tl = max(x1, x2)
    x_tr = min(x1+w1, x2+w2)
    y_tr = max(y1, y2)
    y_br = min(y1+h1, y2+h2)
    if x_tr < x_tl or y_br < y_tr:
        return 0.0
    else:
        return (x_tr - x_tl)*(y_br - y_tr)/min(area_1, area_2)

In [17]:
def resize_image(img):
    h, w = img.shape[:2]
    max_hw = max(h, w)
    scale = 2048 / max_hw
    img = cv2.resize(img, None, fx = scale, fy = scale)
    return img

In [18]:
def find_cell_bbox(contours, hierarchy, margin_x, margin_y, img = None):
    list_cell = []
    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 1
    color = (255, 0, 0)
    thickness = 2
    for idx, c in enumerate(contours):
        xc, yc, wc, hc = cv2.boundingRect(c)
        xc, yc = xc + margin_x, yc + margin_y
        if (wc > 10 and hc > 10) and (hierarchy[0][idx][2] == -1) and (hierarchy[0][idx][3] != -1):
            if img is not None:
                cv2.rectangle(img,(xc, yc),(xc + wc, yc + hc),(0, 255, 0), 2)
            list_cell.append((xc, yc, wc, hc))
        # elif (hierarchy[0][idx][3] == -1) and (wc > 100 and hc > 100) and (hierarchy[0][idx][2] != -1):
        #     print(xc, yc, wc, hc))
            # cv2.rectangle(img,(xc, yc),(xc + wc, yc + hc),(255, 125, 125), 2)
            # cv2.putText(img, str(idx), (xc, yc + 50), font,  
            #     fontScale, color, thickness, cv2.LINE_AA)2
    return list_cell, img

In [19]:
def get_row_col_dict(cell_boxes):
    '''
    return row (dict) and column (dict) coordinate
    '''
    row = get_row_dict(cell_boxes)
    col = get_col_dict(cell_boxes)
    return row, col

In [20]:
def get_row_dict(list_cell):
    row_dict = defaultdict()
    row_count = 1
    list_cell = sorted(list_cell, key=lambda k: [k[1], k[0]])
    for i in range(len(list_cell)-2):
        y = list_cell[i][1]
        y_next = list_cell[i+1][1]
        if y_next-y > 10:
            row = int((y_next+y)/2)
            row_dict[str(row_count)] = row
            row_count += 1
    # # add 1 last row
    row_dict[str(row_count)] = int((list_cell[-1][1]+list_cell[-1][3]/2))
    return row_dict

In [21]:
def get_col_dict(list_cell):
    '''
    return col_dict = {'A': 51, 'B': 334, 'C': 1064}
    '''
    col_dict = defaultdict()
    col_count = 0
    list_cell = sorted(list_cell, key=lambda k: [k[0], k[1]])
    for i in range(len(list_cell)-2):
        x = list_cell[i][0]
        x_next = list_cell[i+1][0]
        if x_next-x > 10:
            col = int((x_next+x)/2)
            if col_count>25:
                col_dict[string.ascii_uppercase[col_count//26-1]+string.ascii_uppercase[col_count%26-26]] = col
            else:
                col_dict[string.ascii_uppercase[col_count]] = col
            col_count += 1
    # #add 1 last col
    col = int((list_cell[-1][0]+list_cell[-1][2]/2))
    if col_count>25:
        col_dict[string.ascii_uppercase[col_count//26-1]+string.ascii_uppercase[col_count%26-26]] = col
    else:
        col_dict[string.ascii_uppercase[col_count]] = col
    return col_dict

In [22]:
def get_sheet_cell_and_merge(list_cell, row_dict, col_dict, img):
    sheet_cell_dict = defaultdict()
    list_sheet_merge = []
    for cell in list_cell:
        x, y, w, h = cell
        merge_count = 0
        merge_start = ''
        merge_end = ''
        cell_image = img[y:y+h, x:x+w]
        cell_image = Image.fromarray(cell_image)
        text = detector.predict(cell_image)
        merge_row = []
        merge_col = []
        for row, row_y in row_dict.items():
            if y + h < row_y:
                break
            if y < row_y and y + h > row_y:
                merge_row.append(row)
        for col, col_x in col_dict.items():
            if x + w < col_x:
                break
            if x < col_x and x + w > col_x:
                merge_col.append(col)
        if len(merge_row) == 0 or len(merge_col) == 0:
            continue
        sheet_cell_dict[merge_col[0] + merge_row[0]] = text
        if len(merge_row) > 1 or len(merge_col) > 1:
            merge_start = merge_col[0] + merge_row[0]
            merge_end = merge_col[-1] + merge_row[-1]
            list_sheet_merge.append('{}:{}'.format(merge_start, merge_end))
    return sheet_cell_dict, list_sheet_merge

In [23]:
def write_sheet(sheet, sheet_cell_dict, list_sheet_merge):
    # wb = Workbook()

    # ws1 = wb.active
    # ws1.title = "filename"
    for cell, text in sheet_cell_dict.items():
        sheet[cell] = text
    for merge_range in list_sheet_merge:
        sheet.merge_cells(merge_range)
    return sheet

In [24]:
# folder_path = '/content/drive/MyDrive/Colab Notebooks/KTTV'

In [27]:
folder_path = '/content/drive/MyDrive/CourseAI-VBD/CV-VBD/data-KTTV/data-KTTV-Loc/data-KTTV-jpeg'

In [29]:
plt.rcParams["figure.figsize"] = (50,10)
out_path = "/content/drive/MyDrive/CourseAI-VBD/CV-VBD/data_output"
for name in os.listdir(folder_path):
    img_path = os.path.join(folder_path, name)
    print(img_path)
    img = cv2.imread(img_path)
    img = resize_image(img)
    # plt.imshow(img)
    # plt.show()

    h, w = img.shape[:2]
    img = img[20:h-20, 20:w-20]
    img_not = invert_img(img = img)
    # plt.imshow(img_not, 'gray')
    # plt.show()

    vertical_lines = get_vertical_lines(img_not)
    horizontal_lines = get_horizontal_lines(img_not)
    ver_hor_lines = joints_ver_hor_lines(vertical_lines, horizontal_lines)
    # plt.imshow(ver_hor_lines)
    # plt.show()

    ### tìm kiếm vùng bbox lớn nhất bao được tất cả các bảng
    cnts = cv2.findContours(ver_hor_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
    c = sorted(cnts, key=cv2.contourArea, reverse=True)
    x_tl, y_tl, x_br, y_br = table_segment(c)
    segmented_table = ver_hor_lines[y_tl:y_br, x_tl:x_br]
    x_crop = x_tl
    y_crop = y_tl
    # plt.imshow(segmented_table, 'gray')
    # plt.show()

    ### tìm kiếm các bảng có trong bbox trên
    area_contours = []
    cnts = cv2.findContours(segmented_table, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
    area_image = segmented_table.shape[0]*segmented_table.shape[1]
    if len(cnts) < 1:
        pass
    elif len(cnts) == 1:
        area_contours.append(cv2.boundingRect(cnts[0]))
    else:
        for cnt in cnts:
            x,y,w,h = cv2.boundingRect(cnt)
            rect_area = w*h
            ratio = rect_area/area_image
            if ratio > 0.01:
                area_contours.append((x,y,w,h))
        area_contours = sorted(area_contours, key = lambda k: [k[0], k[1]])
    
    count = 0
    wb = Workbook()
    ws1 = wb.active
    print("Số lượng bảng tìm được:", len(area_contours))
    for contour in area_contours:
        x, y, w, h = contour
        print(x, y, w, h)
        table_bw = segmented_table[y:y+h, x:x+w]
        vert = get_vertical_lines(table_bw)
        hori = get_horizontal_lines(table_bw)
        hori_boxes, vert_boxes = extract_line_boxes(hori, vert)
        # Create new black image
        cell_image = np.zeros(table_bw.shape,dtype=np.uint8)
        # Draw lines
        cell_image = draw_cell_lines(cell_image, hori_boxes, vert_boxes)
        # plt.imshow(cell_image, 'gray')
        # plt.show()
        margin_x = x_crop + x
        margin_y = y_crop + y
        contours, hierarchy = cv2.findContours(table_bw, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
        list_cells, img = find_cell_bbox(contours, hierarchy, margin_x, margin_y, img)
        if len(list_cells) == 0:
            continue
        row_dict, col_dict = get_row_col_dict(list_cells)
        sheet_cell_dict, list_sheet_merge = get_sheet_cell_and_merge(list_cells, row_dict, col_dict, img)
        if count > 0:
            sheet = wb.create_sheet("Sheet {}".format(count))
        else:
            sheet = ws1
        sheet = write_sheet(sheet, sheet_cell_dict, list_sheet_merge)
        count+=1
    cv2.imwrite(os.path.join(out_path, name), img)
    plt.imshow(img)
    plt.show()
    wb.save(os.path.join(out_path,"{}.xlsx".format(name.split(".pdf.jpg")[0])))
    # wb.save("{}.xlsx".format(name.split(".pdf.jpg")[0]))

Output hidden; open in https://colab.research.google.com to view.