In [None]:
import os
import json

In [None]:
TRAIN_IMAGE_FOLDER = "./train/images/"
TRAIN_LABEL_FOLDER = "./train/annotations/"

TEST_IMAGE_FOLDER = "./test/images/"
TEST_LABEL_FOLDER = "./test/annotations/"

In [None]:
image_paths = []
label_paths = []

for file in os.listdir(TRAIN_IMAGE_FOLDER):
    image_paths.append(os.path.join(TRAIN_IMAGE_FOLDER, file))

for file in os.listdir(TRAIN_LABEL_FOLDER):
    label_paths.append(os.path.join(TRAIN_LABEL_FOLDER, file))

print("Number of images: ", len(image_paths))
print("Number of labels: ", len(label_paths))

In [None]:
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt

from matplotlib.pyplot import figure


ground_truths = []
keypoints_ground_truths = {}

for path in tqdm(label_paths[50:]):
    with open(path) as json_file:
        data = json.load(json_file)

    image_name = os.path.basename(path)[:-5] + ".jpg"

    for item in data["data-series"]:
        for k in item:
            if isinstance(item[k], float):
                item[k] = round(item[k], 4) 

    x_axis_ids = [item["id"] for item in data["axes"]["x-axis"]["ticks"]]
    y_axis_ids = [item["id"] for item in data["axes"]["y-axis"]["ticks"]]

    x_labels = [item for item in data["text"] if item["id"] in x_axis_ids]
    x_labels.sort(key=lambda x: x["polygon"]["x0"])
    x_labels = [item["text"].strip() for item in x_labels]

    y_labels = [item for item in data["text"] if item["id"] in y_axis_ids]
    y_labels.sort(key=lambda x: x["polygon"]["y0"])
    y_labels = [item["text"].strip() for item in y_labels]

    gt = {
        "file_name": os.path.join("images", image_name),
        "ground_truth": {
            "gt_parse": {
                "class": data["chart-type"],
                "value": data["data-series"],
                "x_type": data["axes"]["x-axis"]["values-type"],
                "y_type": data["axes"]["y-axis"]["values-type"],
                "x_labels": x_labels,
                "y_labels": y_labels,
            }
        },
        "source": data["source"],
    }
    ground_truths.append(gt)

    # TODO: for multi stages pipeline, we need to save the points of each axis and the points of each data series
    # we need to calculate pixel value of each point in data-series based on axis points
    x_data = []
    for tick in data["axes"]["x-axis"]["ticks"]:
        x_data.append(
            {
                "id": tick["id"],
                "x": tick["tick_pt"]["x"],
                "y": tick["tick_pt"]["y"],
            }
        )

    y_data = []
    for tick in data["axes"]["y-axis"]["ticks"]:
        y_data.append(
            {
                "id": tick["id"],
                "x": tick["tick_pt"]["x"],
                "y": tick["tick_pt"]["y"],
            }
        )
    
    # add label to x_data and y_data
    x_id_to_text = {item["id"]: item["text"] for item in data["text"] if item["id"] in x_axis_ids}
    y_id_to_text = {item["id"]: item["text"] for item in data["text"] if item["id"] in y_axis_ids}

    for item in x_data:
        item["text"] = x_id_to_text[item["id"]]
    
    for item in y_data:
        item["text"] = y_id_to_text[item["id"]]

    x_data_dict = {item["text"]: item for item in x_data}
    y_data_dict = {item["text"]: item for item in y_data}

    # calculate pixel values
    x_type = data["axes"]["x-axis"]["values-type"]
    y_type = data["axes"]["y-axis"]["values-type"]

    skip = False
    for item in data["data-series"]:
        try:
            if x_type == "numerical":
                x_value = float(item["x"])
                x_data.sort(key=lambda x: abs(float(x["text"].replace(",", "").replace(".", "").replace("%", "")) - x_value))
                x1 = x_data[0]
                x2 = x_data[1]
                x1_value = float(x1["text"].replace(",", "").replace(".", "").replace("%", ""))
                x2_value = float(x2["text"].replace(",", "").replace(".", "").replace("%", ""))
                if x_value > float(x1["text"]):
                    x_pixel = (x_value - x1_value) / (x2_value - x1_value) * (x2["x"] - x1["x"]) + x1["x"]
                else:
                    x_pixel = x1["x"] - (x1_value - x_value) / (x2_value - x1_value) * (x2["x"] - x1["x"])
            else: # categorical
                x_value = item["x"]
                x_pixel = x_data_dict[x_value]["x"]
            
            if y_type == "numerical":
                y_value = float(item["y"])
                y_data.sort(key=lambda x: abs(float(x["text"].replace(",", "").replace(".", "").replace("%", "")) - y_value))
                y1 = y_data[0]
                y2 = y_data[1]
                y1_value = float(y1["text"].replace(",", "").replace(".", "").replace("%", ""))
                y2_value = float(y2["text"].replace(",", "").replace(".", "").replace("%", ""))

                if y_value > y1_value:
                    y_pixel = (y_value - y1_value) / (y2_value - y1_value) * (y2["y"] - y1["y"]) + y1["y"]
                else:
                    y_pixel = y1["y"] - (y1_value - y_value) / (y2_value - y1_value) * (y2["y"] - y1["y"])
            else: # categorical
                y_value = item["y"]
                y_pixel = y_data_dict[y_value]["y"]

            item["x_pixel"] = x_pixel
            item["y_pixel"] = y_pixel

        # DEBUG: plot data-series points to image
        # image_path = os.path.join(TRAIN_IMAGE_FOLDER, image_name)
        # image = cv2.imread(image_path)

        # for item in data["data-series"]:
        #     cv2.circle(image, (int(item["x_pixel"]), int(item["y_pixel"])), 5, (0, 0, 255), -1)

        # cv2.imwrite(os.path.join("debug", image_name), image)
        except:
            skip = True
            break

    if not skip:
        keypoints_ground_truths[image_name] = data["data-series"]

