In [None]:
import pandas as pd
from utils import *
import os
import matplotlib.pyplot as plt
import seaborn as sns
import textwrap
sns.set_style('whitegrid')
plt.rcParams.update({'font.size': 17})

## Generation

In [None]:
## Generation I/O
data_file = "../data/master_papers.jsonl"
generation_prompt = "../prompt/generation_1.txt"
seed_1 = "../prompt/seed_1.md"
generation_out = "../data/generation_1_paper.jsonl"
generation_topic = "../data/master_paper.md"

In [None]:
%run generation_1.py --deployment_name gpt-4-1106-preview \
                    --max_tokens 300 --temperature 0.0 --top_p 0.0 \
                    --data $data_file \
                    --prompt_file $generation_prompt \
                    --seed_file $seed_1 \
                    --out_file $generation_out \
                    --topic_file $generation_topic \
                    --verbose True

## Filtering

In [None]:
tree, nodes = generate_tree(read_seed("../data/master_paper.md"))
print(tree_view(tree))

In [None]:
topic_count = sum([node.count for node in tree.descendants])
threshold = 5
for node in tree.descendants: 
    if node.count < threshold and node.lvl==1: 
        print(f"Removing {node.name} ({node.count} counts)")
        node.parent = None
        nodes.remove(node)

## Visualization

In [None]:
topics = [node.name for node in nodes]
counts = [node.count for node in nodes]
sorted_topics, sorted_counts = zip(*sorted([(t, c) for t, c in zip(topics, counts)], key=lambda x: x[1], reverse=True))
plt.figure(figsize=(10, 20))
sns.barplot(x=sorted_counts, y=sorted_topics, color='purple')
plt.xlabel('Number of papers')
plt.title('Topic distribution')
plt.tight_layout()
plt.savefig('../data/topic_distribution.png')
plt.show()