In [None]:
TEST_MODE = False

if TEST_MODE:
    DATA_FOLDER = "/kaggle/input/benetech-making-graphs-accessible/test/"
    WEIGHTS_FOLDER = "/kaggle/input/bmgaweights/"
else:
    DATA_FOLDER = "./data/validation/"
    WEIGHTS_FOLDER = "./weights/"

In [None]:
%load_ext autoreload
%autoreload 2

# export environment variables
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys

sys.path.insert(0, "./text_detection/src")
sys.path.append("./classification/src")
sys.path.append("./detection/src")
sys.path.append("./segmentation/src")
sys.path.append("./text_recognition/src")

# Compute metrics code

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rapidfuzz.distance.Levenshtein import distance as levenshtein
from sklearn.metrics import r2_score

In [None]:
def sigmoid(x):
    return 2 - 2 / (1 + np.exp(-x))


def normalized_rmse(y_true, y_pred):
    # The argument to the sigmoid transform is equal to 
    # rmse(y_true, y_pred) / rmse(y_true, np.mean(y_true))
    return sigmoid((1 - r2_score(y_true, y_pred)) ** 0.5)


def normalized_levenshtein_score(y_true, y_pred):
    total_distance = np.sum([levenshtein(yt, yp) for yt, yp in zip(y_true, y_pred)])
    length_sum = np.sum([len(yt) for yt in y_true])
    return sigmoid(total_distance / length_sum)


def score_series(y_true, y_pred):
    if len(y_true) != len(y_pred):
        return 0.0
    if isinstance(y_true[0], str):
        return normalized_levenshtein_score(y_true, y_pred)
    else:
        return normalized_rmse(y_true, y_pred)


def benetech_score(ground_truth: pd.DataFrame, predictions: pd.DataFrame) -> float:
    """Evaluate predictions using the metric from the Benetech - Making Graphs Accessible.
    
    Parameters
    ----------
    ground_truth: pd.DataFrame
        Has columns `[data_series, chart_type]` and an index `id`. Values in `data_series` 
        should be either arrays of floats or arrays of strings.
    
    predictions: pd.DataFrame
    """
    if not ground_truth.index.equals(predictions.index):
        raise ValueError("Must have exactly one prediction for each ground-truth instance.")
    if not ground_truth.columns.equals(predictions.columns):
        raise ValueError(f"Predictions must have columns: {ground_truth.columns}.")
    pairs = zip(ground_truth.itertuples(index=False), predictions.itertuples(index=False))
    scores = []
    for (gt_series, gt_type), (pred_series, pred_type) in pairs:
        if gt_type != pred_type:  # Check chart_type condition
            scores.append(0.0)
        else:  # Score with RMSE or Levenshtein as appropriate
            scores.append(score_series(gt_series, pred_series))
    return np.mean(scores)

### Inference Code

In [None]:
import numpy as np
import pandas as pd
import os

from classification.core import ClassificationModel
from detection.core import ObjectDetectionModel
# from segmentation.core import SegmentationModel
from text_recognition.core import TextRecognitionModel
from text_detection.core import TextDetectionModel
from postprocessing.core import Postprocessing

In [None]:
from matplotlib import pyplot as plt
import cv2
import numpy as np


In [None]:
# image_folder = "./data/train/images"
# origin_image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x][:500]
# image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x][:500]
image_folder = os.path.join(DATA_FOLDER, "images")

origin_image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x]
image_paths = [os.path.join(image_folder, x) for x in os.listdir(image_folder) if ".jpg" in x]

In [None]:
graph_classfication_config = {
    "model_name": "resnet50",
    "n_classes": 5,
    "weights_path": os.path.join(WEIGHTS_FOLDER, "graph_classification.pth"),
}

x_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": os.path.join(WEIGHTS_FOLDER, "x_type_classification.pth"),
}

y_type_classification_config = {
    "model_name": "resnet50",
    "n_classes": 2,
    "weights_path": os.path.join(WEIGHTS_FOLDER, "y_type_classification.pth"),
}

line_segmentation_config = {
    "weights_path": os.path.join(WEIGHTS_FOLDER, "line_segmentation.pth"),
    "arch": "Unet",
    "encoder_name": "tf_efficientnetv2_b1",
    "drop_path": 0,
    "size": 512
}

keypoint_detection_config = {
    "name": "keypoint_detection",
    "experiment_path": "./detection/src/exps/example/custom/bmga.py",
    # "weights_path": os.path.join(WEIGHTS_FOLDER, "keypoint_detection.pth"),
    "weights_path": os.path.join(WEIGHTS_FOLDER, "keypoint_detection_yoloxx.pth"),
    "classes": ["value", "x", "y"], #, "x_label", "y_label"],
    "conf_thre": 0.15,
    "nms_thre": 0.25,
    "test_size": (640, 640),
}

text_detection_config = {
    "weights_path": os.path.join(WEIGHTS_FOLDER, "synthtext_finetune_ic19_res50_dcn_fpn_dbv2"),
    "config_path": "text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf_inference.yaml",
    "image_short_side": 768,
    "thresh": 0.1,
    "box_thresh": 0.05,
    "resize": False,
    "polygon": True,
}

x_labels_text_detection_config = {
    # "weights_path": os.path.join(WEIGHTS_FOLDER, "db_x_labels"),
    "weights_path": os.path.join(WEIGHTS_FOLDER, "db_x_labels_new"),
    "config_path": "./text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf_inference.yaml",
    "image_short_side": 768,
    "thresh": 0.3,
    "box_thresh": 0.5,
    "resize": False,
    "polygon": True,
}

y_labels_text_detection_config = {
    "weights_path": os.path.join(WEIGHTS_FOLDER, "db_y_labels"),
    "config_path": "./text_detection/src/experiments/ASF/td500_resnet50_deform_thre_asf_inference.yaml",
    "image_short_side": 768,
    "thresh": 0.05,
    "box_thresh": 0.25,
    "resize": False,
    "polygon": True,
}

text_recognition_config = {
    "weights_path": os.path.join(WEIGHTS_FOLDER, "parseq-bb5792a6.pt"),
    # "weights_path": "baudm/parseq",
    "model_name": "parseq",
    "config_path": os.path.join(WEIGHTS_FOLDER, "parseq_hparams.json"),
}

graph_classification_model = ClassificationModel(**graph_classfication_config)
x_type_classification_model = ClassificationModel(**x_type_classification_config)
y_type_classification_model = ClassificationModel(**y_type_classification_config)
# line_segmentation_model = SegmentationModel(**line_segmentation_config)

keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)

# keypoint_detection_config["weights_path"] = os.path.join(WEIGHTS_FOLDER, "keypoint_detection_yoloxx_xy.pth")
# xy_keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)

# keypoint_detection_config["conf_thre"] = 0.3
# keypoint_detection_config["nms_thre"] = 0.3
keypoint_detection_config["weights_path"] = os.path.join(WEIGHTS_FOLDER, "keypoint_detection_yoloxx_line_value.pth")
line_value_keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)

# keypoint_detection_config["conf_thre"] = 0.3
# keypoint_detection_config["nms_thre"] = 0.3
keypoint_detection_config["weights_path"] = os.path.join(WEIGHTS_FOLDER, "keypoint_detection_yoloxx_bar_value.pth")
bar_value_keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)

# keypoint_detection_config["conf_thre"] = 0.25
# keypoint_detection_config["nms_thre"] = 0.35
keypoint_detection_config["weights_path"] = os.path.join(WEIGHTS_FOLDER, "keypoint_detection_yoloxx_scatter_value.pth")
scatter_value_keypoint_detection_model = ObjectDetectionModel(**keypoint_detection_config)

text_detection_model = TextDetectionModel(**text_detection_config)
x_labels_text_detection_model = TextDetectionModel(**x_labels_text_detection_config)
x_labels_text_detection_config["image_short_side"] = 1024
x_labels_text_detection_model_2 = TextDetectionModel(**x_labels_text_detection_config)
y_labels_text_detection_model = TextDetectionModel(**y_labels_text_detection_config)
text_recognition_model = TextRecognitionModel(**text_recognition_config)
text_recognition_model.parseq.eval()
print()

In [None]:
# read ground truth from /home/thanh/bmga/data/validation/metadata.jsonl
import json

if not TEST_MODE:
    # with open("/home/thanh/bmga/data/train/metadata.jsonl", "r") as f:
    with open("/home/thanh/bmga/data/validation/metadata.jsonl", "r") as f:
        metadata = [json.loads(x) for x in f.readlines()]

    metadata_dict = {}
    for x in metadata:
        metadata_dict[x["file_name"]] = x

    filtered_image_paths = []
    filtered_original_image_paths = []

    for image_path in image_paths:
        if "images/" + image_path.split("/")[-1] not in metadata_dict.keys():
            continue
        filtered_image_paths.append(image_path)
        filtered_original_image_paths.append(image_path)

    image_paths = filtered_image_paths
    origin_image_paths = filtered_original_image_paths

### Utility Functions

In [None]:
# function to convert polygon points to smallest 4 points polygon
def convert_polygon_to_min_rect(polygon):
    polygon = np.array(polygon)
    polygon = polygon.reshape(-1, 2)
    polygon = polygon.astype(np.float32)
    rect = cv2.minAreaRect(polygon)
    box = cv2.boxPoints(rect)
    box = np.int0(box)

    return box

