In [78]:
#import libraries
import os
import json
import math
import csv
from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from shapely.geometry import Polygon as ShapelyPolygon
from matplotlib.collections import PatchCollection
import pandas as pd
import numpy as np
from matplotlib.patches import Rectangle
from tabulate import tabulate

In [79]:
def read_file(path):
    with open(path, 'r') as file:
        # Read the content of the file
        file_content = file.read()
    return file_content

In [80]:
def parse_text_blocks(content):
    lines = content.split('\n')
    data = []
    id_counter = 0
    for line in lines:
        if line.strip() == '':
            continue
        values = line.split()
        if len(values) < 5:
            continue
        x0, y0, x1, y1 = map(int, values[:4])
        text = ' '.join(values[4:-1])  # Exclude the last value which represents text_role
        text_role = values[-1]
        polygon = {
            'x0': x0,
            'x1': x1,
            'x2': x1,
            'x3': x0,
            'y0': y0,
            'y1': y0,
            'y2': y1,
            'y3': y1
        }
        data.append({'id': id_counter, 'polygon': polygon, 'text': text, 'text_role': text_role})
        id_counter += 1
    return data

In [81]:
def parse_tick_data(content):
    tick_data = {'x-axis': [], 'y-axis': []}
    lines = content.split('\n')
    id_counter = 0
    for line in lines:
        if line.strip() == '':
            continue
        values = line.split()
#         print(f"values = {values}")
        x0, y0, x1, y1, text_role, axis_type, x, y = values
        x0 = int(x0)
        y0 = int(y0)
        x1 = int(x1)
        y1 = int(y1)
#         x0, y0, x1, y1 = map(int, x0)
        x, y = map(int, (x, y))
        tick_point = {'x': x, 'y': y}
        tick_info = {'id': id_counter, 'tick_pt': tick_point}
        if 'x_axis' in axis_type:
            tick_data['x-axis'].append(tick_info)
        else:
            tick_data['y-axis'].append(tick_info)
        id_counter += 1
    return tick_data

In [82]:
def calculate_intersection_over_area(detected_bar, ground_truth_bar):
    # Calculate coordinates of intersection rectangle
    x_left = max(detected_bar['x0'], ground_truth_bar['x0'])
    y_bottom = max(detected_bar['y0'], ground_truth_bar['y0'])
    x_right = min(detected_bar['x0'] + detected_bar['width'], ground_truth_bar['x0'] + ground_truth_bar['width'])
    y_top = min(detected_bar['y0'] + detected_bar['height'], ground_truth_bar['y0'] + ground_truth_bar['height'])
    
    # Calculate intersection area
    intersection_area = max(0, x_right - x_left) * max(0, y_top - y_bottom)
    
    # Calculate IoA
    detected_area = detected_bar['width'] * detected_bar['height']
    ground_truth_area = ground_truth_bar['width'] * ground_truth_bar['height']
    IoA = intersection_area / (detected_area + ground_truth_area - intersection_area)
    
    return IoA

def map_bars(detected_bars, ground_truth_bars):
    mapped_bars = []
    gt_bars = {}
    dt_bars = {}
    idd = 1
    for gt_bar in ground_truth_bars:
        max_IoA = 0
        mapped_bar = None
        for dt_bar in detected_bars:
#             print(f"dt_bar = {dt_bar}")
#             print(f"gt_bar = {gt_bar}")
            IoA = calculate_intersection_over_area(dt_bar, gt_bar)
#             print(IoA)
            if IoA > max_IoA:
                max_IoA = IoA
                mapped_bar = dt_bar
        if mapped_bar is not None:
            gt_bars[idd] = gt_bar
            dt_bars[idd] = mapped_bar
            idd += 1
            mapped_bars.append((gt_bar, mapped_bar))
            detected_bars.remove(mapped_bar)
    
    return gt_bars,dt_bars

In [83]:
def parse_bar_data(content):
    bars = []
    lines = content.split('\n')
    for line in lines:
        if line.strip() == '':
            continue
        values = line.split()
        x0, y0, x1, y1 = map(int, values)
        height = abs(y1 - y0)
        width = abs(x1 - x0)
        bars.append({'height': height, 'width': width, 'x0': x0, 'y0': y0})
    return {'bars': bars}

