In [None]:
import os
import shutil
import sys
import time

import cv2
import numpy as np
import tensorflow as tf

sys.path.append("/home/tal-c/Src/text-detection-ctpn")
from nets import model_train as model
from utils.rpn_msr.proposal_layer import proposal_layer
from utils.text_connector.detectors import TextDetector

In [None]:
from matplotlib import pyplot as plt

In [None]:
CHECKPOINT_PATH = "/home/tal-c/Src/text-detection-ctpn/checkpoints_mlt"

def resize_image(img):
    img_size = img.shape
    im_size_min = np.min(img_size[0:2])
    im_size_max = np.max(img_size[0:2])
    # 短边/长边 对齐
    im_scale = float(600) / float(im_size_min)
    if np.round(im_scale * im_size_max) > 1200:
        im_scale = float(1200) / float(im_size_max)
    new_h = int(img_size[0] * im_scale)
    new_w = int(img_size[1] * im_scale)
    
    # 确保宽高均可被16整除
    new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16
    new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16

    re_im = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    return re_im, (new_h / img_size[0], new_w / img_size[1])


def detect_line(raw_image):
    tf.reset_default_graph()

    with tf.get_default_graph().as_default():
        input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
        input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')

        global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)

        bbox_pred, cls_pred, cls_prob = model.model(input_image)

        variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(CHECKPOINT_PATH)
            model_path = os.path.join(CHECKPOINT_PATH, os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            img, (rh, rw) = resize_image(raw_image)
            h, w, c = img.shape
            im_info = np.array([h, w, c]).reshape([1, 3])
            bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
                                                feed_dict={input_image: [img],
                                                            input_im_info: im_info})

            textsegs, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info)
            scores = textsegs[:, 0]
            textsegs = textsegs[:, 1:5]

            textdetector = TextDetector(DETECT_MODE='O')
            boxes = textdetector.detect(textsegs, scores[:, np.newaxis], img.shape[:2])
            boxes = np.array(boxes, dtype=np.int)

            ### Viz part
            # cost_time = (time.time() - start)
            # print("cost time: {:.2f}s".format(cost_time))

            # for i, box in enumerate(boxes):
            #     cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
            #                 thickness=2)
            # img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
            # cv2.imwrite(os.path.join(FLAGS.output_path, os.path.basename(im_fn)), img[:, :, ::-1])

            # with open(os.path.join(FLAGS.output_path, os.path.splitext(os.path.basename(im_fn))[0]) + ".txt",
            #         "w") as f:
            #     for i, box in enumerate(boxes):
            #         line = ",".join(str(box[k]) for k in range(8))
            #         line += "," + str(scores[i]) + "\r\n"
            #         f.writelines(line)
            ###

    return boxes,img,(rh, rw)



In [None]:
import det_layout

In [None]:
def first_tag_merge(boxes):
    if len(boxes) == 0:
        return boxes
    
    boxes_np = np.array(boxes)
    # print(boxes_np.shape)
    center_x = (boxes_np[:,4] + boxes_np[:,0])/2.0
    center_y = (boxes_np[:,5] + boxes_np[:,1])/2.0
    
    mean_x = np.mean(center_x)
    mean_y = np.mean(center_y)
    
    height_mean = np.mean(boxes_np[:,5]-boxes_np[:,1])
    width_mean = np.mean(boxes_np[:,4]-boxes_np[:,0])
    
    sort_idx = np.argsort(boxes_np[:,1])
    
    result_boxes = [[boxes_np[sort_idx[0]]]]
    for i in range(boxes_np.shape[0]-1):
        tmp_result = result_boxes[-1]
        print("center_y",center_y[sort_idx[i+1]], center_y[sort_idx[i]])
        if center_y[sort_idx[i+1]] - center_y[sort_idx[i]] < height_mean * 1.5:
            if len(tmp_result) == 0:
                result_boxes[-1].append(boxes_np[sort_idx[i]])
                result_boxes[-1].append(boxes_np[sort_idx[i+1]])
            elif (result_boxes[-1][-1] == boxes_np[sort_idx[i]]).all():
                result_boxes[-1].append(boxes_np[sort_idx[i+1]])
            else:
                tmp_result = [boxes_np[sort_idx[i]],boxes_np[sort_idx[i+1]]]
                result_boxes.append(tmp_result)
    
    final_result = []
    for i in range(len(result_boxes)):
        tmp_box_np = np.array(result_boxes[i])
        tmp_x1 = np.min(tmp_box_np[:,0])
        tmp_y1 = np.min(tmp_box_np[:,1])
        
        tmp_x2 = np.max(tmp_box_np[:,2])
        tmp_y2 = np.min(tmp_box_np[:,3])
        
        tmp_x3 = np.max(tmp_box_np[:,4])
        tmp_y3 = np.max(tmp_box_np[:,5])
        
        tmp_x4 = np.min(tmp_box_np[:,6])
        tmp_y4 = np.max(tmp_box_np[:,7])
        
        tmp_box = np.array([tmp_x1,tmp_y1,tmp_x2,tmp_y2,tmp_x3,tmp_y3,tmp_x4,tmp_y4,0])
        final_result.append(tmp_box)
    return final_result
            
    
            
            
        