def crop_polygon_from_image(image, polygon):
    polygon = convert_polygon_to_min_rect(polygon)
    mask = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask, [polygon], 0, 255, -1, cv2.LINE_AA)
    out = 255 - np.zeros_like(image)
    out[mask == 255] = image[mask == 255]

    # return crop from image
    try:
        crop = out[np.min(polygon[:, 1]):np.max(polygon[:, 1]), np.min(polygon[:, 0]):np.max(polygon[:, 0])]
    except:
        crop = np.ones((32, 100, 3), dtype=np.uint8) * 255
    return crop


# sample_image_path = image_paths[0]
# sample_image = cv2.imread(sample_image_path)
# sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

# sample_polygon = [[20, 20], [10, 100], [100, 200], [300, 40]]

# # draw polygon
# sample_image = cv2.polylines(sample_image, [np.array(sample_polygon)], True, (0, 255, 0), 2)
# plt.imshow(sample_image)

In [None]:
# crop = crop_polygon_from_image(sample_image, sample_polygon)
# plt.imshow(crop)

In [None]:
def filter_x_polygons(polygons, img_height, img_path):
    # first, draw a line along y axis then count the number of x_label_boxes that intersect with the line
    max_count = 0
    max_count_line_y = 0

    for line_y in range(img_height):
        count = 0
        for polygon in polygons:
            if polygon:
                min_y = min([x[1] for x in polygon])
                max_y = max([x[1] for x in polygon])
            else:
                min_y = 0
                max_y = 0

            if min_y <= line_y <= max_y:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_y = line_y

    # filter out y_label_boxes that intersect with the line
    filtered_x_label_polygons = []
    for polygon in polygons:
        if polygon:
            min_y = min([x[1] for x in polygon])
            max_y = max([x[1] for x in polygon])
        else:
            min_y = 0
            max_y = 0

        if min_y <= max_count_line_y <= max_y:
            filtered_x_label_polygons.append(polygon)

    return filtered_x_label_polygons


def filter_y_polygons(polygons, img_width, image):
    # first, draw a line along x axis then count the number of y_label_boxes that intersect with the line
    max_count = 0
    max_count_line_x = 0

    for line_x in range(img_width):
        count = 0
        for polygon in polygons:
            if polygon:
                min_x = min([x[0] for x in polygon])
                max_x = max([x[0] for x in polygon])
            else:
                min_x = 0
                max_x = 0
            w = max_x - min_x
            if min_x + w // 4 <= line_x <= max_x - w // 4:
                count += 1
        if count > max_count:
            max_count = count
            max_count_line_x = line_x

    # filter out y_label_boxes that intersect with the line
    filtered_y_label_polygons = []
    for polygon in polygons:
        if polygon:
            min_x = min([x[0] for x in polygon])
            max_x = max([x[0] for x in polygon])
        else:
            min_x = 0
            max_x = 0
        if min_x <= max_count_line_x <= max_x:
            filtered_y_label_polygons.append(polygon)

    return filtered_y_label_polygons
    # # second, do text recognition on y_label_boxes
    # crops = []
    # for polygon in filtered_y_label_polygons:
    #     crop = crop_polygon_from_image(image, polygon)
    #     crops.append(crop)

    # text_recognition_results = text_recognition_model.predict(crops)

    # # filter out those boxes that the values can't be converted to float: TODO: only case that y labels are numbers, have to update
    # filtered_y_label_boxes_2 = []
    # for i, box in enumerate(filtered_y_label_polygons):
    #     try:
    #         text = "".join([c for c in text_recognition_results[0][i][0] if c in "0123456789."])
    #         if not text:
    #             float(text)
    #         filtered_y_label_boxes_2.append(box)
    #     except:
    #         pass

    # return filtered_y_label_boxes_2

def calculate_iou(polygon1, polygon2, image):
    # calculate iou between two polygons
    polygon1 = np.array(polygon1)
    polygon2 = np.array(polygon2)
    polygon1 = polygon1.reshape(-1, 2)
    polygon2 = polygon2.reshape(-1, 2)
    polygon1 = polygon1.astype(np.float32)
    polygon2 = polygon2.astype(np.float32)

    rect1 = cv2.minAreaRect(polygon1)
    box1 = cv2.boxPoints(rect1)
    box1 = np.int0(box1)

    rect2 = cv2.minAreaRect(polygon2)
    box2 = cv2.boxPoints(rect2)
    box2 = np.int0(box2)

    mask1 = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask1, [box1], 0, 255, -1, cv2.LINE_AA)
    mask2 = np.zeros(image.shape[:2], np.uint8)
    cv2.drawContours(mask2, [box2], 0, 255, -1, cv2.LINE_AA)

    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou_score = np.sum(intersection) / np.sum(union)

    return iou_score

def calculate_label_polygons_accuracy(pred_polygons, gt_polygons, image, is_x_label=True, iou_thre=0.5):
    if len(pred_polygons) != len(gt_polygons):
        return 0
    
    if is_x_label:
        gt_polygons = sorted(gt_polygons, key=lambda x: min([y[0] for y in x]) if x else 0)
        gt_polygons = sorted(pred_polygons, key=lambda x: min([y[0] for y in x]) if x else 0)
    else:
        gt_polygons = sorted(gt_polygons, key=lambda x: min([y[1] for y in x]) if x else 0)
        gt_polygons = sorted(pred_polygons, key=lambda x: min([y[1] for y in x]) if x else 0)

    iou_score = 0
    for i in range(len(gt_polygons)):
        iou = calculate_iou(gt_polygons[i], gt_polygons[i], image)
        if iou > iou_thre:
            iou_score += 1

    if iou_score == len(gt_polygons):
        return 1

    return 0

def visualize(image_path, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    image = cv2.imread(image_path)

    if value_boxes is not None:
        for box in value_boxes:
            cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)

    if x_boxes is not None:
        for box in x_boxes:
            cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)

    if y_boxes is not None:
        for box in y_boxes:
            cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)   
    
    if x_labels_polygons is not None:
        # visualize x_label_boxes
        for polygon in x_labels_polygons:
            polygon = np.array(polygon)
            polygon = polygon.reshape(-1, 2)
            polygon = polygon.astype(np.int32)
            cv2.drawContours(image, [polygon], 0, (255, 255, 0), 2)

    if y_labels_polygons is not None:
        # visualize y_label_boxes
        for polygon in y_labels_polygons:
            polygon = np.array(polygon)
            polygon = polygon.reshape(-1, 2)
            polygon = polygon.astype(np.int32)
            cv2.drawContours(image, [polygon], 0, (0, 255, 255), 2)

    plt.figure(figsize=(8, 8))
    plt.imshow(image)

### Graph classification model, x/y labels classification model

In [None]:
graph_classes = ['dot', 'line', 'scatter', 'vertical_bar', "horizontal_bar"]

graph_type_predictions = graph_classification_model.predict(image_paths=image_paths)

# convert predictions to graph type
graph_type_predictions = np.argmax(graph_type_predictions, axis=1)
graph_type_predictions = [graph_classes[i] for i in graph_type_predictions]

In [None]:
if not TEST_MODE:
    gt_classes = []

    for image_path in image_paths:
        gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["class"])

    # calculate accuracy
    acc = 0
    for idx in range(len(image_paths)):
        if graph_type_predictions[idx] == gt_classes[idx]:
            acc += 1

    print("acc: ", acc / len(image_paths))
    print(np.unique(gt_classes, return_counts=True))

In [None]:
type_classes = ["numerical", "categorical"]

x_type_predictions = x_type_classification_model.predict(image_paths=image_paths)
x_type_predictions = np.argmax(x_type_predictions, axis=1)
x_type_predictions = [type_classes[i] for i in x_type_predictions]

y_type_predictions = y_type_classification_model.predict(image_paths=image_paths)
y_type_predictions = np.argmax(y_type_predictions, axis=1)
y_type_predictions = [type_classes[i] for i in y_type_predictions]

if not TEST_MODE:
    x_type_gt_classes = []
    y_type_gt_classes = []

    for image_path in image_paths:
        x_type_gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["x_type"])
        y_type_gt_classes.append(metadata_dict["images/" + image_path.split("/")[-1]]["ground_truth"]["gt_parse"]["y_type"])

    # calculate accuracy
    x_type_acc = 0
    y_type_acc = 0
    for i in range(len(image_paths)):
        if x_type_predictions[i] == x_type_gt_classes[i]:
            x_type_acc += 1
        if y_type_predictions[i] == y_type_gt_classes[i]:
            y_type_acc += 1

    print("x_type_acc: ", x_type_acc / len(image_paths))
    print("y_type_acc: ", y_type_acc / len(image_paths))

### X/Y labels detection using DB model

In [None]:
# x_labels_predictions = x_labels_text_detection_model.predict(image_paths=image_paths)
# y_labels_predictions = y_labels_text_detection_model.predict(image_paths=image_paths)

In [None]:
# from tqdm import tqdm
# # calucate accuracy
# x_acc = 0
# y_acc = 0

