In [None]:
import os
import cv2
import json
import math

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

In [None]:
path = '/workspace/OCR/datasets/TAL_OCR/190326_4122'
img_path = 'test/'
json_path = 'toc_jsons/'
img_c_path = os.path.join(path, img_path)
json_c_path = os.path.join(path, json_path)
json_names = os.listdir(json_c_path)
images_datas = []
for json_name in json_names:
    json_path = os.path.join(json_c_path, json_name)
    json_file = json.loads(open(json_path).read())
    images_datas_temp = json_file
    images_datas = images_datas + images_datas_temp
img_size = 512
color = {
    '横式': (255, 0, 0),
    '竖式': (255, 255, 0),
    '脱式': (255, 0, 255),
    '解式': (0, 255, 0),
    '题干': (0, 255, 255),
    '答案': (0, 0, 255)
}

feat_shapes=[(128, 128), (64, 64), (32, 32), (16, 16), (8, 8), (8, 8), (8, 8)]
anchor_steps=[4, 8, 16, 32, 64, 64, 64]

# feat_shapes=[(64, 64), (32, 32), (16, 16), (8, 8), (4, 4), (2, 2), (1, 1)]
# anchor_steps=[8, 16, 32, 64, 128, 256, 512]
img_shape = (img_size, img_size)
# smin = 0.05
# smax = 0.6
# m = len(feat_shapes) - 1
# anchor_sizes = []
# a1 = smin * img_size * 0.6
# for i in range(1, m+2):
#     a2 = 512 * (smin + (smax - smin)/(m-1) * (i -1))
#     anchor_sizes.append((a1, a2))
#     print((a1, a2))
#     a1 = a2
    
# anchor_sizes=[7,
#               14,
#               28,
#               56,
#               112,
#               224,
#               448]
# anchor_sizes=[(16, 32),
#               (32, 64),
#               (64, 128),
#               (128, 256),
#               (256, 350),
#               (350, 450),
#               (450, 500)]
anchor_sizes=[(12, 25),
              (25, 50),
              (50, 100),
              (100, 200),
              (200, 300),
              (300, 400),
              (400, 500)]
anchor_ratios=[[2, 3, 5,  7, 10, 15],
               [2, .5, 3, 5, 7, 10, 15],
               [2, .5, 3, 5, 7, 10, 15],
               [2, .5, 3, 5, 7, 10, 15],
               [2, .5, 3, 5],
               [2, .5, 3, 5],
               [2, .5, 3, 5]]
label_name = ['横式', '竖式', '脱式', '解式']



In [None]:
def is_contain(coord1, coord2):
    if coord1[0] > (coord2[0] - 0.03) and coord1[1] > (coord2[1] - 0.03)\
        and coord1[2] < (coord2[2] + 0.03) and coord1[3] < (coord2[3] + 0.03):
        return True
    return False
def get_coord(mark_datas):
    verts = []
    verts_index = []
    final_boxes = []
    mask_regions = []
    for mark_data in mark_datas:
        label = mark_data['markd_label']
        if label == '竖式':
            bbox = mark_data['marked_path'].split()
            ymin, xmin, ymax, xmax = get_rect(bbox)
            verts.append((ymin, xmin, ymax, xmax))

    for mark_data in mark_datas:
        label = mark_data['markd_label']
        if label == None:
            continue
        bbox = mark_data['marked_path']
        if bbox == None:
            continue
        bbox = bbox.split()
        if len(bbox) != 9:
            continue
        ymin, xmin, ymax, xmax = get_rect(bbox)
        coord = (ymin, xmin, ymax, xmax)
        if label == '脱式' or label == '解方程' or label == '解式':\
            final_boxes.append((label, ymin, xmin, ymax, xmax))
        elif label == '横式':
            if ymax - ymin < 0.035 and xmax - xmin <0.25:
                final_boxes.append((label, ymin, xmin, ymax, xmax))
            else:
                mask_regions.append(coord)
        elif label == '答案':
            for i, vert in enumerate(verts):
                if is_contain(coord, vert):
                    final_boxes.append(('竖式', ymin, xmin, ymax, xmax))
                    verts_index.append(i)
                    break
        elif label == '题干':
            for i, vert in enumerate(verts):
                if is_contain(coord, vert):
                    final_boxes.append(('横式', ymin, xmin, ymax, xmax))
                    # print(ymax - ymin)
                    break

    for i, vert in enumerate(verts):
        if i not in verts_index:
            final_boxes.append(('竖式', vert[0], vert[1], vert[2], vert[3]))
    return final_boxes, mask_regions