In [None]:
# figure(figsize=(8, 6), dpi=80)

# plt.imshow(image)
# plt.show()

In [None]:
num_item_from_source = {}
for item in ground_truths:
    if item["source"] not in num_item_from_source:
        num_item_from_source[item["source"]] = 0
    num_item_from_source[item["source"]] += 1

print("Number of items from each source: ", num_item_from_source)

In [None]:
# item that has source = "extracted"
source_image_paths = []
for item in ground_truths:
    if item["source"] == "extracted":
        source_image_paths.append(os.path.join("train", item["file_name"]))

In [None]:
from PIL import Image

Image.open(source_image_paths[10])

In [None]:
# split ground truths into train and validation, validation images is all images from source = "extracted"
train_ground_truths = []
val_ground_truths = []

for item in ground_truths:
    if item["source"] == "extracted":
        val_ground_truths.append(item)
    else:
        train_ground_truths.append(item)

print("Number of train ground truths: ", len(train_ground_truths))
print("Number of val ground truths: ", len(val_ground_truths))


In [None]:
train_file_names = set([os.path.basename(item["file_name"]) for item in train_ground_truths])
val_file_names = set([os.path.basename(item["file_name"]) for item in val_ground_truths])

In [None]:
# split keypoints ground truths into train and validation
train_keypoints_ground_truths = {}
val_keypoints_ground_truths = {}

for image_name, keypoints in keypoints_ground_truths.items():
    if image_name in train_file_names:
        train_keypoints_ground_truths[image_name] = keypoints
    elif image_name in val_file_names:
        val_keypoints_ground_truths[image_name] = keypoints

# save to json 
with open("train_keypoints_ground_truths.json", "w") as f:
    json.dump(train_keypoints_ground_truths, f)

with open("val_keypoints_ground_truths.json", "w") as f:
    json.dump(val_keypoints_ground_truths, f)

In [None]:
# !pip install jsonlines

In [None]:
# copy all images to train/images and val/images
import shutil

# make dirs for validation
os.makedirs("./validation/images/", exist_ok=True)

for gt in tqdm(val_ground_truths):
    shutil.copy2(os.path.join("./train", gt["file_name"]), "./validation/images/")


In [None]:
# save to jsonl file
import jsonlines

with jsonlines.open("./train/metadata.jsonl", mode="w") as writer:
    writer.write_all(train_ground_truths)

with jsonlines.open("./validation/metadata.jsonl", mode="w") as writer:
    writer.write_all(val_ground_truths)

In [None]:
max_len = 0
lens = []

for gt in ground_truths:
    l = len(gt["ground_truth"]["gt_parse"]["value"])
    lens.append(l)
    if l > max_len:
        max_len = l

print("Max number of data-series: ", max_len)

In [None]:
# draw histogram of length of data-series
import matplotlib.pyplot as plt
plt.hist(lens, bins=10)

In [None]:
import numpy as np
import Levenshtein as lev
from sklearn.metrics import r2_score


def sigmoid2(x):
    return 2 - 2 / (1 + np.exp(-x))


def rmse(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    return np.sqrt(np.mean(np.square(y_true - y_pred)))


def nrmse(y_true, y_pred):
    if len(y_true) != len(y_pred):
        return 0
    # y_bar = np.array([np.mean(y_true) for _ in range(len(y_true))])
    # return sigmoid2(rmse(y_true, y_pred) / rmse(y_true, y_bar))
    return sigmoid2(1 - r2_score(y_true, y_pred))


def nlev(y_true, y_pred):
    if len(y_true) != len(y_pred):
        return 0
    return sigmoid2(sum([lev.distance(y_t, y_p) for y_t, y_p in zip(y_true, y_pred)]) / sum([len(y) for y in y_true]))


def calculate_score(pred, gt):
    if pred["class"] != gt["class"]:
        return 0

    if len(pred["value"]) != len(gt["value"]):
        return 0
    
    if len(pred["value"]) == 0 and len(gt["value"]) == 0:
        return 1

    pred_xs = [x["x"] for x in pred["value"]]
    pred_ys = [x["y"] for x in pred["value"]]

    gt_xs = [x["x"] for x in gt["value"]]
    gt_ys = [x["y"] for x in gt["value"]]

    score = 0
    if isinstance(gt_xs[0], str):
        score += nlev(pred_xs, gt_xs)
    else:
        score += nrmse(pred_xs, gt_xs)

    if isinstance(gt_ys[0], str):
        score += nlev(pred_ys, gt_ys)
    else:
        score += nrmse(pred_ys, gt_ys)

    return score / 2

In [None]:
calculate_score(
    {
        'class': 'scatter',
        'value': [
            {'x': 1949.4201, 'y': 66.683},
            {'x': 1954.6107, 'y': 66.2785},
            {'x': 1959.9936, 'y': 65.6718},
            {'x': 1964.7997, 'y': 64.0537},
        ]
    },
    {
        'class': 'scatter',
        'value': [
            {'x': 1949.4201, 'y': 6.683},
            {'x': 1954.6107, 'y': 66.2785},
            {'x': 1959.9936, 'y': 65.6718},
            {'x': 1964.7997, 'y': 64.0537},
        ]
    },
)