# for idx in tqdm(range(len(image_paths))):
#     image = cv2.imread(image_paths[idx])

#     x_labels_polygons = x_labels_predictions[idx][0][0]
#     y_labels_polygons = y_labels_predictions[idx][0][0]

#     x_labels_polygons = filter_x_polygons(
#         x_labels_polygons,
#         image.shape[0],
#         image_paths[idx],
#     )

#     y_labels_polygons = filter_y_polygons(
#         y_labels_polygons,
#         image.shape[1],
#         image
#     )
    
#     x_acc += calculate_label_polygons_accuracy(
#         x_labels_polygons,
#         metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_polygons"],
#         image=image,
#         is_x_label=True,
#     )
    
#     y_acc += calculate_label_polygons_accuracy(
#         y_labels_polygons,
#         metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["y_labels_polygons"],
#         image=image,
#         is_x_label=False,    
#     )

# print("x_acc: ", x_acc / len(image_paths))
# print("y_acc: ", y_acc / len(image_paths))

# # x_acc:  0.8872987477638641
# # y_acc:  0.8461538461538461


In [None]:
# # visualize keypoint detection results, data is boxes
# idx = 162
# # idx = (idx + 1) % len(image_paths)
# image = cv2.imread(image_paths[idx])
# x_labels_polygons = x_labels_predictions[idx][0][0]
# y_labels_polygons = y_labels_predictions[idx][0][0]

# x_labels_polygons = filter_x_polygons(
#     x_labels_polygons,
#     image.shape[0],
#     image_paths[idx],
# )

# y_labels_polygons = filter_y_polygons(
#     y_labels_polygons,
#     image.shape[1],
#     image
# )

# # visualize x_label_boxes
# image = cv2.imread(image_paths[idx])
# for polygon in x_labels_polygons:
#     polygon = np.array(polygon)
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(image, [polygon], 0, (0, 255, 0), 2)

# # visualize y_label_boxes
# for polygon in y_labels_polygons:
#     polygon = np.array(polygon)
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(image, [polygon], 0, (0, 0, 255), 2)

# plt.figure(figsize=(10, 10))
# plt.imshow(image)
# print(idx, image_paths[idx])

In [None]:
# # visualize ground truth
# image = cv2.imread(image_paths[idx])
# for polygon in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["y_labels_polygons"]:
#     x0, y0, x1, y1, x2, y2, x3, y3 = polygon["x0"], polygon["y0"], polygon["x1"], polygon["y1"], polygon["x2"], polygon["y2"], polygon["x3"], polygon["y3"]
#     polygon = np.array([[x0, y0], [x1, y1], [x2, y2], [x3, y3]])
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(image, [polygon], 0, (0, 0, 255), 2)

# for polygon in metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["x_labels_polygons"]:
#     x0, y0, x1, y1, x2, y2, x3, y3 = polygon["x0"], polygon["y0"], polygon["x1"], polygon["y1"], polygon["x2"], polygon["y2"], polygon["x3"], polygon["y3"]
#     polygon = np.array([[x0, y0], [x1, y1], [x2, y2], [x3, y3]])
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(image, [polygon], 0, (0, 255, 0), 2)
# plt.imshow(image)
# print(metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["value"])

### Object detection model to detect point on graphs
1. Detect x_labels and y_labels points
2. Map these points with x_labels and y_labels texts
3. Post processing depends on the graph type

In [None]:
# keypoint_predictions = keypoint_detection_model.predict(image_paths=image_paths)

In [None]:
# # sample image of horizontal bar chart
# sample_image = cv2.imread("./data/validation/images/" + image_paths[idx].split("/")[-1])
# sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

# # rotate image using cv2
# sample_image = cv2.rotate(sample_image, cv2.ROTATE_90_COUNTERCLOCKWISE)

# # horizontal flip
# sample_image = cv2.flip(sample_image, 1)

# # save temporary image
# temp_path = "./data/temp.jpg"
# cv2.imwrite(temp_path, sample_image)

# plt.imshow(sample_image)

In [None]:
# x_labels_polygons = x_labels_text_detection_model.predict(image_paths=[temp_path])[0][0][0]
# y_labels_polygons = y_labels_text_detection_model.predict(image_paths=[temp_path])[0][0][0]

# x_labels_polygons = filter_x_polygons(
#     x_labels_polygons,
#     image.shape[0],
#     image_paths[idx],
# )

# y_labels_polygons = filter_y_polygons(
#     y_labels_polygons,
#     image.shape[1],
#     image
# )

# # visualize x_label_boxes
# for polygon in x_labels_polygons:
#     polygon = np.array(polygon)
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(sample_image, [polygon], 0, (0, 255, 0), 2)

# # visualize y_label_boxes
# for polygon in y_labels_polygons:
#     polygon = np.array(polygon)
#     polygon = polygon.reshape(-1, 2)
#     polygon = polygon.astype(np.int32)
#     cv2.drawContours(sample_image, [polygon], 0, (0, 0, 255), 2)

# single_keypoint_predictions = keypoint_detection_model.predict(image_paths=[temp_path])
# data = single_keypoint_predictions[0][0][0].cpu().numpy()

# value_boxes = (data[data[:, 6] == 0][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)
# x_boxes = (data[data[:, 6] == 1][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)
# y_boxes = (data[data[:, 6] == 2][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)


# visualize(temp_path, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

