In [None]:
import matplotlib.pyplot as plt
import os
import pandas as pd
import papermill as pm
import re
import scrapbook as sb

from IPython.display import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from pylatexenc.latex2text import LatexNodes2Text
# from matplotlib.font_manager import FontProperties

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

In [None]:
# base_name = "2021_Hashimoto_Neural_ODE_and_holographic_QCD_PUB"
# project_folder = "diygenomics-projects"
# sub_category = "math"
# work_bucket = "AdS-CFT"
# external_id = "2023_05_22_92dc0613b4493d7b5847g"
# human_title = '2021 Hashimoto Neural ODE'

In [None]:
data_path = os.getenv('DATA_PATH')
file_path = lambda *args: os.path.join(data_path, project_folder, sub_category, work_bucket,
                                       base_name, 'mathpix', *args)

index_col = 'uuid'

input_file = 'extracted_annotated_math.csv'
output_file = 'extracted_annotated_math.csv'

if not os.path.exists(file_path('math_images')):
    os.makedirs(file_path('math_images'))

In [None]:
df = pd.read_csv(file_path(input_file), index_col=index_col)

In [None]:
def clean_math(math):
    math = math.strip('\n')
    math = math.replace("\\begin{aligned}", "")
    math = math.replace("\\end{aligned}", "")
    # math = math.replace('& \\', '')
    math = math.replace('\\[', '')
    math = math.replace('\\]', '')
    math = math.replace('\\\\', '\\')

    math = math.strip().strip('\n')
    math = remove_text_command(math)
    math = math.lstrip().rstrip().rstrip(',').rstrip('.')
    math = re.sub(r'\\(?=\b\w\b)', '', math)
    ## TODO pass in extra/paper specific removals
    if math.startswith("ansatz 1:"):
        math = math[len("ansatz 1:"):]
    
    if math.startswith("{ ansatz 2: }"):
        math = math[len("{ ansatz 2: }"):]
        
    math = math.replace('$', '') 
    return math

def remove_text_command(math):
    return re.sub(r'\\text\{[^}]*\}', '', math).replace('\\text', '').replace('&', '')

def fallback_img(math, image_name, image_title=None, dpi=300, fig_width=1, fig_height=1):
    math = LatexNodes2Text().latex_to_text(math)
    maths = math.split('\n')
    
    fig = Figure(figsize=(fig_width, fig_height))
    FigureCanvas(fig)
    ax = fig.subplots(1,1)
    
    for index, math_piece in enumerate(maths):
        math_piece = math_piece.rstrip('\\')
        offset = 0.5 - index
        if len(math_piece) > 0:
            ax.text(0.5, offset, f"${math_piece}$", size=30, va='center', ha='center')
    if image_title != None:
        ax.text(0.5, offset - 1, image_title, size=30, va='center', ha='center')
    ax.axis('off')
    fig.savefig(file_path('math_images', image_name), dpi=dpi, bbox_inches='tight')

def latex_to_img(row, dpi=300, fig_width=1, fig_height=1):
    paper_annotation = row['paper_annotation'].replace('(', '').replace(')', '')
    image_name = f'{paper_annotation}_{row.name}.png'
    display_image_name = f'display_{paper_annotation}_{row.name}.png'
    image_title = f"{paper_annotation} - {human_title}"
    
    math = clean_math(row['math'])
    maths = math.split('\n')
    
    fig = Figure(figsize=(fig_width, fig_height))
    FigureCanvas(fig)
    ax = fig.subplots(1,1)
    try:
        for index, math_piece in enumerate(maths):
            math_piece = math_piece.rstrip('\\')
            if len(math_piece) > 0:
                ax.text(0.5, 0.5 - index, f"${math_piece}$", size=30, va='center', ha='center')
        ax.axis('off')
        fig.savefig(file_path('math_images', image_name), dpi=dpi, bbox_inches='tight')
    except Exception as e:
        print(e)
        print(math)
        fallback_img(math, image_name, None, dpi, fig_width, fig_height)


    fig = Figure(figsize=(fig_width, fig_height))
    FigureCanvas(fig)
    ax = fig.subplots(1,1)
    try:
        for index, math_piece in enumerate(maths):
            math_piece = math_piece.rstrip('\\')
            offset = 0.5 - index
            if len(math_piece) > 0:
                ax.text(0.5, offset, f"${math_piece}$", size=30, va='center', ha='center')
        ax.text(0.5, offset - 1, image_title, size=30, va='center', ha='center')
        ax.axis('off')
        fig.savefig(file_path('math_images', display_image_name), dpi=dpi, bbox_inches='tight')
    except Exception as e:
        print(e)
        print(math)
        fallback_img(math, display_image_name, image_title, dpi, fig_width, fig_height)

    return math, image_name, display_image_name

In [None]:
# df[33:34].apply(lambda row: latex_to_img(row), axis=1)
# df[0:1].apply(lambda row: latex_to_img(row), axis=1)
# df[17:18].apply(lambda row: latex_to_img(row), axis=1)
# df[20:21].apply(lambda row: latex_to_img(row), axis=1)
# df[5:6].apply(lambda row: latex_to_img(row), axis=1)
df[['clean_math', 'math_image', 'display_math_image']]  = df.apply(lambda row: latex_to_img(row), axis=1, result_type='expand')

In [None]:
df.to_csv(file_path(input_file))

In [None]:
sb.glue('status', 'completed')

In [None]:
# math = clean_math(df.iloc[16]['math'])

# dpi=300; fig_width=1; fig_height=1

# fig = Figure(figsize=(fig_width, fig_height))
# FigureCanvas(fig)
# ax = fig.subplots(1,1)

# ax.text(0.5, 0.5, f'${math}$ (3)', size=30, va='center', ha='center')

# ax.axis('off')

# fig.savefig('test.png', dpi=dpi, bbox_inches='tight')
# Image(filename='test.png')

In [None]:
# math = df.iloc[0]['math']
# pattern_a = r'\\\\\n|\\\\mathrm'
# math = re.sub(pattern_a, '', math)
# math

In [None]:
# s = '\\[\n\\mathrm{d} s^{2}=-f(\\eta) \\mathrm{d} t^{2}+\\mathrm{d} \\eta^{2}+g(\\eta)\\left(\\mathrm{d} x_{1}^{2}+\\cdots+\\mathrm{d} x_{d-1}^{2}\\right)\n\\]'
# lines = s.split('\n')

# # Remove '\\[' from the first line and '\\mathrm' from the second line
# lines[0] = lines[0].replace('\\[', '', 1)
# lines[1] = lines[1].replace('\\mathrm', '', 1)

# # Remove '\\]' from the last line
# lines[-1] = lines[-1].replace('\\]', '', 1)

# s = '\n'.join(lines)
# s.strip()

In [None]:
# s = '\\[\n\\mathrm{d} s^{2}=-f(\\eta) \\mathrm{d} t^{2}+\\mathrm{d} \\eta^{2}+g(\\eta)\\left(\\mathrm{d} x_{1}^{2}+\\cdots+\\mathrm{d} x_{d-1}^{2}\\right)\n\\]'
# # pattern_a = r'\\\[\n\\mathrm'
# pattern_a = r'^\\\[\n\\mathrm'
# s = re.sub(pattern_a, '', s)
# # pattern_b = r'\n\\\]'
# # s = re.sub(pattern_b, '', s)
# print(s)