In [None]:
%load_ext autoreload
%autoreload 2

import csv
import json
import os
import pandas as pd
import papermill as pm
import re
import scrapbook as sb
import uuid

from functions import gpt

from tqdm.notebook import tqdm_notebook
tqdm_notebook().pandas()

pd.set_option('display.max_colwidth', None)

from pylatexenc.latexwalker import LatexWalker, LatexMathNode, LatexMacroNode, LatexGroupNode, LatexCharsNode, LatexEnvironmentNode

In [None]:
base_name = "2021_Hashimoto_Neural_ODE_and_holographic_QCD_PUB"
project_folder = "diygenomics-projects"
sub_category = "math"
work_bucket = "AdS-CFT"
external_id = "2023_05_22_92dc0613b4493d7b5847g"

In [None]:
data_path = os.getenv('DATA_PATH')
file_path = lambda *args: os.path.join(data_path, project_folder, sub_category, work_bucket,
                                       base_name, 'mathpix', *args)

model = 'gpt-4' # 'gpt-3.5-turbo' # 'gpt-4'
openai_model = 'text-embedding-ada-002'
index_col = 'uuid'

input_file = f'{external_id}.lines.json'

In [None]:
with open(file_path(input_file), 'r') as f:
    data = json.load(f)

In [None]:
pattern = r'^\([A-Za-z]?\d+(\.\d+)?\)$'

text_nodes = []
captured_math = []

for page in data['pages']:
    for line in page['lines']:
        text = line['text']
        
        if re.match(pattern, text):
            current_annotations = []
            current_annotations.append(text)
            if len(text_nodes) > 0:
                while len(text_nodes) > 0: #  and re.match(pattern, text_nodes[offset]['text'])
                    previous_line = text_nodes.pop()
                    if re.match(pattern, previous_line['text']):
                        current_annotations.append(previous_line['text'])
                    else:
                        captured_math.append((previous_line['text'], current_annotations))
                        break
        
        text_nodes.append(line)

In [None]:
begin_gathered_pattern = r'\\begin{gathered}\n\\'
end_gathered_pattern = r'\\end{gathered}\n\\'
begin_gathered_no_newline_pattern = r'\\begin{gathered}'
end_gathered_no_newline_pattern = r'\\end{gathered}'
begin_gathered_bracket_pattern = '\[\n'
end_gathered_bracket_pattern = '\n]'

modified_tuples = []
tuples = captured_math
for i in range(len(tuples)):
    if len(tuples[i][1]) > 1:
        prior_tuple = modified_tuples.pop()
        original_math = prior_tuple[0]
        original_annotations = tuples[i][1][::-1]
        maths = original_math.split("\\\\")
        for index, math in enumerate(maths):
            math = re.sub(begin_gathered_pattern, '', math)
            math = re.sub(end_gathered_pattern, '', math)
            math = re.sub(begin_gathered_no_newline_pattern, '', math)
            math = re.sub(end_gathered_no_newline_pattern, '', math)
            math = re.sub(begin_gathered_bracket_pattern, '', math)
            math = re.sub(end_gathered_bracket_pattern, '', math)
            modified_tuples.append((math, original_annotations[index]))
    else:
        modified_tuples.append((tuples[i][0], tuples[i][1][0]))

In [None]:
uuids = [uuid.uuid4() for _ in range(len(modified_tuples))]

df = pd.DataFrame(modified_tuples, columns=['math', 'paper_annotation'], index=uuids)
df = df.rename_axis('uuid', axis='index')

In [None]:
df.to_csv(file_path('extracted_annotated_math.csv'), quoting=csv.QUOTE_MINIMAL)

In [None]:
# text_nodes = []

# for page in data['pages']:
#     for line in page['lines']:
#         text = line['text']
        
#         if re.match(pattern, text):
#             if previous_line is not None:
#                 if re.match(pattern, previous_line['text']):
#                     print(previous_line['text'])
        
#         previous_line = line