In [None]:
def process_labels_polygons(idx, graph_type):
    TEMP_IMAGE_FOLDER = "./data/temporary/"
    if not os.path.exists(TEMP_IMAGE_FOLDER):
        os.makedirs(TEMP_IMAGE_FOLDER)

    if graph_type == "horizontal_bar":
        chosen_x_labels_text_detection_model = x_labels_text_detection_model_2
    else:
        chosen_x_labels_text_detection_model = x_labels_text_detection_model

    if graph_type == "horizontal_bar":
        try:
            sample_image = cv2.imread(os.path.join(image_folder, origin_image_paths[idx].split("/")[-1]))

            # rotate image using cv2
            sample_image = cv2.rotate(sample_image, cv2.ROTATE_90_COUNTERCLOCKWISE)

            # horizontal flip
            sample_image = cv2.flip(sample_image, 1)

            # save temporary image
            temp_path = f"{os.path.join(TEMP_IMAGE_FOLDER, os.path.basename(image_paths[idx]))}"
            cv2.imwrite(temp_path, sample_image)

            image_paths[idx] = temp_path
        except:
            x_labels_polygons, y_labels_polygons = [], []

    try:
        x_labels_polygons = chosen_x_labels_text_detection_model.predict(image_paths=[image_paths[idx]])[0][0][0]
        y_labels_polygons = y_labels_text_detection_model.predict(image_paths=[image_paths[idx]])[0][0][0]
    except:
        print("can't predict original image, using resized image")
        sample_image = cv2.imread(image_paths[idx])
        if graph_type == "horizontal_bar":
            # resize image to size = 512, 768
            sample_image = cv2.resize(sample_image, (512, 768))
        else:
            sample_image = cv2.resize(sample_image, (768, 512))

        temp_path = f"{os.path.join(TEMP_IMAGE_FOLDER, os.path.basename(image_paths[idx]))}"
        cv2.imwrite(temp_path, sample_image)
        image_paths[idx] = temp_path

        try:
            x_labels_polygons = chosen_x_labels_text_detection_model.predict(image_paths=[image_paths[idx]])[0][0][0]
            y_labels_polygons = y_labels_text_detection_model.predict(image_paths=[image_paths[idx]])[0][0][0]
        except:
            x_labels_polygons, y_labels_polygons = [], []

    image = cv2.imread(image_paths[idx])
    x_labels_polygons = filter_x_polygons(
        x_labels_polygons,
        image.shape[0],
        image_paths[idx],
    )

    y_labels_polygons = filter_y_polygons(
        y_labels_polygons,
        image.shape[1],
        image
    )

    try:
        # use general model x/y/value
        single_keypoint_predictions = keypoint_detection_model.predict(image_paths=[image_paths[idx]])
        data = single_keypoint_predictions[0][0][0].cpu().numpy()
        x_boxes = (data[data[:, 6] == 1][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)
        y_boxes = (data[data[:, 6] == 2][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)

        # use separate model x/y
        # xy_keypoint_predictions = xy_keypoint_detection_model.predict(image_paths=[image_paths[idx]])
        # xy_data = xy_keypoint_predictions[0][0][0].cpu().numpy()
        # x_boxes = (xy_data[xy_data[:, 6] == 1][:, :4] / xy_keypoint_predictions[1][0]["ratio"]).astype(int)
        # y_boxes = (xy_data[xy_data[:, 6] == 2][:, :4] / xy_keypoint_predictions[1][0]["ratio"]).astype(int)

        # use general model for all chart types
        # value_boxes = (data[data[:, 6] == 0][:, :4] / single_keypoint_predictions[1][0]["ratio"]).astype(int)

        # use separate model for each chart type
        if graph_type == "line":
            keypoint_predictions = line_value_keypoint_detection_model.predict(image_paths=[image_paths[idx]])
            value_data = keypoint_predictions[0][0][0].cpu().numpy()
        elif graph_type == "scatter":
            keypoint_predictions = scatter_value_keypoint_detection_model.predict(image_paths=[image_paths[idx]])
            value_data = keypoint_predictions[0][0][0].cpu().numpy()
        else:
            keypoint_predictions = bar_value_keypoint_detection_model.predict(image_paths=[image_paths[idx]])
            value_data = keypoint_predictions[0][0][0].cpu().numpy()

        value_boxes = (value_data[value_data[:, 6] == 0][:, :4] / keypoint_predictions[1][0]["ratio"]).astype(int)
    except:
        value_boxes, x_boxes, y_boxes = [], [], []

    return value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons


if not TEST_MODE:
    # ------------ SELECT ONE SAMPLE ------------
    idx = 558
    # idx = (idx + 1) % len(image_paths)
    # while graph_type_predictions[idx] != "line":
    #     idx = (idx + 1) % len(image_paths)
    value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons = process_labels_polygons(idx, graph_type_predictions[idx])
    print("-------- PREDICTION ---------")
    print("len(x_labels_polygons): ", len(x_labels_polygons), "len(x_boxes): ", len(x_boxes), "len(value_boxes): ", len(value_boxes))

    visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

    # ground truth
    print("-------- GROUND TRUTH ---------")
    print("graph type: ", metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["class"])
    gt_values = metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]["value"]
    for v in gt_values:
        print(v)

    print("len(gt_values): ", len(gt_values))

    print(idx, image_paths[idx].split("/")[-1])

In [None]:
# GENERAL RULE:
# 1. Filter x_points, y_points by draw a line_y, line_x
def process_filter_xy_value_boxes(idx, x_boxes, y_boxes, value_boxes):
    def convert_4_points_box_to_polygon(box):
        x1, y1, x2, y2 = box
        return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]

    image = cv2.imread(image_paths[idx])
    x_boxes_polygons = [convert_4_points_box_to_polygon(box) for box in x_boxes]
    x_boxes_polygons = filter_x_polygons(x_boxes_polygons, image.shape[0], image_paths[idx])
    x_boxes = [
        [
            min([p[0] for p in polygon]) if polygon else 0,
            min([p[1] for p in polygon]) if polygon else 0,
            max([p[0] for p in polygon]) if polygon else 0,
            max([p[1] for p in polygon]) if polygon else 0,
        ] 
        for polygon in x_boxes_polygons
    ]

    y_boxes_polygons = [convert_4_points_box_to_polygon(box) for box in y_boxes]
    y_boxes_polygons = filter_y_polygons(y_boxes_polygons, image.shape[1] // 2, image)
    y_boxes = [
        [
            min([p[0] for p in polygon]) if polygon else 0,
            min([p[1] for p in polygon]) if polygon else 0,
            max([p[0] for p in polygon]) if polygon else 0,
            max([p[1] for p in polygon]) if polygon else 0,
        ]
        for polygon in y_boxes_polygons
    ]


    # draw Ox, Oy of the graph based on centers of x_boxes and y_boxes
    Oy = np.mean([(box[0] + box[2]) / 2 for box in y_boxes])
    Ox = np.mean([(box[1] + box[3]) / 2 for box in x_boxes])
    origin = (Ox, Oy)

    # filter out those value_boxes that are not in the graph
    filter_value_boxes = []
    for box in value_boxes:
        x1, y1, x2, y2 = box
        if x2 > origin[1] and y1 < origin[0]:
            filter_value_boxes.append(box)

    return x_boxes, y_boxes, filter_value_boxes


if not TEST_MODE:
    # visualize x_boxes and y_boxes
    x_boxes, y_boxes, value_boxes = process_filter_xy_value_boxes(idx, x_boxes, y_boxes, value_boxes)

    visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

In [None]:
def compute_iou_and_distance_one_direction(polygon, box, direction="x"):
    x1, y1, x2, y2 = box
    if len(polygon) == 0:
        polygon_min_x, polygon_max_x, polygon_min_y, polygon_max_y = 0, 0, 0, 0
    else:
        polygon_min_x = min([p[0] for p in polygon])
        polygon_max_x = max([p[0] for p in polygon])
        polygon_min_y = min([p[1] for p in polygon])
        polygon_max_y = max([p[1] for p in polygon])


    if direction == "x":
        intersection = set(range(int(x1), int(x2))).intersection(set(range(int(polygon_min_x), int(polygon_max_x))))
        partly_union = set(range(int(x1), int(x2)))
        iou = len(intersection) / len(partly_union)
        distance = abs((y2 + y1) / 2 - (polygon_max_y + polygon_min_y) / 2)
        return iou, distance
    elif direction == "y":
        intersection = set(range(int(y1), int(y2))).intersection(set(range(int(polygon_min_y), int(polygon_max_y))))
        partly_union = set(range(int(y1), int(y2)))
        iou = len(intersection) / len(partly_union)
        distance = abs((x2 + x1) / 2 - (polygon_max_x + polygon_min_x) / 2)
        return iou, distance
    else:
        raise ValueError("direction must be x or y")

def compute_iou_and_all_distances(polygon, box, direction="x"):
    x1, y1, x2, y2 = box
    if len(polygon) == 0:
        polygon_min_x, polygon_max_x, polygon_min_y, polygon_max_y = 0, 0, 0, 0
    else:
        polygon_min_x = min([p[0] for p in polygon])
        polygon_max_x = max([p[0] for p in polygon])
        polygon_min_y = min([p[1] for p in polygon])
        polygon_max_y = max([p[1] for p in polygon])

    distance_x = abs((x2 + x1) / 2 - (polygon_max_x + polygon_min_x) / 2)
    distance_y = abs((y2 + y1) / 2 - (polygon_max_y + polygon_min_y) / 2)

    if direction == "x":
        intersection = set(range(int(x1), int(x2))).intersection(set(range(int(polygon_min_x), int(polygon_max_x))))
        partly_union = set(range(int(x1), int(x2)))
        if len(partly_union) == 0:
            return 0, 0, 0
        iou = len(intersection) / len(partly_union)
        return iou, distance_x, distance_y
    elif direction == "y":
        intersection = set(range(int(y1), int(y2))).intersection(set(range(int(polygon_min_y), int(polygon_max_y))))
        partly_union = set(range(int(y1), int(y2)))
        if len(partly_union) == 0:
            return 0, 0, 0
        iou = len(intersection) / len(partly_union)
        return iou, distance_x, distance_y
    else:
        raise ValueError("direction must be x or y")


def mapping_labels_and_value(graph_type, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    # if graph_type == "horizontal_bar":
    #     # map x_labels with x_boxes because we don't need to be exact
    #     # map based on best iou in Ox direction
    #     # find all iou > o pairs between x_labels polygons and x_boxes
    #     pairs = []  # (x_labels_index, x_boxes_index, similarity)
    #     for i, x_label_polygon in enumerate(x_labels_polygons):
    #         for j, x_box in enumerate(x_boxes):
    #             iou, distance = compute_iou_and_distance_one_direction(x_label_polygon, x_box, direction="x")
    #             if iou > 0.2:
    #                 pairs.append((i, j, iou, distance))
        
    #     # select only one pair for each y_label with the highest iou
    #     pairs = sorted(pairs, key=lambda x: (x[2], x[3]), reverse=True)
        
    #     filtered_pairs = []
    #     existed_indices = set()
    #     for pair in pairs:
    #         if pair[0] not in existed_indices:
    #             filtered_pairs.append(pair)
    #             existed_indices.add(pair[0])

    #     # get remaining y_labels_polygons and y_boxes
    #     remaining_x_labels_polygons = [x_labels_polygons[p[0]] for p in filtered_pairs]
    #     remaining_x_boxes = [x_boxes[p[1]] for p in filtered_pairs]

    #     return remaining_x_boxes, y_boxes, remaining_x_labels_polygons, y_labels_polygons
    # else:
        # map y_labels with y_boxes because we don't need to be exact
        # map based on best iou in Oy direction
        # find all iou > o pairs between y_labels polygons and y_boxes
        pairs = []  # (y_labels_index, y_boxes_index, similarity)
        for i, y_label_polygon in enumerate(y_labels_polygons):
            for j, y_box in enumerate(y_boxes):
                iou, distance = compute_iou_and_distance_one_direction(y_label_polygon, y_box, direction="y")
                if iou > 0.2:
                    pairs.append((i, j, iou, distance))
        
        # select only one pair for each y_label with the highest iou
        pairs = sorted(pairs, key=lambda x: (x[2], x[3]), reverse=True)
        
        filtered_pairs = []
        existed_indices = set()
        for pair in pairs:
            if pair[0] not in existed_indices:
                filtered_pairs.append(pair)
                existed_indices.add(pair[0])

        # get remaining y_labels_polygons and y_boxes
        remaining_y_labels_polygons = [y_labels_polygons[p[0]] for p in filtered_pairs]
        remaining_y_boxes = [y_boxes[p[1]] for p in filtered_pairs]

        return x_boxes, remaining_y_boxes, x_labels_polygons, remaining_y_labels_polygons

if not TEST_MODE:
    # visualize x_boxes and y_boxes
    graph_type = graph_type_predictions[idx]
    x_boxes, y_boxes, x_labels_polygons, y_labels_polygons = mapping_labels_and_value(graph_type, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

    visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

In [None]:
# ------------ VERTICAL BAR GRAPH -------------
# Data: x will always be categorical, y will always be numerical
# 1. we should prioritize value prediction, map 1-1 with x_points then if there is outlier x_points/values, ignore it, map with closest x2-x1 first, then y2-y1
# if number of x_labels is equal to number of x_boxes, then map 1-1
# if number of x_labels is different from number of x_boxes, then:
#     - get the rectangle of x_label
#     - draw a rhombus with the points is center of the rectangle edges
#     - draw a rectangle with the center be the highest point of the rhombus
#     - then map 1-1 with x_boxes

def map_x_labels_polygons_and_x_boxes(x_labels_polygons, x_boxes):
    if False: # len(x_labels_polygons) == len(x_boxes):
        # map 1-1
        indices_mapping = [(i, i) for i in range(len(x_labels_polygons))]
    else:
        # - get the min rectangle of x_label
        x_labels_boxes = []
        for x_label_polygon in x_labels_polygons:
            rect = cv2.minAreaRect(np.array(x_label_polygon))
            box = cv2.boxPoints(rect)
            box = np.intp(box)
            x_labels_boxes.append(box)
        
        #     - draw a rhombus with the points is center of the rectangle edges
        x_labels_rhombuses = []
        for x_label_box in x_labels_boxes:
            x1, y1, x2, y2, x3, y3, x4, y4 = x_label_box.flatten()
            new_x1 = (x1 + x2) / 2
            new_y1 = (y1 + y2) / 2
            new_x2 = (x2 + x3) / 2
            new_y2 = (y2 + y3) / 2
            new_x3 = (x3 + x4) / 2
            new_y3 = (y3 + y4) / 2
            new_x4 = (x4 + x1) / 2
            new_y4 = (y4 + y1) / 2
            x_labels_rhombuses.append(np.array([[new_x1, new_y1], [new_x2, new_y2], [new_x3, new_y3], [new_x4, new_y4]]))

        #     - draw a rectangle with the center be the highest point of the rhombus
        x_labels_rectangles = []
        for x_label_rhombus in x_labels_rhombuses:
            # highest point is the point that has minimum y
            highest_point = None
            for point in x_label_rhombus:
                if highest_point is None:
                    highest_point = point
                else:
                    if point[1] < highest_point[1]:
                        highest_point = point
            x, y = highest_point
            w = 10
            h = 10 # x_boxes[0].shape[1]

            x1 = x - w
            y1 = y - h
            x2 = x1 + 2 * w
            y2 = y1 + 2 * h

            x_labels_rectangles.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))

        # then map 1-1 with x_boxes
        pairs = []
        for i, x_label_rect in enumerate(x_labels_rectangles):
            for j, x_box in enumerate(x_boxes):
                iou, distance = compute_iou_and_distance_one_direction(x_label_rect, x_box, direction="x")
                if iou > 0.2:
                    pairs.append((i, j, iou, distance))
        
        # select only one pair for each y_label with the highest iou
        pairs = sorted(pairs, key=lambda x: (x[2], x[3]), reverse=True)
        
        filtered_pairs = []
        existed_labels_indices = set()
        existed_boxes_indices = set()
        for pair in pairs:
            if pair[0] not in existed_labels_indices and pair[1] not in existed_boxes_indices:
                filtered_pairs.append(pair)
                existed_labels_indices.add(pair[0])
                existed_boxes_indices.add(pair[1])


        indices_mapping = [(p[0], p[1]) for p in filtered_pairs]

        
    # get remaining y_labels_polygons and y_boxes
    remaining_x_labels_polygons = [x_labels_polygons[p[0]] for p in indices_mapping]
    remaining_x_boxes = [x_boxes[p[1]] for p in indices_mapping]

    return remaining_x_labels_polygons, remaining_x_boxes


def map_x_boxes_and_value_boxes(x_boxes, value_boxes, graph_type=""):
    value_indices_mapping = []
    if len(x_boxes) == len(value_boxes) and graph_type in ["vertical_bar", "horizontal_bar", "dot"]:
        value_indices_mapping = [(i, i) for i in range(len(value_boxes))]
    else:
        # rely on the number of x_boxes
        # map 1-1 with x_boxes, if there is any missing, set the value to minimum value of y_boxes value
        value_x_box_pairs = []
        for i, x_box in enumerate(x_boxes):
            for j, value_box in enumerate(value_boxes):
                # convert x_box to polygon
                x1, y1, x2, y2 = x_box
                x_box_polygon = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
                # iou, distance = compute_iou_and_distance_one_direction(x_box_polygon, value_box, direction="x")
                iou, distance_x, distance_y = compute_iou_and_all_distances(x_box_polygon, value_box, direction="x")
                if iou > 0.2:
                    value_x_box_pairs.append((i, j, iou, distance_y))
        
        # value_x_box_pairs = sorted(value_x_box_pairs, key=lambda x: (x[2], x[3]), reverse=True)
        value_x_box_pairs = sorted(value_x_box_pairs, key=lambda x: (-x[3], x[2]), reverse=True)
        
        filtered_pairs = []
        existed_boxes_indices = set()
        existed_values_indices = set()
        for pair in value_x_box_pairs:
            if pair[0] not in existed_boxes_indices and pair[1] not in existed_values_indices:
                filtered_pairs.append(pair)
                existed_boxes_indices.add(pair[0])
                existed_values_indices.add(pair[1])

        value_indices_mapping = [(p[0], p[1]) for p in filtered_pairs]

    return value_indices_mapping


def filter_non_numerical_boxes_and_polygons(image_path, boxes, polygons):
    origin_image = cv2.imread(image_path)
    origin_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)

    crops = []
    for polygon, box in zip(polygons, boxes):
        try:
            crop = crop_polygon_from_image(origin_image, polygon)
        except:
            crop = np.ones((32, 100, 3), dtype=np.uint8) * 255
        if min(crop.shape) == 0:
            # random white blank crop
            crop = np.ones((32, 100, 3), dtype=np.uint8) * 255

        if graph_type_predictions[idx] == "horizontal_bar":
            # horizontal flip
            crop = cv2.flip(crop, 1)
        crops.append(crop)
    
    filtered_texts, filtered_polygons, filtered_boxes = [], [], []
    crops_texts = [p[0] for p in text_recognition_model.predict(crops)[0]]
    for text, polygon, box in zip(crops_texts, polygons, boxes):
        try:
            # filter out those character that are not numerical and alphabets
            text = "".join([c for c in text if c in "0123456789.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"])

            t = text.replace(",", "").replace("%", "").replace("-", "")
            if t.endswith("K") or t.endswith("k"):
                t = t[:-1]
            float(t)

            filtered_texts.append(text)
            filtered_polygons.append(polygon)
            filtered_boxes.append(box)
        except:
            print("Not numerical", text)

    return filtered_texts, filtered_polygons, filtered_boxes