In [120]:
def is_numerical(text):
    try:
        # Remove commas and other special characters
        cleaned_text = ''.join(char for char in text if char.isdigit() or char in ['.', '-'])
        float(cleaned_text)
        return True
    except ValueError:
        return False


In [134]:
def find_xaxis_label(bar, xaxis_labels):
    bar_center_x = (bar[0][1]['y0'] + bar[0][1]['y0'] + bar[0][1]['height']) / 2
    for label in xaxis_labels:
        if label['polygon']['y0'] <= bar_center_x <= label['polygon']['y1']:
            return label['text']
    return None

In [84]:
def polygon_to_key(polygon):
    # Convert the polygon to a hashable tuple representation
    return tuple(sorted(polygon.items()))


def find_polygon_center(vertices):
    num_vertices = len(vertices)

    if num_vertices < 3:
        raise ValueError("A polygon must have at least 3 vertices.")

    sum_x = sum(v[0] for v in vertices)
    sum_y = sum(v[1] for v in vertices)

    center_x = sum_x / num_vertices
    center_y = sum_y / num_vertices

    return center_x, center_y

def assign_polygons_to_ticks(xaxis_ticks, yaxis_ticks, polygons):
    assigned_xaxis_polygons = []
    assigned_yaxis_polygons = []
    box_tick_mapping = []


    for polygon in polygons:
        idd = polygon['id']
        text_role = polygon['text_role']
        
#         temp =  text_roles.loc[text_roles['text'] == polygon['text'],'role'].values[0]
        if text_role != 'tick_label':
            continue
            
        polygon_vertices = [
            (polygon['polygon'][f'x{i}'], polygon['polygon'][f'y{i}']) for i in range(4)
        ]
        center = find_polygon_center(polygon_vertices)

        # Find the nearest tick label (x, y) coordinate
#         x_distances = [abs(center[0] - tick['tick_pt']['x']) for tick in xaxis_ticks]
#         y_distances = [abs(center[1] - tick['tick_pt']['y']) for tick in yaxis_ticks]
        
        x_distances = 0
        y_distances = 0
        
        x1 = center[0]
        y1 = center[1]
        
        min_xtick = []
        min_ytick = []
        
        for item in xaxis_ticks:
            x2 = item['tick_pt']['x']
            y2 = item['tick_pt']['y']
            dist = math.sqrt((x2-x1) ** 2 + (y2-y1)**2)
            
            if x_distances == 0:
                x_distances = dist
                min_xtick.append((x2,y2))
            elif x_distances > dist:
                x_distances = dist
                min_xtick[0] = ((x2,y2))
                
        
        for item in yaxis_ticks:
            x2 = item['tick_pt']['x']
            y2 = item['tick_pt']['y']
            dist = math.sqrt((x2-x1)**2 + (y2-y1)**2)
            
            if y_distances == 0:
                y_distances = dist
                min_ytick.append((x2,y2))
            elif y_distances > dist:
                y_distances =  dist
                min_ytick.append((x2,y2))
                
        if len(xaxis_ticks) == 0:
            assigned_yaxis_polygons.append(polygon)
            box_tick_mapping.append((polygon, min_ytick))
            continue
        
    
        if x_distances <= y_distances:
            assigned_xaxis_polygons.append(polygon)
            box_tick_mapping.append((polygon,min_xtick))
        else:
            assigned_yaxis_polygons.append(polygon)
            box_tick_mapping.append((polygon,min_ytick))

    return assigned_xaxis_polygons, assigned_yaxis_polygons, box_tick_mapping

In [85]:
def per_pixel_diff(yaxis_ticks,yaxis_polygons):
#     print(f"yaxis_ticks = {yaxis_ticks}")
    yaxis_polygons = sorted(yaxis_polygons, key=lambda x: x['polygon']['x0'])
    pixel_diff = abs(yaxis_ticks[0]['tick_pt']['x'] - yaxis_ticks[1]['tick_pt']['x'])
    yy1 = float(yaxis_polygons[0]['text'].replace("%", ""))
    yy2 = float(yaxis_polygons[1]['text'].replace("%",""))
