In [None]:
%load_ext autoreload
%autoreload 2

import csv
import json
import os
import pandas as pd
import papermill as pm
import re
import scrapbook as sb
import uuid

from functions import gpt

from tqdm.notebook import tqdm_notebook
tqdm_notebook().pandas()

pd.set_option('display.max_colwidth', None)

from pylatexenc.latexwalker import LatexWalker, LatexMathNode, LatexMacroNode, LatexGroupNode, LatexCharsNode, LatexEnvironmentNode

In [None]:
# base_name = "2021_Hashimoto_Neural_ODE_and_holographic_QCD_PUB"
# project_folder = "diygenomics-projects"
# sub_category = "math"
# work_bucket = "AdS-CFT"
# external_id = "2023_05_22_92dc0613b4493d7b5847g"

In [None]:
data_path = os.getenv('DATA_PATH')
file_path = lambda *args: os.path.join(data_path, project_folder, sub_category, work_bucket,
                                       base_name, 'mathpix', *args)

index_col = 'uuid'

input_file = f'{external_id}.lines.json'

In [None]:
with open(file_path(input_file), 'r') as f:
    data = json.load(f)

In [None]:
begin_gathered_pattern = r'\\begin{gathered}\n\\'
end_gathered_pattern = r'\\end{gathered}\n\\'
begin_gathered_no_newline_pattern = r'\\begin{gathered}'
end_gathered_no_newline_pattern = r'\\end{gathered}'
begin_gathered_bracket_pattern = '\[\n'
end_gathered_bracket_pattern = '\n]'

def strip_latex_containers(math):
    math = re.sub(begin_gathered_pattern, '', math)
    math = re.sub(end_gathered_pattern, '', math)
    math = re.sub(begin_gathered_no_newline_pattern, '', math)
    math = re.sub(end_gathered_no_newline_pattern, '', math)
    math = re.sub(begin_gathered_bracket_pattern, '', math)
    math = re.sub(end_gathered_bracket_pattern, '', math)
    return math

def is_index_in_list(index, lst):
    return 0 <= index < len(lst)

In [None]:
pattern = r'^\([A-Za-z]?(\.\d+|\d+(\.\d+)?)\)$'

math_with_annotation = {}
found_annotations = []

for page in data['pages']:
    for line_idx, line in enumerate(page['lines']):
        text = line['text']
        
        stripped_text = text.strip('$')
        if re.match(pattern, stripped_text) and not stripped_text in found_annotations:
            current_annotations = [stripped_text]
            if line_idx > 0:
                math_block = page['lines'][line_idx - 1]['text']

            for next_line in page['lines'][line_idx + 1:]:
                next_text = next_line['text'].strip('$')
                if re.match(pattern, next_text):
                    current_annotations.append(next_text)
                else:
                    if len(current_annotations) > 1:
                        maths = math_block.split("\\\\")
                        for index, math in enumerate(maths):
                            math = strip_latex_containers(math)
                            try:
                                math_with_annotation[current_annotations[index]] = math
                            except Exception as e:
                                ms = math_block.split(',')
                                for j, m in enumerate(ms):
                                    if is_index_in_list(j, current_annotations):
                                        if not current_annotations[j] in math_with_annotation:
                                            math_with_annotation[current_annotations[j]] = m
                                        # math_with_annotation.append((m, current_annotations[j]))
                                    # except Exception as e:
                                    #     print(f'{m} {current_annotations}')
                    else:
                        math_with_annotation[current_annotations[0]] = math_block
                        # math_with_annotation.append((math_block, current_annotations[0]))
                    found_annotations = found_annotations + current_annotations
                    break

In [None]:
uuids = [uuid.uuid4() for _ in range(len(math_with_annotation))]

df = pd.DataFrame(list(math_with_annotation.items()), columns=['paper_annotation', 'math'], index=uuids)
df = df.rename_axis('uuid', axis='index')

In [None]:
df.to_csv(file_path('extracted_annotated_math.csv'), quoting=csv.QUOTE_MINIMAL)

In [None]:
sb.glue('status', 'completed')

In [None]:
# text_nodes = []

# for page in data['pages']:
#     for line in page['lines']:
#         text = line['text']
        
#         if re.match(pattern, text):
#             if previous_line is not None:
#                 if re.match(pattern, previous_line['text']):
#                     print(previous_line['text'])
        
#         previous_line = line