def get_pixel_to_value_pair(boxes, texts, direction="y"):
    pixel_to_value_pairs = []
    for i, box in enumerate(boxes):
        value_text = texts[i].replace(",", "").replace("%", "").replace("-", "")
        if value_text.endswith("K") or value_text.endswith("k"):
            value_text = value_text.replace("K", "000")

        # only keep numbers
        value_text = "".join([c for c in value_text if c in "0123456789.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"])

        if direction == "y":
            pixel_to_value_pairs.append(((box[1] + box[3]) / 2,  float(value_text)))
        else:
            pixel_to_value_pairs.append(((box[0] + box[2]) / 2,  float(value_text)))

    return pixel_to_value_pairs

def get_x_values_from_value_boxes(value_boxes, pixel_to_value_pairs, default_len):
    """
        default_len: in case there is no pixel_to_value_pairs, we return `[0] * default_len` for all x labels
    """
    if len(pixel_to_value_pairs) == 0:
        print("No y_boxes found!!!")
        all_values = [0] * default_len
    else:
        min_value = min([p[1] for p in pixel_to_value_pairs]) if len(pixel_to_value_pairs) > 0 else 0

        # print(pixel_value_pairs)
        # calculate the real value of value boxes based on Oy axis
        # sort value_boxes
        value_boxes = sorted(value_boxes, key=lambda x: x[0])
        all_values = []
        for value_box in value_boxes:
            value_x_pixel = (value_box[0] + value_box[2]) / 2

            # find 2 nearest pixel_value_pairs to value_y_pixel
            nearest_pixel_value_pairs = sorted(pixel_to_value_pairs, key=lambda x: abs(x[0] - value_x_pixel))[:2]
            if len(nearest_pixel_value_pairs) >= 2:
                x1_pixel, x1_value = nearest_pixel_value_pairs[0]
                x2_pixel, x2_value = nearest_pixel_value_pairs[1]
            else:
                # TODO: handle the case there is only one nearest y value -> use origin as 0 or use highest y_labels then map value
                x1_pixel, x1_value = nearest_pixel_value_pairs[0]
                x2_pixel, x2_value = nearest_pixel_value_pairs[0]

            # calculate the real value of value_box
            # TODO: handle the case value_y_pixel > y2_pixel
            if value_x_pixel > x1_pixel: # on the right of x1_pixel
                value_box_value = x1_value + abs((x2_value - x1_value) / (x2_pixel - x1_pixel) * (value_x_pixel - x1_pixel))
                # y1_value - abs((y2_value - y1_value) / (y2_pixel - y1_pixel) * (value_y_pixel - y1_pixel))
            else:
                value_box_value = x1_value - abs((x2_value - x1_value) / (x2_pixel - x1_pixel) * (value_x_pixel - x1_pixel))
                # y1_value + abs((y2_value - y1_value) / (y2_pixel - y1_pixel) * (value_y_pixel - y1_pixel))

            # print("----------------")
            # print("y1_pixel =", y1_pixel, "y1_value =", y1_value)
            # print("y2_pixel =", y2_pixel, "y2_value =", y2_value)
            # print("value_y_pixel =", value_y_pixel, "value_box_value =", value_box_value)
            if value_box_value < 0 or math.isnan(value_box_value):
                value_box_value = min_value
            all_values.append(value_box_value)

    # set all infinite values to 0
    all_values = [0 if v == float("inf") else v for v in all_values]

    return all_values