#     print(f"yy1 = {yy1} yy2 = {yy2} yaxis_ticks[0] = {yaxis_ticks[0]} yaxis_ticks[1] = {yaxis_ticks[1]}")
    val_diff_bw_two_coord = abs(yy1 - yy2)
    one_pixel_val = val_diff_bw_two_coord / pixel_diff
    return one_pixel_val

In [86]:
def sort_polygons(polygons):
    sorted_polygons = sorted(polygons, key=lambda x: x['polygon']['y0'], reverse=True)
    return sorted_polygons

In [87]:
def sort_bars(bars):
    sorted_bars = sorted(bars['bars'], key=lambda x: x['y0'], reverse=True)
    return sorted_bars

In [88]:
def predict_bar_values2(bars, yaxis_ticks,yaxis_polygons):
    predicted_y_values = []
    
    sorted_bars = sorted(bars.items(), key=lambda x: x[1]['y0'], reverse=True)
    
#     print(yaxis_ticks)
    
    one_pixel_val = per_pixel_diff(yaxis_ticks,yaxis_polygons)
    
    for bar in sorted_bars:
#         print(bar[1])
        value = bar[1]['width'] * one_pixel_val
        predicted_y_values.append((bar,value))
    
    return predicted_y_values;

In [89]:
def find_legend_pairs(polygons):
    legend_pairs = []
    for item in polygons:
        if item['text_role'] == 'legend_label':
            legend_pairs.append(item)
    
    return legend_pairs

In [157]:
def data_extraction(predicted_y_values, polygons,legend_pairs):
    sorted_polygons = sort_polygons(polygons)
    
    data_mapping = {}
    
    for polygon in polygons:
        data_mapping[polygon['text']] = []
    
    
    for item in predicted_y_values:
        bar = item[0][1]
        min_dist = 0
        y_mapped_polygon = {}
        for polygon in polygons:
            polygon_vertices = [
            (polygon['polygon'][f'x{i}'], polygon['polygon'][f'y{i}']) for i in range(4)]
            # center = find_polygon_center(polygon_vertices)
            center =  polygon['polygon']['x1'], (polygon['polygon']['y0'] + polygon['polygon']['y1'])/2
            
            x1, y1 = center[0], center[1]
            
            x2 = bar['x0']
            y2 = (bar['y0'] + bar['height'] /2)
            
            dist =  math.sqrt((x2-x1)**2 + (y2-y1)**2)
            print(f"item = {item}")
            if min_dist == 0 or min_dist > dist:
                y_mapped_polygon[item[0][0]] = polygon
                min_dist = dist
                
        
        data_mapping[y_mapped_polygon[item[0][0]]['text']].append(item)
            
    return data_mapping

In [158]:
def parse_gt_bar_data(content):
    bars = []
    lines = content.split('\n')
    for line in lines:
        if line.strip() == '':
            continue
        values = line.split()
        x0, y0, x1, y1 = map(int, values[:-1])
        value = float(values[-1])
        height = abs(y1 - y0)
        width = abs(x1 - x0)
        bars.append({'height': height, 'width': width, 'x0': x0, 'y0': y0, 'value':value})
    return {'bars': bars}

In [159]:
image_path = "./images/"
text_detect_role_path = "./text_detect_role_classify/"
bar_plot_path = './Bar_plot_coordinates/'
axis_analysis_path = './Axis_Analysis/'
gt_coordinates_path = "./GT_coordinate_value/"


