In [None]:
%load_ext autoreload
%autoreload 2

# export environment variables
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import sys

sys.path.insert(0, "./text_detection/src")
sys.path.append("./classification/src")
sys.path.append("./detection/src")
sys.path.append("./text_recognition/src")

import numpy as np
import pandas as pd
import os

from classification.core import ClassificationModel
from detection.core import ObjectDetectionModel
from text_recognition.core import TextRecognitionModel
from text_detection.core import TextDetectionModel
from postprocessing.core import Postprocessing

In [None]:
from matplotlib import pyplot as plt
import cv2
import numpy as np


In [None]:
image_folder = "./data/validation/images"
image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x]

In [None]:
graph_classfication_config = {
    "model_name": "resnet50",
    "n_classes": 5,
    "weights_path": "./weights/graph_classification.pth",
}

x_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": "./weights/x_type_classification.pth",
}

y_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": "./weights/y_type_classification.pth",
}

keypoint_detection_config = {
    "name": "keypoint_detection",
    "experiment_path": "./detection/src/exps/example/custom/bmga.py",
    "weights_path": "./weights/keypoint_detection.pth",
    "classes": ["value", "x", "y", "x_label", "y_label"],
    "conf_thre": 0.15,
    "nms_thre": 0.25,
    "test_size": (640, 640),
}

text_detection_config = {
    "weights_path": "./weights/synthtext_finetune_ic19_res50_dcn_fpn_dbv2",
    # "config_path": "/home/thanh/bmga/text_detection/src/experiments/seg_detector/totaltext_resnet50_deform_thre.yaml",
    "config_path": "/home/thanh/bmga/text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf.yaml",
    "image_short_side": 768,
    "thresh": 0.1,
    "box_thresh": 0.05,
    "resize": False,
    "polygon": True,
}

x_labels_text_detection_config = {
    "weights_path": "./weights/db_x_labels",
    "config_path": "/home/thanh/bmga/text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf.yaml",
    "image_short_side": 768,
    "thresh": 0.15,
    "box_thresh": 0.25,
    "resize": False,
    "polygon": True,
}

y_labels_text_detection_config = {
    "weights_path": "./weights/db_y_labels",
    "config_path": "/home/thanh/bmga/text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf.yaml",
    "image_short_side": 768,
    "thresh": 0.05,
    "box_thresh": 0.25,
    "resize": False,
    "polygon": True,
}

text_recognition_config = {
    "weights_path": "baudm/parseq",
    "model_name": "parseq",
}

graph_classification_model = ClassificationModel(**graph_classfication_config)
x_type_classification_model = ClassificationModel(**x_type_classification_config)
y_type_classification_model = ClassificationModel(**y_type_classification_config)
keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)
text_detection_model = TextDetectionModel(**text_detection_config)
x_labels_text_detection_model = TextDetectionModel(**x_labels_text_detection_config)
y_labels_text_detection_model = TextDetectionModel(**y_labels_text_detection_config)
text_recognition_model = TextRecognitionModel(**text_recognition_config)

In [None]:
# read ground truth from /home/thanh/bmga/data/validation/metadata.jsonl
import json

with open("/home/thanh/bmga/data/validation/metadata.jsonl", "r") as f:
    metadata = [json.loads(x) for x in f.readlines()]

metadata_dict = {}
for x in metadata:
    metadata_dict[x["file_name"]] = x

### X/Y labels detection using DB model

In [None]:
# function to convert polygon points to smallest 4 points polygon
def convert_polygon_to_min_rect(polygon):
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.float32)
    rect = cv2.minAreaRect(polygon)
    box = cv2.boxPoints(rect)
    box = np.int0(box)

    return box

def crop_polygon_from_image(image, polygon):
    polygon = convert_polygon_to_min_rect(polygon)
    mask = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask, [polygon], 0, 255, -1, cv2.LINE_AA)
    out = 255 - np.zeros_like(image)
    out[mask == 255] = image[mask == 255]

    # return crop from image
    crop = out[np.min(polygon[:, 1]):np.max(polygon[:, 1]), np.min(polygon[:, 0]):np.max(polygon[:, 0])]
    return crop


# sample_image_path = image_paths[0]
# sample_image = cv2.imread(sample_image_path)
# sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

# sample_polygon = [[20, 20], [10, 100], [100, 200], [300, 40]]