def get_y_values_from_value_boxes(value_boxes, pixel_to_value_pairs, default_len):
    """
        default_len: in case there is no pixel_to_value_pairs, we return `[0] * default_len` for all x labels
    """
    if len(pixel_to_value_pairs) == 0:
        print("No y_boxes found!!!")
        all_values = [0] * default_len
    else:
        min_value = min([p[1] for p in pixel_to_value_pairs]) if len(pixel_to_value_pairs) > 0 else 0

        # print(pixel_value_pairs)
        # calculate the real value of value boxes based on Oy axis
        # sort value_boxes
        value_boxes = sorted(value_boxes, key=lambda x: x[0])
        all_values = []
        for value_box in value_boxes:
            value_y_pixel = (value_box[1] + value_box[3]) / 2

            # find 2 nearest pixel_value_pairs to value_y_pixel
            nearest_pixel_value_pairs = sorted(pixel_to_value_pairs, key=lambda x: abs(x[0] - value_y_pixel))[:2]
            if len(nearest_pixel_value_pairs) >= 2:
                y1_pixel, y1_value = nearest_pixel_value_pairs[0]
                y2_pixel, y2_value = nearest_pixel_value_pairs[1]
            else:
                # TODO: handle the case there is only one nearest y value -> use origin as 0 or use highest y_labels then map value
                y1_pixel, y1_value = nearest_pixel_value_pairs[0]
                y2_pixel, y2_value = nearest_pixel_value_pairs[0]

            # calculate the real value of value_box
            # TODO: handle the case value_y_pixel > y2_pixel
            if value_y_pixel > y1_pixel: # below the y1_pixel
                value_box_value = y1_value - abs((y2_value - y1_value) / (y2_pixel - y1_pixel) * (value_y_pixel - y1_pixel))
            else:
                value_box_value = y1_value + abs((y2_value - y1_value) / (y2_pixel - y1_pixel) * (value_y_pixel - y1_pixel))

            # print("----------------")
            # print("y1_pixel =", y1_pixel, "y1_value =", y1_value)
            # print("y2_pixel =", y2_pixel, "y2_value =", y2_value)
            # print("value_y_pixel =", value_y_pixel, "value_box_value =", value_box_value)
            if value_box_value < 0 or math.isnan(value_box_value):
                value_box_value = min_value
            all_values.append(value_box_value)

    # set all infinite values to 0
    all_values = [0 if v == float("inf") else v for v in all_values]

    return all_values


def read_text_from_polygons(image_path, polygons, graph_type):
    origin_image = cv2.imread(image_path)

    crops = []
    for polygon in polygons:
        try:
            crop = crop_polygon_from_image(origin_image, polygon)
        except:
            crop = np.ones((32, 100, 3), dtype=np.uint8) * 255
        if min(crop.shape) == 0:
            # random white blank crop
            crop = np.ones((32, 100, 3), dtype=np.uint8) * 255
        if graph_type == "horizontal_bar":
            # horizontal flip
            crop = cv2.flip(crop, 1)
        crops.append(crop)

    texts = [p[0] for p in text_recognition_model.predict(crops)[0]]

    return texts

In [None]:
import math

def postprocess_bar_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    x_labels_polygons, x_boxes = map_x_labels_polygons_and_x_boxes(x_labels_polygons, x_boxes)
    value_indices_mapping = map_x_boxes_and_value_boxes(x_boxes, value_boxes)    

    # add missing pairs with value_index -1 if missing value
    missing_x_indices = set(range(len(x_boxes))) - set([p[0] for p in value_indices_mapping])
    for missing_x_index in missing_x_indices:
        value_indices_mapping.append((missing_x_index, -1))
    
    # then add new boxes to value_boxes
    inserted_value_boxes = []
    value_indices_mapping.sort(key=lambda x: x[0])
    for i, value_index in value_indices_mapping:
        if value_index != -1:
            inserted_value_boxes.append(value_boxes[value_index])
        else:
            inserted_value_boxes.append(x_boxes[i])

    value_boxes = inserted_value_boxes
    x_boxes = [x_boxes[p[0]] for p in value_indices_mapping]

    # filter out those y boxes that are not numerical
    image_path = image_paths[idx]
    filtered_texts, filtered_y_labels_polygons, filtered_y_boxes = filter_non_numerical_boxes_and_polygons(image_path, y_boxes, y_labels_polygons)

    # get pixel to value pair
    pixel_to_value_pairs = get_pixel_to_value_pair(filtered_y_boxes, filtered_texts, direction="y")

    # get y values from value boxes
    all_values = get_y_values_from_value_boxes(value_boxes, pixel_to_value_pairs, default_len=len(x_labels_polygons))

    # predict text for x_labels_polygons and sort x_labels_polygons based on min x of x_labels_polygons
    x_labels_polygons = sorted(x_labels_polygons, key=lambda x: min([p[0] for p in x]))

    x_labels_texts = read_text_from_polygons(image_path, x_labels_polygons, graph_type_predictions[idx])

    if graph_type_predictions[idx] == "horizontal_bar":
        x_labels_texts = x_labels_texts[::-1]
        all_values = all_values[::-1]

    return value_boxes, x_boxes, filtered_y_boxes, x_labels_polygons, filtered_y_labels_polygons, x_labels_texts, all_values

In [None]:
# 1. get line, then map x pixel t  y_pixel
# 2. use missing x boxes to get remaining value -> value box with corresponding y pixel

# find the line in line chart using opencv
def find_line(image_path, text_detection_prediction, erode_size=2):
    image = cv2.imread(image_path)

    # # FIND THE LINE USING SEGMENTATION MODEL
    # mask = line_segmentation_model.predict([image_path])[0][0]
    # # apply threshold
    # mask = mask > 0.99
    # mask = mask.astype(np.uint8)
    # # resize mask to original image size
    # mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
    # # round mask to integer
    # mask = np.round(mask).astype(np.uint8)
    # mask0 = mask * 255

    # FIND THE LINE USING OPENCV
    # convert to grayscale
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # remove all pixels that are in text_detection_prediction polygons
    for polygon in text_detection_prediction:
        # if polygon is bigger than 0.2 * image size, then it is not a text
        if not cv2.contourArea(np.array(polygon)) > 0.2 * image.shape[0] * image.shape[1]:
            cv2.fillPoly(image, [np.array(polygon)], 255)

    lower = np.array([0])
    upper = np.array([180])
    mask = cv2.inRange(image, lower, upper)

    mask = cv2.erode(mask, np.ones((erode_size, erode_size), np.uint8), iterations=1)
    mask = cv2.dilate(mask, np.ones((3, 3), np.uint8), iterations=1)

    # remove all horizontal lines
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    detected_lines = cv2.morphologyEx(mask, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
    cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]

    for c in cnts:
        cv2.drawContours(mask, [c], -1, 0, -1)

    # remove all vertical lines
    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
    detected_lines = cv2.morphologyEx(mask, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
    cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]

    for c in cnts:
        cv2.drawContours(mask, [c], -1, 0, -1)

    # remove all connected components smaller than 16 pixels
    output = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
    num_labels = output[0]
    labels = output[1]
    stats = output[2]
    centroids = output[3]

    sizes = stats[1:, -1]
    num_labels = num_labels - 1

    min_size = 16

    for i in range(0, num_labels):
        if sizes[i] <= min_size:
            mask[labels == i + 1] = 0

    # mask = np.minimum(mask, mask0)
    return image, mask