In [264]:
def dataTable(filename, text_result, data_mapping):
    legend_pairs = find_legend_pairs(text_result)

    num_of_legends = len(legend_pairs)

    headers = []


    headers.append("")

    for item in legend_pairs:
        headers.append(item['text'])


    if num_of_legends == 0:
        headers.append("value")

    table = []

    for item in data_mapping.items():
        temp = []

        temp.append(item[0])
        cnt = 0

        if len(item[1]) > num_of_legends and num_of_legends != 0:
            print(f"dataTable = {filename}")
            continue
        for value in item[1]:
            temp.append(value[1])
            cnt += 1
   
        while cnt < num_of_legends:
            temp.append(0)
            cnt += 1

        table.append(temp)

    
    csv_filename = filename.rsplit('.', 1)[0] + '.csv'
    csv_filename = './csv_files/' + csv_filename

    with open(csv_filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        writer.writerow(headers)

        writer.writerows(table)
        
    return headers,table
          

In [312]:
def data_extraction_for_GT(gt_bar_results, polygons,legend_pairs):
    sorted_polygons = sort_polygons(polygons)
    sorted_bars = sorted(gt_bar_results['bars'], key=lambda x: x['y0'])
    sorted_legends = sorted(legend_pairs, key=lambda x: x['polygon']['x0'], reverse=True)

    num_of_legends = len(legend_pairs)

    data_mapping = {}

    for polygon in polygons:
        data_mapping[polygon['text']] = []

    n = 0

    if num_of_legends == 0:
        for i in range(0, len(sorted_bars)):
            label = sorted_polygons[n]['text']

            data_mapping[label].append(sorted_bars[i]['value'])


        return data_mapping


    n = 0

    # print(f"len(sorted_bars['bars']) = {sorted_bars}")
    for i in range(0,len(sorted_bars),num_of_legends):
        # print(f"sorted_polygons = {sorted_polygons}")
        label = sorted_polygons[n]['text']

        for j in range(0,num_of_legends):
            # print(sorted_bars[i+j])
            data_mapping[label].append(sorted_bars[i+j]['value'])

        n += 1

    return data_mapping

In [313]:
def convert_table(table):
    data_mapping = {}

    for row in table:
        data_mapping[row[0]] = row[1:]

    return data_mapping


In [314]:
def convert_data_for_metrics(data_mapping_gt, data_mapping_pred):
    predictions = []
    ground_truth = []

    for key, value in data_mapping_gt.items():
        # print(f"value = {len(value)}")
        for item in value:
            ground_truth.append((key,item))

    for key, value in data_mapping_pred.items():
        for item in value:
            predictions.append((key,item))

    return predictions, ground_truth

In [316]:
score_val = 0.0
counter = 0

for filename in os.listdir(text_detect_role_path):
#     print(filename)
    image_name = image_path + filename
    text_detect_filename = text_detect_role_path + filename
    bar_plot_filename = bar_plot_path + filename
    axis_analysis_filename = axis_analysis_path + filename
    gt_coordinates_filename = gt_coordinates_path + filename
    
    # text detect role
    file_content = read_file(text_detect_filename)
    result = parse_text_blocks(file_content)

    #axis analysis
    axis_file_content = read_file(axis_analysis_filename)
    axis_result = parse_tick_data(axis_file_content)

    #bar analysis
    bar_file_content = read_file(bar_plot_filename)
    bar_results = parse_bar_data(bar_file_content)

    #gt_coordinates
    gt_coordinates = read_file(gt_coordinates_filename)
    gt_bar_results = parse_gt_bar_data(gt_coordinates)
    
    try:
        gt_bars, dt_bars = map_bars(bar_results['bars'], gt_bar_results['bars'])

        xaxis_ticks = axis_result['x-axis']
        yaxis_ticks = axis_result['y-axis']
        xaxis_polygons, yaxis_polygons, box_tick_mapping = assign_polygons_to_ticks(xaxis_ticks, yaxis_ticks, result)

        temp_xaxis_polygon = xaxis_polygons
        for polygon in temp_xaxis_polygon:
            if polygon['text'] == '0':
                xaxis_polygons.remove(polygon)


        predicted_y_values = predict_bar_values2(dt_bars, yaxis_ticks,yaxis_polygons)

        legend_pairs = find_legend_pairs(result)
        data_mapping = data_extraction(predicted_y_values, xaxis_polygons,legend_pairs)

        headers, table = dataTable(filename, result, data_mapping )

        
        data_mapping_pred = convert_table(table)

        
        data_mapping_gt = data_extraction_for_GT(gt_bar_results,xaxis_polygons, legend_pairs)

        counter += 1
        predictions, gt = convert_data_for_metrics(data_mapping_pred, data_mapping_gt)
    
    except:
        print(f"{filename}")

#     print(data_mapping)
    

dataTable = PMC7177019___2.txt
dataTable = PMC7177019___2.txt
PMC7154115___1.txt
dataTable = PMC7079625___g005.txt
dataTable = PMC6084502___2_OC.txt
dataTable = PMC6084502___2_OC.txt
dataTable = PMC6084502___2_OC.txt
PMC6114288___g006.txt
PMC6635378___3_HTML.txt