# # draw polygon
# sample_image = cv2.polylines(sample_image, [np.array(sample_polygon)], True, (0, 255, 0), 2)
# plt.imshow(sample_image)

In [None]:
# crop = crop_polygon_from_image(sample_image, sample_polygon)
# plt.imshow(crop)

In [None]:
def filter_x_polygons(polygons, img_height, img_path):
    # first, draw a line along y axis then count the number of x_label_boxes that intersect with the line
    max_count = 0
    max_count_line_y = 0

    for line_y in range(img_height):
        count = 0
        for polygon in polygons:
            min_y = min([x[1] for x in polygon])
            max_y = max([x[1] for x in polygon])
            h = max_y - min_y
            if min_y <= line_y <= max_y:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_y = line_y

    # filter out y_label_boxes that intersect with the line
    filtered_x_label_polygons = []
    for polygon in polygons:
        min_y = min([x[1] for x in polygon])
        max_y = max([x[1] for x in polygon])
        h = max_y - min_y
        if min_y <= max_count_line_y <= max_y:
            filtered_x_label_polygons.append(polygon)

    return filtered_x_label_polygons


def filter_y_polygons(polygons, img_width, image):
    # first, draw a line along x axis then count the number of y_label_boxes that intersect with the line
    max_count = 0
    max_count_line_x = 0

    for line_x in range(img_width):
        count = 0
        for polygon in polygons:
            min_x = min([x[0] for x in polygon])
            max_x = max([x[0] for x in polygon])
            w = max_x - min_x
            if min_x + w // 4 <= line_x <= max_x - w // 4:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_x = line_x

    # filter out y_label_boxes that intersect with the line
    filtered_y_label_polygons = []
    for polygon in polygons:
        min_x = min([x[0] for x in polygon])
        max_x = max([x[0] for x in polygon])
        if min_x <= max_count_line_x <= max_x:
            filtered_y_label_polygons.append(polygon)

    return filtered_y_label_polygons
    # # second, do text recognition on y_label_boxes
    # crops = []
    # for polygon in filtered_y_label_polygons:
    #     crop = crop_polygon_from_image(image, polygon)
    #     crops.append(crop)

    # text_recognition_results = text_recognition_model.predict(crops)

    # # filter out those boxes that the values can't be converted to float: TODO: only case that y labels are numbers, have to update
    # filtered_y_label_boxes_2 = []
    # for i, box in enumerate(filtered_y_label_polygons):
    #     try:
    #         text = "".join([c for c in text_recognition_results[0][i][0] if c in "0123456789."])
    #         if not text:
    #             float(text)
    #         filtered_y_label_boxes_2.append(box)
    #     except:
    #         pass

    # return filtered_y_label_boxes_2

def calculate_iou(polygon1, polygon2, image):
    # calculate iou between two polygons
    polygon1 = np.array(polygon1)
    polygon2 = np.array(polygon2)
    polygon1 = polygon1.reshape(-1, 2)
    polygon2 = polygon2.reshape(-1, 2)
    polygon1 = polygon1.astype(np.float32)
    polygon2 = polygon2.astype(np.float32)

    rect1 = cv2.minAreaRect(polygon1)
    box1 = cv2.boxPoints(rect1)
    box1 = np.int0(box1)

    rect2 = cv2.minAreaRect(polygon2)
    box2 = cv2.boxPoints(rect2)
    box2 = np.int0(box2)

    mask1 = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask1, [box1], 0, 255, -1, cv2.LINE_AA)
    mask2 = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask2, [box2], 0, 255, -1, cv2.LINE_AA)

    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou_score = np.sum(intersection) / np.sum(union)

    return iou_score

def calculate_label_polygons_accuracy(pred_polygons, gt_polygons, image, is_x_label=True, iou_thre=0.5):
    if len(pred_polygons) != len(gt_polygons):
        return 0
    
    if is_x_label:
        gt_polygons = sorted(gt_polygons, key=lambda x: min([y[0] for y in x]))
        gt_polygons = sorted(pred_polygons, key=lambda x: min([y[0] for y in x]))
    else:
        gt_polygons = sorted(gt_polygons, key=lambda x: min([y[1] for y in x]))
        gt_polygons = sorted(pred_polygons, key=lambda x: min([y[1] for y in x]))

    iou_score = 0
    for i in range(len(gt_polygons)):
        iou = calculate_iou(gt_polygons[i], gt_polygons[i], image)
        if iou > iou_thre:
            iou_score += 1

    if iou_score == len(gt_polygons):
        return 1

    return 0

