In [None]:
import ast
import altair as alt
import json
import numpy as np
import os
import pandas as pd
import pprint
import re

from altair_saver import save
from collections import defaultdict
from selenium.webdriver import Chrome, ChromeOptions

from collections import defaultdict, Counter

pd.set_option('display.max_colwidth', None)

In [None]:
# ! pip install -U selenium==4.2.0

In [None]:
topic_file = 'topics.json'
index_col = 'uuid'

original_file = '2021_Wightman-Posthuma_A_genomewide_association_study_with_112_563_individuals_identifies_new_risk_loci_for_Alzheimers_disease'
external_id = '2023_05_02_27142069922ab9506d3dg'

input_file = f'gpt_topics_{external_id}.csv'

data_path = os.getenv('DATA_PATH')
file_path = lambda *args: os.path.join(data_path, 'diygenomics-projects', 'experiment-a', 
                                       original_file, 'mathpix', *args)

if not os.path.exists(file_path('charts')):
    os.makedirs(file_path('charts'))

In [None]:
def snake_case(s):
    # Replace any punctuation with whitespace
    s = re.sub(r'[^\w\s]', ' ', s)
    # Split the string into words and lowercase them
    words = s.lower().split()
    # Join the words with underscores
    return '_'.join(words)

In [None]:
with open(file_path(topic_file), 'r') as f:
    corpus_topics = json.load(f)
    corpus_topics_keys = list(corpus_topics.keys())

In [None]:
df = pd.read_csv(file_path(input_file), index_col=index_col)

In [None]:
df_topics = df[~df['topics'].isna()]

In [None]:
topics_list = df_topics['topics'].apply(lambda x: x).tolist()

topic_dict = {topic: sum(topic in row for row in topics_list) for topic in corpus_topics_keys}

with open(file_path(f'topics_counts.json'), 'w') as f:
    json.dump(topic_dict, f)

In [None]:
df_topics = pd.DataFrame.from_dict(topic_dict, orient='index', columns=['count'])

df_topics = df_topics.reset_index().rename(columns={'index': 'topic'})

color_scale = alt.Scale(domain=list(df_topics['topic']))

chart = alt.Chart(df_topics).mark_bar().encode(
    x=alt.X('count:Q', axis=alt.Axis(title='Count'), sort='-y'),
    y=alt.Y('topic', axis=alt.Axis(title='Topic')),
    color=alt.Color('topic', scale=color_scale)
).properties(
    title=f'Topic Counts'
)

# chart.show()
save(chart, file_path('charts', f'topic_counts.png'))  