In [1]:
from real_world_graphs.llms import OpenAILLM, OpenAIConfig

from real_world_graphs.cause_net_tasks import GraphEstimationTask
import os
import json

In [2]:
NARRATIVE_DIR = "path/to/narrative_directory"
NARRATIVE_PATHS = [os.path.join(NARRATIVE_DIR, fn) for fn in os.listdir(NARRATIVE_DIR)]

In [7]:
llm = OpenAILLM(config=OpenAIConfig(max_tokens=100))
llm.config.model_name = "gpt-4o"
llm.config.max_workers = 1

In [8]:
def get_max_narrative_len(narrative_path: str) -> int:
    max_len = 0
    with open(narrative_path, "r") as file:
        for line in file:
            narrative_data = json.loads(line)
            max_len = max(max_len, len(narrative_data["nodes"]))
    return max_len

In [9]:
tasks = [
    GraphEstimationTask(
        graph_path=None,
        llm=llm,
        narrative_path=path,
        min_chain_length=get_max_narrative_len(path),
    )
    for path in NARRATIVE_PATHS
]

In [1]:
prompt_list = []
for task in tasks:
    prompt_list += task.generate_prompt_data()
len(prompt_list)

In [2]:
responses = task.prompts_to_response(prompt_data=prompt_list, show_progress=True)