In [None]:
x_labels_predictions = x_labels_text_detection_model.predict(image_paths=image_paths)
y_labels_predictions = y_labels_text_detection_model.predict(image_paths=image_paths)

In [None]:
from tqdm import tqdm
# calucate accuracy
x_acc = 0
y_acc = 0

for idx in tqdm(range(len(image_paths))):
    image = cv2.imread(image_paths[idx])

    x_labels_polygons = x_labels_predictions[idx][0][0]
    y_labels_polygons = y_labels_predictions[idx][0][0]

    x_labels_polygons = filter_x_polygons(
        x_labels_polygons,
        image.shape[0],
        image_paths[idx],
    )

    y_labels_polygons = filter_y_polygons(
        y_labels_polygons,
        image.shape[1],
        image
    )
    
    x_acc += calculate_label_polygons_accuracy(
        x_labels_polygons,
        metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_polygons"],
        image=image,
        is_x_label=True,
    )
    
    y_acc += calculate_label_polygons_accuracy(
        y_labels_polygons,
        metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["y_labels_polygons"],
        image=image,
        is_x_label=False,    
    )

print("x_acc: ", x_acc / len(image_paths))
print("y_acc: ", y_acc / len(image_paths))

# x_acc:  0.8872987477638641
# y_acc:  0.8354203935599285


In [None]:
# visualize keypoint detection results, data is boxes
idx = 137
# idx = (idx + 1) % len(image_paths)
image = cv2.imread(image_paths[idx])
x_labels_polygons = x_labels_predictions[idx][0][0]
y_labels_polygons = y_labels_predictions[idx][0][0]


x_labels_polygons = filter_x_polygons(
    x_labels_polygons,
    image.shape[0],
    image_paths[idx],
)

y_labels_polygons = filter_y_polygons(
    y_labels_polygons,
    image.shape[1],
    image
)

# visualize x_label_boxes
image = cv2.imread(image_paths[idx])
for polygon in x_labels_polygons:
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (0, 255, 0), 2)

# visualize y_label_boxes
for polygon in y_labels_polygons:
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (0, 0, 255), 2)

plt.figure(figsize=(10, 10))
plt.imshow(image)
print(idx, image_paths[idx])

In [None]:
# visualize ground truth
image = cv2.imread(image_paths[idx])
for polygon in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["y_labels_polygons"]:
    x0, y0, x1, y1, x2, y2, x3, y3 = polygon["x0"], polygon["y0"], polygon["x1"], polygon["y1"], polygon["x2"], polygon["y2"], polygon["x3"], polygon["y3"]
    polygon = np.array([[x0, y0], [x1, y1], [x2, y2], [x3, y3]])
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (0, 0, 255), 2)

for polygon in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_polygons"]:
    x0, y0, x1, y1, x2, y2, x3, y3 = polygon["x0"], polygon["y0"], polygon["x1"], polygon["y1"], polygon["x2"], polygon["y2"], polygon["x3"], polygon["y3"]
    polygon = np.array([[x0, y0], [x1, y1], [x2, y2], [x3, y3]])
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (0, 255, 0), 2)
plt.imshow(image)
print(metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["value"])

### Graph classification model, x/y labels classification model

In [None]:
graph_classes = ['dot', 'line', 'scatter', 'vertical_bar', "horizontal_bar"]

graph_type_predictions = graph_classification_model.predict(image_paths=image_paths)

# convert predictions to graph type
graph_type_predictions = np.argmax(graph_type_predictions, axis=1)
graph_type_predictions = [graph_classes[i] for i in graph_type_predictions]

In [None]:
gt_classes = []

for image_path in image_paths:
    gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["class"])

# calculate accuracy
acc = 0
for idx in range(len(image_paths)):
    if graph_type_predictions[idx] == gt_classes[idx]:
        acc += 1

print("acc: ", acc / len(image_paths))
print(np.unique(gt_classes, return_counts=True))

In [None]:
type_classes = ["numerical", "categorical"]

x_type_predictions = x_type_classification_model.predict(image_paths=image_paths)
x_type_predictions = np.argmax(x_type_predictions, axis=1)
x_type_predictions = [type_classes[i] for i in x_type_predictions]