In [None]:
def get_rect(bbox):
    x1 = float(bbox[0][1:])
    y1 = float(bbox[1])
    x2 = float(bbox[2][1:])
    y2 = float(bbox[3])
    x3 = float(bbox[4][1:])
    y3 = float(bbox[5])
    x4 = float(bbox[6][1:])
    y4 = float(bbox[7])
    xmin = min([x1, x2, x3, x4])
    xmax = max([x1, x2, x3, x4])
    ymin = min([y1, y2, y3, y4])
    ymax = max([y1, y2, y3, y4])
    xmin = max([xmin, 0])
    ymin = max([ymin, 0])
    xmax = min([xmax, 1])
    ymax = min([ymax, 1])
    return ymin, xmin, ymax, xmax

In [None]:
def ssd_anchor_one_layer(img_shape,
                         feat_shape,
                         sizes,
                         ratios,
                         step,
                         offset=0.5,
                         dtype=np.float32):
    y, x = np.mgrid[0:feat_shape[0], 0:feat_shape[1]]
    y = (y.astype(dtype) + offset) * step / img_shape[0]
    x = (x.astype(dtype) + 0.5) * step / img_shape[1]

    # Expand dims to support easy broadcasting.
    y = np.expand_dims(y, axis=-1)
    x = np.expand_dims(x, axis=-1)

    # Compute relative height and width.
    # Tries to follow the original implementation of SSD for the order.
    num_anchors = 2 + len(ratios)
    h = np.zeros((num_anchors, ), dtype=dtype)
    w = np.zeros((num_anchors, ), dtype=dtype)
    # Add first anchor boxes with ratio=1.
    h[0] = sizes[0] / img_shape[0]
    w[0] = sizes[0] / img_shape[1]
    di = 1
    if len(sizes) > 1:
        h[1] = math.sqrt(sizes[0] * sizes[1]) / img_shape[0]
        w[1] = math.sqrt(sizes[0] * sizes[1]) / img_shape[1]
        di += 1
    for i, r in enumerate(ratios):
        h[i+di] = sizes[0] / img_shape[0] / math.sqrt(r)
        w[i+di] = sizes[0] / img_shape[1] * math.sqrt(r)
#         if r >=1:
#             h[i+di] = sizes / img_shape[0]
#             w[i+di] = sizes / img_shape[1] * r
#         else:
#             h[i+di] = sizes / img_shape[0] / r
#             w[i+di] = sizes / img_shape[1] 
#         if r >=1:
#             h[i] = sizes / img_shape[0]
#             w[i] = sizes / img_shape[1] * r
#         else:
#             h[i] = sizes / img_shape[0] / r
#             w[i] = sizes / img_shape[1] 
    return y, x, h, w

In [None]:
def IOU_calculation(predict_box, ground_box):
    ground_box = [float(x * 512) for x in ground_box]
    predict_box = [x * 512 for x in predict_box]
#     print(predict_box)
    box_interaction = [max(ground_box[0], predict_box[0]), max(ground_box[1], predict_box[1]),
                       min(ground_box[2], predict_box[2]), min(ground_box[3], predict_box[3])]
    interaction_weight = box_interaction[2] - box_interaction[0]+1
    interaction_height = box_interaction[3] - box_interaction[1]+1
    if interaction_weight > 0 and interaction_height > 0:
        interaction_area = interaction_height * interaction_weight
        union_area = (ground_box[2] - ground_box[0]) * (ground_box[3] - ground_box[1]) + \
                    (predict_box[2] - predict_box[0]) * (predict_box[3] - predict_box[1]) - interaction_area
        return interaction_area / union_area
    else:
        return -1

In [None]:
def stat_anchors(y, x, h, w, img_size, final_boxes, flag):
#     print(img_shape)
    boxex = len(final_boxes)
    num = 0
    size = y.shape[0]
    anchor_num = len(h)
    
    for i, p_box in enumerate(final_boxes):
        if flag[i]:
            continue
        y_c = int(0.5 * (p_box[0] + p_box[2]) * size)
        x_c = int(0.5 * (p_box[1] + p_box[3]) * size)
#         p_box = p_box[1:5]
        for k in range(anchor_num):
            box = [y[y_c, x_c] - 0.5 * h[k], x[y_c, x_c] - 0.5 * w[k], y[y_c, x_c] + 0.5 * h[k], x[y_c, x_c] + 0.5 * w[k]]
            iou = IOU_calculation(box, p_box)
            if iou > 0.5:
                flag[i] = True 
                break
    return flag

In [None]:
total_boxes = 0
hit_boxes = 0