In [None]:
SHORT_TAG = 20

def box_transfer(line_boxes):
    tmp_boxes = []
    for i in range(line_boxes.shape[0]):
        left_p_x = (line_boxes[i][0]+line_boxes[i][6])/2
        right_p_x = (line_boxes[i][2]+line_boxes[i][4])/2
        left_p_y = (line_boxes[i][1]+line_boxes[i][7])/2
        right_p_y = (line_boxes[i][3]+line_boxes[i][5])/2
        angle = 0.0
        height = ((line_boxes[i][7]-line_boxes[i][1])+(line_boxes[i][5]-line_boxes[i][3]))/2.0
        width = ((line_boxes[i][2]-line_boxes[i][0])+(line_boxes[i][4]-line_boxes[i][6]))/2.0
        tmp_boxes.append([left_p_x,left_p_y,right_p_x,right_p_y,width,height,angle])
    return tmp_boxes

# print(box_transfer(test_boxes))

def tag_analyst(image,line_boxes):
    
    line_det_result = box_transfer(line_boxes)
    h,w,_ = image.shape
    tag_result = np.zeros((w))
    print("shape",h,w)
    
    # left-to-right detection
#     width_div = [0,w//4,w//2,3*w//4,w]
#     for w_idx in range(len(width_div)-1):
#         t_left = width_div[w_idx]
#         t_right = width_div[w_idx+1]
#         for 
    # print(w,h,len(line_det_result))
    for i in range(w):
        tmp_count = 0
        for j in range(len(line_det_result)):
            tmp_left_x = line_det_result[j][0]
            tmp_right_x = line_det_result[j][2]
            if i>tmp_left_x and i<tmp_right_x:
                tmp_count += 1
        tag_result[i] = tmp_count
        
    tmp_tag = np.array(tag_result)
    tag_dict = {}
    count = 1
    while np.max(tmp_tag) > 0:
        tmp_tag -= 1.0
        z_idx = np.where(tmp_tag==0.0)
        # print(z_idx)
        if z_idx[0].shape[0] > 0:
            tag_dict[count] = z_idx[0]
            # print(count,z_idx)
        count += 1
    # print(tag_dict)
    tag = {}
    count = 1
    for t,coord in tag_dict.items():
        tmp_left_idx = 0
        tmp_right_idx = 0
        tag_left_idx = 0
        tag_right_idx = 0
        tmp_len = 0
        max_len = 0
        for m in range(coord.shape[0]-1):
            if (coord[m+1] - coord[m]) == 1:
                tmp_right_idx = m+1
            else:
                # print(tmp_right_idx, tmp_left_idx)
                tmp_len = tmp_right_idx - tmp_left_idx
                if tmp_len > SHORT_TAG:
                    tag_left_idx = tmp_left_idx
                    tag_right_idx = tmp_right_idx
                    # max_len = tmp_len
                    tag[count] = [coord[tag_left_idx],coord[tag_right_idx],tmp_len]
                    count += 1
                tmp_left_idx = tmp_right_idx
        if tmp_left_idx != tmp_right_idx:
            tmp_len = tmp_right_idx - tmp_left_idx
            if tmp_len > SHORT_TAG:
                tag_left_idx = tmp_left_idx
                tag_right_idx = tmp_right_idx
                # max_len = tmp_len
                tag[count] = [coord[tag_left_idx],coord[tag_right_idx],tmp_len]
                count += 1
