In [None]:
# export environment variables
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import sys

sys.path.insert(0, "./text_detection/src")
sys.path.append("./classification/src")
sys.path.append("./detection/src")
sys.path.append("./text_recognition/src")

import numpy as np
import pandas as pd
import os

from classification.core import ClassificationModel
from detection.core import ObjectDetectionModel
from text_recognition.core import TextRecognitionModel
from text_detection.core import TextDetectionModel
from postprocessing.core import Postprocessing



In [None]:
image_folder = "./data/validation/images"
image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x]

In [None]:
graph_classfication_config = {
    "model_name": "resnet50",
    "n_classes": 5,
    "weights_path": "./weights/graph_classification.pth",
}

x_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": "./weights/x_type_classification.pth",
}

y_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": "./weights/y_type_classification.pth",
}

keypoint_detection_config = {
    "name": "keypoint_detection",
    "experiment_path": "./detection/src/exps/example/custom/bmga.py",
    "weights_path": "./weights/keypoint_detection.pth",
    "classes": ["value", "x", "y", "x_label", "y_label"],
    "conf_thre": 0.15,
    "nms_thre": 0.25,
    "test_size": (640, 640),
}

text_detection_config = {
    "weights_path": "./weights/synthtext_finetune_ic19_res50_dcn_fpn_dbv2",
    # "config_path": "/home/thanh/bmga/text_detection/src/experiments/seg_detector/totaltext_resnet50_deform_thre.yaml",
    "config_path": "/home/thanh/bmga/text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf.yaml",
    "image_short_side": 2048,
    "thresh": 0.1,
    "box_thresh": 0.05,
    "resize": False,
    "polygon": True,
}

text_recognition_config = {
    "weights_path": "baudm/parseq",
    "model_name": "parseq",
}


classification_model = ClassificationModel(**graph_classfication_config)
x_type_classification_model = ClassificationModel(**x_type_classification_config)
y_type_classification_model = ClassificationModel(**y_type_classification_config)
keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)
# text_detection_model = TextDetectionModel(**text_detection_config)
text_recognition_model = TextRecognitionModel(**text_recognition_config)

In [None]:
# read ground truth from /home/thanh/bmga/data/validation/metadata.jsonl
import json

with open("/home/thanh/bmga/data/validation/metadata.jsonl", "r") as f:
    metadata = [json.loads(x) for x in f.readlines()]

metadata_dict = {}
for x in metadata:
    metadata_dict[x["file_name"]] = x

In [None]:
keypoint_detection_results = keypoint_detection_model.predict(image_paths)

In [None]:
from matplotlib import pyplot as plt
import cv2

In [None]:
# ------ CURRENT ANALYSIS ------
# y column: nearly always correct -> outlier detection (draw a line: get the line that across most point -> then filter by str(number)) 
#                                 -> we will map this with y points, if there is a missing y point, we calculate based on y nearest points
#                   
# x column: if there is no overlap -> always correct
#           if there is overlap -> we need to use db/another model to predict

# horizontal bar chart: usually wrong

In [None]:
# visualize keypoint detection results, data is boxes
idx = (idx - 1) % len(image_paths)
# result_image = keypoint_detection_model.predictor.visual(keypoint_detection_results[0][idx][0], keypoint_detection_results[1][idx])

# from matplotlib import pyplot as plt
# plt.imshow(result_image)

data = keypoint_detection_results[0][idx][0].cpu().numpy()
x_label_boxes = (data[data[:, 6] == 3][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)
y_label_boxes = (data[data[:, 6] == 4][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)

x_label_boxes = filter_x_boxes(
    x_label_boxes,
    keypoint_detection_results[1][idx]["height"],
    image_paths[idx],
)

y_label_boxes = filter_y_boxes(
    y_label_boxes,
    keypoint_detection_results[1][idx]["width"],
    image_paths[idx],
)

# visualize x_label_boxes
image = cv2.imread(image_paths[idx])
for box in x_label_boxes:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 255), 2)
plt.imshow(image)

In [None]:
# visualize ground truth
image = cv2.imread(image_paths[idx])
for box in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_rects"]:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 255), 2)
plt.imshow(image)
print(metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["value"])

In [None]:
# # find the line in line chart using opencv
# def find_line(image):
#     # convert to grayscale
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#     image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

#     lower = np.array([0, 0, 0])
#     upper = np.array([180, 255, 120])
#     mask = cv2.inRange(image, lower, upper)

#     mask = cv2.erode(mask, np.ones((1, 1), np.uint8), iterations=1)
#     mask = cv2.dilate(mask, np.ones((3, 3), np.uint8), iterations=1)
#     return mask

# mask = find_line(image)
# plt.imshow(mask)

In [None]:
idx, image_paths[idx]