for image_datas in tqdm(images_datas):
    final_boxes = []
    mark_datas = image_datas['mark_datas']
    for mark_data in mark_datas:
        if mark_data['label'] == []:
            continue
        bbox = mark_data['marked_path']
        if bbox == None:
            continue
        bbox = bbox.split()
        if len(bbox) != 9:
            continue
        final_boxes.append(get_rect(bbox))
    n = len(final_boxes)
    flag = np.zeros(n)
    for i, sizes in enumerate(anchor_sizes):
        for j in range(1):
            y, x, h, w = ssd_anchor_one_layer(img_shape,
                                 feat_shapes[i],
                                 anchor_sizes[i],
                                 anchor_ratios[i],
                                 anchor_steps[i],
                                 offset=0.5 + 0.25 *j,
                                 dtype=np.float32)
            flag = stat_anchors(y, x, h, w, img_size, final_boxes, flag)
            if np.sum(flag) == n:
                break
        if np.sum(flag) == n:
            break
    
        
    hit_boxes += int(np.sum(flag))
    total_boxes += n
    if np.sum(flag) != n and os.path.exists(os.path.join(img_c_path, image_datas['pic_name'].split('/')[4])):
        print(image_datas['pic_name'].split('/')[4])
print(hit_boxes, total_boxes, hit_boxes/total_boxes)


In [None]:
def draw_anchors(img, y, x, h, w, img_size, final_boxes, flag):
    boxex = len(final_boxes)
    num = 0
    size = y.shape[0]
    anchor_num = len(h)
    for i, p_box in enumerate(final_boxes):
        if flag[i]:
            continue
        y_c = int(0.5 * (p_box[1] + p_box[3]) * size)
        x_c = int(0.5 * (p_box[2] + p_box[4]) * size)
        p_box = p_box[1:5]
        thre = 0.5
        for k in range(anchor_num):
            box = [y[y_c, x_c] - 0.5 * h[k], x[y_c, x_c] - 0.5 * w[k], y[y_c, x_c] + 0.5 * h[k], x[y_c, x_c] + 0.5 * w[k]]
            iou = IOU_calculation(box, p_box)
            if iou > thre:
                box_temp = box                
#                 cv2.rectangle(img, (int(box[1] * img_size), int(box[0] * img_size)),
#                               (int(box[3] * img_size), int(box[2] * img_size)), (0, 255, 255), 1)
                thre = iou
        if thre > 0.5:
#             print(thre)
#             cv2.rectangle(img, (int(box_temp[1] * img_size), int(box_temp[0] * img_size)),
#                               (int(box_temp[3] * img_size), int(box_temp[2] * img_size)), (0, 255, 255), 1)
            flag[i] = True
    return flag

In [None]:
for iters, image_datas in enumerate(images_datas):
    if not image_datas['pic_name'].split('/')[4] == 'IMG_20181229_122653.jpg':
        continue
    print(iters, image_datas['pic_name'])
    final_boxes = []
    mark_datas = image_datas['mark_datas']
    for mark_data in mark_datas:
        label = mark_data['label']
        if label == []:
            continue
        
        children = label[0]["children"]
        if '文本' in children or '横式' in children:
            s = ('横式')
        else:
            s = children[0]
        bbox = mark_data['marked_path']
        if bbox == None:
            continue
        bbox = bbox.split()
        if len(bbox) != 9:
            continue
        q,w,e,r = get_rect(bbox)
        final_boxes.append([s,q,w,e,r])
    img = cv2.imread(os.path.join(img_c_path, image_datas['pic_name'].split('/')[4]))
    img = cv2.resize(img, (img_size, img_size))
    flag = np.zeros(len(final_boxes))
#     img = np.zeros((img_size, img_size, 3))
    n = len(final_boxes)
    for i, sizes in enumerate(anchor_sizes):
        y, x, h, w = ssd_anchor_one_layer(img_shape,
                             feat_shapes[i],
                             anchor_sizes[i],
                             anchor_ratios[i],
                             anchor_steps[i],
                             offset=0.5,
                             dtype=np.float32)
        flag = draw_anchors(img, y, x, h, w, img_size, final_boxes, flag)
        
    
    for t, box in enumerate(final_boxes):
        if flag[t]:
            continue
        print((box[3] - box[1])*512, (box[4] - box[2])*512, (box[4] - box[2])/(box[3] - box[1]))
        label = box[0]
        cv2.rectangle(img, (int(box[2] * img_size), int(box[1] * img_size)),
                      (int(box[4] * img_size), int(box[3] * img_size)), color[label], 1)
#         s = '%.3f/%.3f' % ((box[3] - box[1])*512, (box[4] - box[2])*512)
#         p1 = (int(box[2] * img_size)-5, int(box[1] * img_size))
#         cv2.putText(img, s, p1, cv2.FONT_HERSHEY_DUPLEX, 0.5, color[label], 1)
    plt.figure(figsize = (15, 15))
    plt.axis('equal')
#     img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.show()
    break