#         if max_len > SHORT_TAG:
#             tag[count] = [coord[tag_left_idx],coord[tag_right_idx],max_len]
#             count += 1
    # print(tmp_tag,tag_dict,tag)
    
    sorted_tag = sorted(tag.items(), key=lambda d: ((d[1][0]+d[1][1])/2.0,d[1][0]))
    final_boxes = {}
    final_boxes["1"] = []
    final_boxes["2"] = []
    final_boxes["others"] = []
    
    width_div = [0,w//3,w*2//3,w]
    # print("width_div",width_div)
    lr_result = []
    for w_idx in range(len(width_div)-1):
        t_left = width_div[w_idx]
        t_right = width_div[w_idx+1]
        tmp_lr = 0
        for t_idx in range(len(line_det_result)):
            t_c = (line_det_result[t_idx][0]+line_det_result[t_idx][2])/2.0
            if t_c > t_left and t_c < t_right:
                tmp_lr += 1
        lr_result.append(tmp_lr)
    # print("line_det_result",line_det_result)
    print("left-2-right: ",lr_result)
    ltor = False
    if lr_result[1] < lr_result[0] and lr_result[1] < lr_result[2]:
        ltor = False
    elif (lr_result[1]+lr_result[2] > lr_result[0]):# or (lr_result[0] + lr_result[1] < lr_result[2]):
        ltor = True
    
                
    
    # print(sorted_tag)
    b = 0
    if len(sorted_tag) < 2:
        final_boxes["1"].append(line_boxes[b])
        return tag,final_boxes,False
    while b < len(line_det_result):
    #for b in range(len(line_det_result)):
        # print("line_det_result",line_det_result[b],sorted_tag[0][1][0],sorted_tag[1][1][0])
        if line_det_result[b][0] >= sorted_tag[0][1][0]-21 and line_det_result[b][0] <= sorted_tag[0][1][1]:
            tmp_b = line_boxes[b].copy()
#             tmp_b[2] = sorted_tag[0][1][1]
#             tmp_b[4] = sorted_tag[0][1][1]
#             line_det_result[b][0] = sorted_tag[0][1][1]
#             line_boxes[b][0] = sorted_tag[0][1][1]
#             line_boxes[b][6] = sorted_tag[0][1][1]
            final_boxes["1"].append(tmp_b)
#             if line_boxes[b][2] - line_boxes[b][0] > SHORT_TAG:
#                 # final_boxes["others"].append(line_boxes[b])
#                 b -= 1
        elif line_det_result[b][0] >= sorted_tag[1][1][0]-21 and line_det_result[b][0] <= sorted_tag[1][1][1]:
            tmp_b = line_boxes[b].copy()
#             tmp_b[2] = sorted_tag[1][1][1]
#             tmp_b[4] = sorted_tag[1][1][1]
#             line_det_result[b][0] = sorted_tag[1][1][1]
#             line_boxes[b][0] = sorted_tag[1][1][1]
#             line_boxes[b][6] = sorted_tag[1][1][1]
            final_boxes["2"].append(tmp_b)
#             if line_boxes[b][2] - line_boxes[b][0] > SHORT_TAG:
#                 # final_boxes["others"].append(line_boxes[b])
#                 b -= 1
        else:
            final_boxes["others"].append(line_boxes[b])
        b+=1
        
    final_boxes["1"] = first_tag_merge(final_boxes["1"])
    
    return tag,final_boxes,ltor

# print(tag_analyst(test_img,test_boxes))

def get_interval_stats(bboxes):
    bboxes_np = np.array(bboxes)
    print(bboxes.shape)
    
    upper = (bboxes_np[:,1]+bboxes_np[:,3])/2.0
    bottom = (bboxes_np[:,5]+bboxes_np[:,7])/2.0
    
    center_y = (bboxes_np[:,5] + bboxes_np[:,1])/2.0
    
    y_idx = np.argsort(center_y)
    
    interval = []
    
    for i in range(y_idx.shape[0]-1):
        tmp_interval = np.abs(bottom[y_idx[i]] - upper[y_idx[i+1]])
        interval.append(tmp_interval)
        
    interval_mean = np.mean(interval)
    interval_std = np.std(interval)
    
    return interval_mean,interval_std
    
    
    
    
    

In [None]:
test_case = [[ 32, 195, 129, 200, 128, 231,  30, 226, 0]]

first_tag_merge(test_case)


In [None]:
import det_layout


test_final_boxes = {}
test_final_boxes["1"] = []
test_final_boxes["2"] = []
test_final_boxes["others"] = []

def test_func(img_path):
    test_img = cv2.imread(img_path)
    if test_img is None:
        return [0.0,0.0,0.0,0.0,0.0],test_final_boxes,None,None,0.0,0.0
    tmp_img = test_img[:, :, ::-1].copy()
    test_boxes,test_img,(rh,rw) = detect_line(tmp_img)
    
    result,score,overall_box = det_layout.det_layout(test_img)
    if test_boxes.shape[0] == 0:
        # print(test_boxes.shape,np.argmax((test_boxes[:,5]+test_boxes[:,7])/2.0))
        return score,test_final_boxes,None,None,0.0,0.0
    interval_mean,interval_std = get_interval_stats(test_boxes)
    tmp_test_box = (test_boxes[:,5]+test_boxes[:,7])/2.0
    
    tmp_y_base = max(test_boxes[np.argmax(tmp_test_box),5],test_boxes[np.argmax(tmp_test_box),7])
    canvas_base = overall_box[1][1]
    white_space_length = canvas_base - tmp_y_base
    src_img = test_img.copy()
    print(tmp_img.shape,overall_box,tmp_y_base,canvas_base)
    print("test_boxes shape",test_boxes.shape)
    
    
#     for i, box in enumerate(test_boxes):
#     #     cv2.rectangle(test_img,
#     #                  (int(box[0]),int(box[1])),
#     #                  (int(box[2]),int(box[3])),
#     #                  color=(0,255,0),
#     #                  thickness=1)
#         cv2.polylines(test_img, 
#                       [box[:8].astype(np.int32).reshape((-1, 1, 2))], 
#                       True, 
#                       color=(0, 255, 0),
#                       thickness=2)
    # test_img = cv2.resize(test_img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
#     cv2.imwrite("tmp_test.jpg",test_img)
#     plt.figure(figsize=(16,24),dpi=500)  
#     plt.subplot(1,1,1)
#     plt.imshow(test_img)
#     plt.show()

    print(test_boxes)

    tmp_tag,tmp_final_boxes,ltor = tag_analyst(src_img,test_boxes.copy())

    # print(test_boxes)
    tmp_show_img = src_img.copy()
    cv2.rectangle(tmp_show_img,(0,int(tmp_y_base)),(test_img.shape[0],canvas_base),color=(0,0,255))
    cv2.putText(tmp_show_img,"l2r:"+str(ltor),(30,30),cv2.FONT_HERSHEY_COMPLEX,1,(255,0,0),2)
    print(tmp_final_boxes)
    for k, box in tmp_final_boxes.items():
        for i in range(len(box)):
            tmp_box = box[i]
    #         if k =='2':
    #             print(tmp_box)
            # print(tmp_box[:8].astype(np.int32).reshape((-1, 1, 2)).shape)
            cv2.polylines(tmp_show_img, 
                          [tmp_box[:8].astype(np.int32).reshape((-1, 1, 2))], 
                          True, 
                          color=(0, 255, 0),
                          thickness=2)
            # print(tmp_box)
            cv2.putText(tmp_show_img,str(k),(int(tmp_box[0]),int(tmp_box[1])),cv2.FONT_HERSHEY_COMPLEX,0.5,(255,0,0),2)
    plt.figure(figsize=(16,24),dpi=50)  
    plt.subplot(1,1,1)
    plt.imshow(tmp_show_img)
    plt.show()
#     if ltor:
#         cv2.imwrite("../vis_demo/mm_{}".format(os.path.basename(img_path)),tmp_show_img)
#     else:
#         cv2.imwrite("../det_others/mm_{}".format(os.path.basename(img_path)),tmp_show_img)
    cv2.imwrite("../det_result/{}".format(os.path.basename(img_path)),tmp_show_img)
    return score,tmp_final_boxes,ltor,white_space_length,interval_mean,interval_std

    

In [None]:
import pandas as pd

data_dir = "/home/tal-c/Src/text-detection-ctpn/data/MindMap"

det_result = []

color_result = []
for img in os.listdir(data_dir):
    tmp_path = os.path.join(data_dir,img)
    print(tmp_path)
    s,fb,lr,wpl,imean,istd = test_func(tmp_path)
    # [area_p,relative_p,box_count,width_std,width_std,heig ht_mean]
    # print(fb["1"])
#     if s[0] >= 0.6 and s[1] >= 0.08 and len(fb["2"]) >=2:
#         tmp_color = "Green"
#     elif s[0] >= 0.35 and s[0] < 0.6 and s[1] >= 0.03 and s[1] < 0.08 and len(fb["2"]) >= 2:
#         tmp_color = "Blue"
#     elif s[0] > 0  and s[0] < 0.35 and s[1] > 0 and s[1] < 0.03 and len(fb["2"]) < 2:
#         tmp_color = "Yellow"
#     else:
#         tmp_color = "Red"
    tmp_save_path = os.path.join("/home/tal-c/Src/text-detection-ctpn/det_result",img)
    timg = cv2.imread(tmp_path)
    # cv2.imwrite(tmp_save_path,timg)
    tmp_result = [img,s[0],s[1],s[2],s[3],s[4],len(fb["1"]),len(fb["2"]),len(fb["others"]),lr,wpl,imean,istd]
    if len(fb["others"]) > 0:
        cv2.imwrite(os.path.join("/home/tal-c/Src/text-detection-ctpn/data/tag/True",img),timg)
    else:
        cv2.imwrite(os.path.join("/home/tal-c/Src/text-detection-ctpn/data/tag/False",img),timg)
        
    det_result.append(tmp_result)
    
    
    

heading=["img_name","area_p","relative_p","box_count","width_std","height_mean","1stTag","2ndTag","3rdTag","l2r","white_space_height","interval_mean","interval_std"]  
df = pd.DataFrame(det_result,columns=heading)
df.to_csv("det_result_l2r.csv")