def add_missing_value_box_for_line_graph(image_path, x_boxes, y_boxes, value_boxes, x_labels_polygons, value_indices_mapping, missing_x_indices, is_visualize=False):
    # TODO: remove all the text before find line
    try:
        text_detection_prediction = text_detection_model.predict([image_path])[0][0][0]
    except:
        try:
            print("Can't predict on original image, try to pad image to multiples of 128")
            temp_image = cv2.imread(image_paths[idx])
            h, w = temp_image.shape[:2]

            # find nearest padding size to multiples of 128
            h = int(np.ceil(h / 128) * 128)
            w = int(np.ceil(w / 128) * 128)

            # pad both side to 768
            temp_image = cv2.copyMakeBorder(temp_image, 0, h - temp_image.shape[0], 0, w - temp_image.shape[1], cv2.BORDER_CONSTANT, value=[255, 255, 255])
            cv2.imwrite("./data/temp.png", temp_image)

            text_detection_prediction = text_detection_model.predict(["./data/temp.png"])[0][0][0]
        except:
            text_detection_prediction = []

    _, mask = find_line(image_path, text_detection_prediction)

    if y_boxes and x_boxes:
        Oy = np.mean([(box[0] + box[2]) / 2 for box in y_boxes])
        Ox = np.mean([(box[1] + box[3]) / 2 for box in x_boxes])
    else:
        Oy = 0
        Ox = 0

    if len(y_boxes) == 0:
        min_y = 0
    else:
        min_y = min([(box[1] + box[3]) // 2 for box in y_boxes])

    if len(x_boxes) == 0:
        max_x = mask.shape[1]
    else:
        max_x = max([(box[0] + box[2]) // 2 for box in x_boxes])

    # set value on the left of Ox to 0 and below of Oy to 0
    mask[:, :int(Oy) + 3] = 0
    mask[int(Ox) - 5:, :] = 0

    # only apply this if there is no value box bigger than min_y
    if not any([(box[1] + box[3]) / 2 < min_y for box in value_boxes]):
        mask[:int(min_y) - 5, :] = 0
    mask[:, int(max_x):] = 0
    if is_visualize:
        plt.imshow(mask)

    # get the mapping x_pixel -> (x_value, y_pixel)
    mapping = {}
    for x_pixel in range(mask.shape[1]):
        # mean of y_pixel
        y_pixel = np.mean(np.where(mask[:, x_pixel] > 0)[0])
        if not math.isnan(y_pixel):
            mapping[x_pixel] = y_pixel

    # there is still case that x_pixel is discontinuous, so we need to interpolate
    x_pixels = list(mapping.keys())
    y_pixels = [mapping[x_pixel] for x_pixel in x_pixels]

    if len(x_pixels) == 0:
        print("Not found x_pixels, try with erode = 1")
        _, mask = find_line(image_path, text_detection_prediction, 1)

        # set value on the left of Ox to 0 and below of Oy to 0
        mask[:, :int(Oy) + 3] = 0
        mask[int(Ox) - 5:, :] = 0

        # only apply this if there is no value box bigger than min_y
        if not any([(box[1] + box[3]) / 2 < min_y for box in value_boxes]):
            mask[:int(min_y) - 5, :] = 0
        mask[:, int(max_x):] = 0
        if is_visualize:
            plt.imshow(mask)

        # get the mapping x_pixel -> (x_value, y_pixel)
        mapping = {}
        for x_pixel in range(mask.shape[1]):
            # mean of y_pixel
            y_pixel = np.mean(np.where(mask[:, x_pixel] > 0)[0])
            if not math.isnan(y_pixel):
                mapping[x_pixel] = y_pixel

        # there is still case that x_pixel is discontinuous, so we need to interpolate
        x_pixels = list(mapping.keys())
        y_pixels = [mapping[x_pixel] for x_pixel in x_pixels]

    if len(x_pixels) <= 1:
        # adding default values
        filtered_x_boxes = [x_boxes[p[0]] for p in value_indices_mapping]
        filtered_x_labels_polygons = [x_labels_polygons[p[0]] for p in value_indices_mapping]
        filtered_value_boxes = [value_boxes[p[1]] for p in value_indices_mapping]

        inserted_value_boxes = []
        inserted_x_boxes = []
        inserted_x_labels_polygons = []

        for index in missing_x_indices:
            if y_pixel is not None:
                if len(value_boxes):
                    inserted_value_boxes.append(value_boxes[0])
                else:
                    mid_x = mask.shape[1] // 2
                    mid_y = mask.shape[0] // 2
                    inserted_value_boxes.append([mid_x - 5, mid_y - 5, mid_x + 5, mid_y + 5])
                inserted_x_boxes.append(x_boxes[index])
                inserted_x_labels_polygons.append(x_labels_polygons[index])

        filtered_x_boxes.extend(inserted_x_boxes)
        filtered_value_boxes.extend(inserted_value_boxes)
        filtered_x_labels_polygons.extend(inserted_x_labels_polygons)
    else:
        all_x_pixels = []
        all_y_pixels = []
        for i in range(len(x_pixels) - 1):
            # insert all the missing x_pixels and y_pixels using linspace y_pixels
            new_y_pixels = np.linspace(y_pixels[i], y_pixels[i + 1], x_pixels[i + 1] - x_pixels[i] + 1)
            new_x_pixels = np.linspace(x_pixels[i], x_pixels[i + 1], x_pixels[i + 1] - x_pixels[i] + 1)
            
            all_x_pixels.extend(new_x_pixels)
            all_y_pixels.extend(new_y_pixels)

        # sort all_x_pixels and all_y_pixels by x_pixel
        x_pixels, y_pixels = zip(*sorted(zip(all_x_pixels, all_y_pixels)))
        mapping = {x: y for x, y in zip(x_pixels, y_pixels)}
        # draw the mapping
        # print(min(mapping.keys()), max(mapping.keys()))
        if is_visualize:
            plt.plot(list(mapping.keys()), list(mapping.values()))

        inserted_value_boxes = []
        inserted_x_boxes = []
        inserted_x_labels_polygons = []

        for index in missing_x_indices:
            x_pixel = (x_boxes[index][0] + x_boxes[index][2]) // 2
            # print("x_pixel =", x_pixel, "Oy =", Oy)
            x_pixel = int(Oy + 4) if x_pixel <= Oy + 3 else x_pixel
            x_pixel = int(max_x - 4) if x_pixel >= max_x - 3 else x_pixel
            # print("new x_pixel =", x_pixel, "Oy =", Oy)
            y_pixel = mapping.get(x_pixel, None)
            if y_pixel is not None:
                inserted_value_boxes.append([int(x_boxes[index][0]), int(y_pixel - 5), int(x_boxes[index][2]), int(y_pixel + 5)])
                inserted_x_boxes.append(x_boxes[index])
                inserted_x_labels_polygons.append(x_labels_polygons[index])

        filtered_x_boxes = [x_boxes[p[0]] for p in value_indices_mapping]
        filtered_x_labels_polygons = [x_labels_polygons[p[0]] for p in value_indices_mapping]
        filtered_value_boxes = [value_boxes[p[1]] for p in value_indices_mapping]

        filtered_x_boxes.extend(inserted_x_boxes)
        filtered_value_boxes.extend(inserted_value_boxes)
        filtered_x_labels_polygons.extend(inserted_x_labels_polygons)

    return filtered_value_boxes, filtered_x_boxes, filtered_x_labels_polygons


if not TEST_MODE:
    # ---------- EXAMPLE ----------
    x_labels_polygons, x_boxes = map_x_labels_polygons_and_x_boxes(x_labels_polygons, x_boxes)
    # visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

    value_indices_mapping = map_x_boxes_and_value_boxes(x_boxes, value_boxes, graph_type=graph_type_predictions[idx])

    missing_x_indices = set(range(len(x_boxes))) - set([p[0] for p in value_indices_mapping])

    value_boxes, x_boxes, x_labels_polygons = add_missing_value_box_for_line_graph(image_paths[idx], x_boxes, y_boxes, value_boxes, x_labels_polygons, value_indices_mapping, missing_x_indices, is_visualize=True)
    visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)


In [None]:
def postprocess_line_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    x_labels_polygons, x_boxes = map_x_labels_polygons_and_x_boxes(x_labels_polygons, x_boxes)
    value_indices_mapping = map_x_boxes_and_value_boxes(x_boxes, value_boxes)

    missing_x_indices = set(range(len(x_boxes))) - set([p[0] for p in value_indices_mapping])
    if len(missing_x_indices) > 0:
        # if there are missing values, we use CV algorithm to project x_boxes to line to get value_boxes
        value_boxes, x_boxes, x_labels_polygons = add_missing_value_box_for_line_graph(image_paths[idx], x_boxes, y_boxes, value_boxes, x_labels_polygons, value_indices_mapping, missing_x_indices)
    
    # filter out those y boxes that are not numerical
    image_path = image_paths[idx]
    filtered_texts, filtered_y_labels_polygons, filtered_y_boxes = filter_non_numerical_boxes_and_polygons(image_path, y_boxes, y_labels_polygons)

    # get pixel to value pair
    pixel_to_value_pairs = get_pixel_to_value_pair(filtered_y_boxes, filtered_texts, direction="y")

    # get y values from value boxes
    all_values = get_y_values_from_value_boxes(value_boxes, pixel_to_value_pairs, default_len=len(x_labels_polygons))

    # predict text for x_labels_polygons and sort x_labels_polygons based on min x of x_labels_polygons
    x_labels_polygons = sorted(x_labels_polygons, key=lambda x: min([p[0] for p in x]) if len(x) > 0 else 0)

    x_labels_texts = read_text_from_polygons(image_path, x_labels_polygons, graph_type_predictions[idx])

    return value_boxes, x_boxes, filtered_y_boxes, x_labels_polygons, filtered_y_labels_polygons, x_labels_texts, all_values

In [None]:
def postprocess_scatter_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    x_labels_polygons, x_boxes = map_x_labels_polygons_and_x_boxes(x_labels_polygons, x_boxes)

    filtered_x_texts, filtered_x_labels_polygons, filtered_x_boxes = filter_non_numerical_boxes_and_polygons(image_paths[idx], x_boxes, x_labels_polygons)
    filtered_y_texts, filtered_y_labels_polygons, filtered_y_boxes = filter_non_numerical_boxes_and_polygons(image_paths[idx], y_boxes, y_labels_polygons)

    x_pixel_to_value_pairs = get_pixel_to_value_pair(filtered_x_boxes, filtered_x_texts, direction="x")
    y_pixel_to_value_pairs = get_pixel_to_value_pair(filtered_y_boxes, filtered_y_texts, direction="y")

    all_x_values = get_x_values_from_value_boxes(value_boxes, x_pixel_to_value_pairs, default_len=0)
    all_y_values = get_y_values_from_value_boxes(value_boxes, y_pixel_to_value_pairs, default_len=0)

    # convert all values to float
    all_x_values = [float(x) for x in all_x_values]
    all_y_values = [float(x) for x in all_y_values]

    return value_boxes, filtered_x_boxes, filtered_y_boxes, filtered_x_labels_polygons, filtered_y_labels_polygons, all_x_values, all_y_values


In [None]:
def postprocess_dot_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons):
    if x_type_predictions[idx] == "categorical":
        # if Ox is categorical, we use vertical bar postporcessing method
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_bar_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)
    else:
        # if Ox is numerical, we use scatter plot postprocessing method, then round y (and maybe x also) value to integer.
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_scatter_graph(idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)
        all_values = [round(y) for y in all_values]
        x_labels_texts = [round(x) for x in x_labels_texts]

    return value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values

In [None]:
if not TEST_MODE:
    if graph_type_predictions[idx] in ["vertical_bar", "horizontal_bar"]:
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_bar_graph(
            idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
        )
    elif graph_type_predictions[idx] == "line":
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_line_graph(
            idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
        )
    elif graph_type_predictions[idx] == "scatter":
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_scatter_graph(
            idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
        )
    elif graph_type_predictions[idx] == "dot":
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_dot_graph(
            idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
        )

    visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

In [None]:
import math
import uuid

# compute metrics
def compute_metrics(idx, all_values, x_labels_texts, graph_type):
    gt_xs = []
    gt_ys = []

    gt_data = metadata_dict["images/" + image_paths[idx].split("/")[-1]]["ground_truth"]["gt_parse"]
    x_type = gt_data["x_type"]
    y_type = gt_data["y_type"]

    if gt_data["class"] != graph_type:
        gt_xs, gt_ys = [], []
    else:
        for v in gt_data["value"]:
            # check if v["y"] is not a number
            if gt_data["class"] == "horizontal_bar":
                if math.isnan(float(v["x"])):
                    continue
                gt_xs.append(v["y"])
                gt_ys.append(v["x"])
            else:
                if math.isnan(float(v["y"])):
                    continue
                gt_xs.append(v["x"])
                gt_ys.append(v["y"])

        if graph_type == "horizontal_bar":
            x_type, y_type = y_type, x_type
        if x_type == "categorical":
            gt_xs = [str(x) for x in gt_xs]
        else:
            gt_xs = [float(x) for x in gt_xs]

        if y_type == "categorical":
            gt_ys = [str(y) for y in gt_ys]
        else:
            gt_ys = [float(y) for y in gt_ys]

    random_id = str(uuid.uuid4())[:10]

    ground_truth = pd.DataFrame.from_dict({
        f'{random_id}_x': (gt_xs, gt_data["class"]),
        f'{random_id}_y': (gt_ys, gt_data["class"]),
    }, orient='index', columns=['data_series', 'chart_type']).rename_axis('id')


    # --------- PREDICTION ------------
    pred_xs = []
    pred_ys = []

    if gt_data["class"] != graph_type:
        pred_xs, pred_ys = [1], [1]
    else:
        for x_label_text, value in zip(x_labels_texts, all_values):
            pred_xs.append(x_label_text)
            pred_ys.append(value)

        if x_type == "categorical":
            pred_xs = [str(x) for x in pred_xs]
        else:
            pred_xs = [float(x) for x in pred_xs]

        if y_type == "categorical":
            pred_ys = [str(y) for y in pred_ys]
        else:
            pred_ys = [float(y) for y in pred_ys]

    predictions = pd.DataFrame.from_dict({
        f'{random_id}_x': (pred_xs, graph_type),
        f'{random_id}_y': (pred_ys, graph_type),
    }, orient='index', columns=['data_series', 'chart_type']).rename_axis('id')

    print("prediction graph type: ", graph_type)
    print("ground truth graph type: ", gt_data["class"])

    return benetech_score(ground_truth, predictions), ground_truth, predictions

if not TEST_MODE:
    print(graph_type_predictions[idx])
    score, gt, pred = compute_metrics(idx, all_values, x_labels_texts, graph_type_predictions[idx])
    print("score =", score)
    print("----------")
    print("pred =\n", pred)
    print("----------")
    print("gt =\n", gt)


In [None]:
def predict(idx, is_visualize=False):
    graph_type = graph_type_predictions[idx]

    try:
        value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons = process_labels_polygons(idx, graph_type)
        x_boxes, y_boxes, value_boxes = process_filter_xy_value_boxes(idx, x_boxes, y_boxes, value_boxes)
        x_boxes, y_boxes, x_labels_polygons, y_labels_polygons = mapping_labels_and_value(graph_type, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)

        if graph_type in ["vertical_bar", "horizontal_bar"]:
            value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_bar_graph(
                idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
            )
        elif graph_type == "line":
            value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_line_graph(
                idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
            )
        elif graph_type == "scatter":
            value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_scatter_graph(
                idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
            )
        elif graph_type == "dot":
            value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons, x_labels_texts, all_values = postprocess_dot_graph(
                idx, value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons
            )

        if not TEST_MODE:
            if is_visualize:
                visualize(image_paths[idx], value_boxes, x_boxes, y_boxes, x_labels_polygons, y_labels_polygons)
            score, gt, pred = compute_metrics(idx, all_values, x_labels_texts, graph_type)
            print("Score =", score)
            return score, gt, pred
        else:
            pred_xs = []
            pred_ys = []

            for x_label_text, value in zip(x_labels_texts, all_values):
                pred_xs.append(x_label_text)
                pred_ys.append(value)

            pred_xs = [str(x) for x in pred_xs]
            pred_ys = [str(y) for y in pred_ys]

            if graph_type == "horizontal_bar":
                pred_xs, pred_ys = pred_ys, pred_xs


            image_id = os.path.basename(image_paths[idx]).split(".")[0]
            predictions = pd.DataFrame.from_dict({
                f'{image_id}_x': (";".join([str(x) for x in pred_xs]), graph_type),
                f'{image_id}_y': (";".join([str(x) for x in pred_ys]), graph_type),
            }, orient='index', columns=['data_series', 'chart_type']).rename_axis('id')
            return predictions
    except:
        image_id = os.path.basename(image_paths[idx]).split(".")[0]
        predictions = pd.DataFrame.from_dict({
            f'{image_id}_x': ("0;0", graph_type),
            f'{image_id}_y': ("0;0", graph_type),
        }, orient='index', columns=['data_series', 'chart_type']).rename_axis('id')
        return predictions

In [None]:
preds = []

if not TEST_MODE:
    gts = []

    idx = 0
    while idx < len(image_paths):
        # if graph_type_predictions[idx] == "line":
        print(f"------------ {idx} -------------")
        score, gt, pred = predict(idx)
        preds.append(pred)
        gts.append(gt)

            # if score == 0:
            #     print(pred)
            #     print(gt)
            #     predict(idx, is_visualize=True)
            #     idx += 1
            #     break
        idx += 1

    concat_preds = pd.concat(preds)
    concat_gts = pd.concat(gts)

    score = benetech_score(concat_gts, concat_preds)

    print("----------------------------------------------------------------------------")
    print("AVERAGE Score =", score)
else:
    idx = 0
    while idx < len(image_paths):
        # if graph_type_predictions[idx] == "dot":
        print(f"------------ {idx} -------------")
        pred = predict(idx)
        preds.append(pred)

        idx += 1

    concat_preds = pd.concat(preds)
    


In [None]:
if not TEST_MODE:
    for chart_type in concat_preds["chart_type"].unique():
        chart_gts = concat_gts[concat_gts["chart_type"] == chart_type]
        chart_preds = concat_preds[concat_preds.index.isin(chart_gts.index)]

        print("chart_type =", chart_type, "score =", benetech_score(chart_gts, chart_preds))

In [None]:
# current
# chart_type = line score = 0.6746424764386838
# chart_type = vertical_bar score = 0.8582137215819002
# chart_type = scatter score = 0.4922595596579721
# chart_type = horizontal_bar score = 0.8800508930068293

# single model for each type of chart: 
# chart_type = line score = 0.6809506968108486
# chart_type = vertical_bar score = 0.8574843338854254
# chart_type = scatter score = 0.5517285698807664
# chart_type = horizontal_bar score = 0.8816170433989658