In [None]:
def filter_x_boxes(boxes, img_height, img_path):
    # first, draw a line along y axis then count the number of x_label_boxes that intersect with the line
    max_count = 0
    max_count_line_y = 0

    for line_y in range(img_height):
        count = 0
        for box in boxes:
            h = box[3] - box[1]
            if box[1] + h // 2 <= line_y <= box[3] - h // 2:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_y = line_y

    # filter out y_label_boxes that intersect with the line
    filtered_x_label_boxes = []
    for box in boxes:
        if box[1] <= max_count_line_y <= box[3]:
            filtered_x_label_boxes.append(box)

    return filtered_x_label_boxes

    # # second, do text recognition on y_label_boxes
    # image = cv2.imread(img_path)
    # crops = []
    # for box in filtered_y_label_boxes:
    #     crop = image[box[1]:box[3], box[0]:box[2]]
    #     crops.append(crop)

    # text_recognition_results = text_recognition_model.predict(crops)

    # # filter out those boxes that the values can't be converted to float: TODO: only case that y labels are numbers, have to update
    # filtered_y_label_boxes_2 = []
    # for i, box in enumerate(filtered_y_label_boxes):
    #     try:
    #         text = "".join([c for c in text_recognition_results[0][i][0] if c in "0123456789."])
    #         if not text:
    #             float(text)
    #         filtered_y_label_boxes_2.append(box)
    #     except:
    #         pass

    # return filtered_y_label_boxes_2


def filter_y_boxes(boxes, img_width, img_path):
    # first, draw a line along x axis then count the number of y_label_boxes that intersect with the line
    max_count = 0
    max_count_line_x = 0

    for line_x in range(img_width):
        count = 0
        for box in boxes:
            if box[0] <= line_x <= box[2]:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_x = line_x

    # filter out y_label_boxes that intersect with the line
    filtered_y_label_boxes = []
    for box in boxes:
        if box[0] <= max_count_line_x <= box[2]:
            filtered_y_label_boxes.append(box)

    # second, do text recognition on y_label_boxes
    image = cv2.imread(img_path)
    crops = []
    for box in filtered_y_label_boxes:
        crop = image[box[1]:box[3], box[0]:box[2]]
        crops.append(crop)

    text_recognition_results = text_recognition_model.predict(crops)

    # filter out those boxes that the values can't be converted to float: TODO: only case that y labels are numbers, have to update
    filtered_y_label_boxes_2 = []
    for i, box in enumerate(filtered_y_label_boxes):
        try:
            text = "".join([c for c in text_recognition_results[0][i][0] if c in "0123456789."])
            if not text:
                float(text)
            filtered_y_label_boxes_2.append(box)
        except:
            pass

    return filtered_y_label_boxes_2

data = keypoint_detection_results[0][idx][0].cpu().numpy()
x_label_boxes = (data[data[:, 6] == 3][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)
y_label_boxes = (data[data[:, 6] == 4][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)

filter_y_boxes(
    y_label_boxes,
    keypoint_detection_results[1][idx]["width"],
    image_paths[idx],
)

In [None]:
# calculate accuracy for y_labels boxes
def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    x_left = max(x1, x3)
    y_top = max(y1, y3)
    x_right = min(x2, x4)
    y_bottom = min(y2, y4)

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    intersection_area = (x_right - x_left) * (y_bottom - y_top)
    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x4 - x3) * (y4 - y3)
    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def calculate_label_boxes_accuracy(pred_boxes, gt_boxes, is_x_label=True, iou_thre=0.5):
    if len(pred_boxes) != len(gt_boxes):
        return 0
    
    if is_x_label:
        gt_boxes = sorted(gt_boxes, key=lambda x: x[0])
        pred_boxes = sorted(pred_boxes, key=lambda x: x[0])
    else:
        gt_boxes = sorted(gt_boxes, key=lambda x: x[1])
        pred_boxes = sorted(pred_boxes, key=lambda x: x[1])

    iou_score = 0
    for i in range(len(gt_boxes)):
        iou = calculate_iou(gt_boxes[i], pred_boxes[i])
        if iou > iou_thre:
            iou_score += 1

    if iou_score == len(gt_boxes):
        return 1

    return 0

In [None]:
x_acc = 0
y_acc = 0

for idx in range(len(image_paths)):
    data = keypoint_detection_results[0][idx][0].cpu().numpy()
    x_label_boxes = (data[data[:, 6] == 3][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)
    y_label_boxes = (data[data[:, 6] == 4][:, :4] / keypoint_detection_results[1][idx]["ratio"]).astype(int)

    x_label_boxes = filter_x_boxes(
        x_label_boxes,
        keypoint_detection_results[1][idx]["height"],
        image_paths[idx],
    )

    # y_label_boxes = filter_y_boxes(
    #     y_label_boxes,
    #     keypoint_detection_results[1][idx]["width"],
    #     image_paths[idx],
    # )
    
    x_acc += calculate_label_boxes_accuracy(
        x_label_boxes,
        metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_rects"],
        is_x_label=True,
    )
    
    y_acc += calculate_label_boxes_accuracy(
        y_label_boxes,
        metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["y_labels_rects"],
        is_x_label=False,    
    )

print("x_acc: ", x_acc / len(image_paths))
print("y_acc: ", y_acc / len(image_paths))

In [None]:
# classification_results = classification_model.predict(image_paths)
# x_type_classification_results = x_type_classification_model.predict(image_paths)
# y_type_classification_results = y_type_classification_model.predict(image_paths)
text_detection_results = text_detection_model.predict(image_paths)
polygons = text_detection_results[0][0][0]
path = image_paths[0]

# visualize
import cv2
import numpy as np
import matplotlib.pyplot as plt

img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

for polygon in polygons:
    polygon = np.array(polygon).astype(np.int32)
    img = cv2.polylines(img, [polygon], True, (0, 255, 0), 2)

plt.imshow(img)
plt.show()