y_type_predictions = y_type_classification_model.predict(image_paths=image_paths)
y_type_predictions = np.argmax(y_type_predictions, axis=1)
y_type_predictions = [type_classes[i] for i in y_type_predictions]

In [None]:
x_type_gt_classes = []
y_type_gt_classes = []

for image_path in image_paths:
    x_type_gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["x_type"])
    y_type_gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["y_type"])

# calculate accuracy
x_type_acc = 0
y_type_acc = 0
for idx in range(len(image_paths)):
    if x_type_predictions[idx] == x_type_gt_classes[idx]:
        x_type_acc += 1
    if y_type_predictions[idx] == y_type_gt_classes[idx]:
        y_type_acc += 1

print("x_type_acc: ", x_type_acc / len(image_paths))
print("y_type_acc: ", y_type_acc / len(image_paths))

### Object detection model to detect point on graphs
1. Detect x_labels and y_labels points
2. Map these points with x_labels and y_labels texts
3. Post processing depends on the graph type

In [None]:
keypoint_predictions = keypoint_detection_model.predict(image_paths=image_paths)

In [None]:
idx = (idx + 1) % len(image_paths)
data = keypoint_predictions[0][idx][0].cpu().numpy()

value_boxes = (data[data[:, 6] == 0][:, :4] / keypoint_predictions[1][idx]["ratio"]).astype(int)
x_boxes = (data[data[:, 6] == 1][:, :4] / keypoint_predictions[1][idx]["ratio"]).astype(int)
y_boxes = (data[data[:, 6] == 2][:, :4] / keypoint_predictions[1][idx]["ratio"]).astype(int)


# visualize value_boxes, x_boxes, y_boxes
image = cv2.imread(image_paths[idx])
for box in value_boxes:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)

for box in x_boxes:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)

for box in y_boxes:
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)

x_labels_polygons = x_labels_predictions[idx][0][0]
y_labels_polygons = y_labels_predictions[idx][0][0]

x_labels_polygons = filter_x_polygons(
    x_labels_polygons,
    image.shape[0],
    image_paths[idx],
)

y_labels_polygons = filter_y_polygons(
    y_labels_polygons,
    image.shape[1],
    image
)

# visualize x_label_boxes
for polygon in x_labels_polygons:
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (255, 255, 0), 2)

# visualize y_label_boxes
for polygon in y_labels_polygons:
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.int32)
    cv2.drawContours(image, [polygon], 0, (0, 255, 255), 2)

plt.figure(figsize=(8, 8))
plt.imshow(image)

# ground truth
print("-------- GROUND TRUTH ---------")
print("graph type: ", metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["class"])
for v in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["value"]:
    print(v)

In [None]:
# GENERAL RULE:
# 1. Filter x_points, y_points by draw a line_y, line_x
# 2. map 1-1 x_points and x_labels, should be overlap between boxes
# 3. if there is only one x_points, keep all the number in x_labels, and map based on connection to that x_points

# 4. Map x_labels to x_points closest and overlap, if there is missing x_points, then map based on the mean y of the x_points and mean x of the x_labels

# 5. draw Ox, Oy, max_x, max_y to filter out prediction outside of the graph
# 6. incase the x_label is skew, get the nearest point of 2 boxes

# ------------ VERTICAL BAR GRAPH -------------
# 1. we should prioritize value prediction, map 1-1 with x_points then if there is outlier x_points/values, ignore it, map with closest x2-x1 first, then y2-y1

# ------------ HORIZONTAL BAR GRAPH ------------

# ------------ SCATTER PLOT ------------
# 1. use value prediction

# ------------ LINE GRAPH ------------
# for the line graph, only keep prediction that is on the line, if number of predictions is equal to number of x_points, then keep all the predictions, else, map to the line to get data
# only predict for those x_labels that is on the line plot down to Ox

# ------------ DOT PLOT ------------


### Line chart postprocessing

In [None]:
# # find the line in line chart using opencv
# def find_line(image):
#     # convert to grayscale
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#     image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

#     lower = np.array([0, 0, 0])
#     upper = np.array([180, 255, 120])
#     mask = cv2.inRange(image, lower, upper)

#     mask = cv2.erode(mask, np.ones((1, 1), np.uint8), iterations=1)
#     mask = cv2.dilate(mask, np.ones((3, 3), np.uint8), iterations=1)
#     return mask

# mask = find_line(image)
# plt.imshow(mask)