In [None]:
#!pip install --quiet langchain-anthropic langchain-neo4j cyVer langchain-google-genai json-repair "numpy<2"

In [None]:
from dotenv import load_dotenv, dotenv_values
import json
from tqdm import tqdm
import pandas as pd
import time
import os

from langchain.chat_models import init_chat_model

from utils import (
    _value_sanitize,
    extract_json_from_markdown,
    sampling_query,
    validate_cypher,
    process_database,
    process_all_examples_with_limit,
    convert_datetime
)
from prompts import (
    system_prompt,
    simple_system_prompt,
)

In [None]:
config = dotenv_values("run.env")

In [None]:
for k in config:
    if "_API_KEY" in k:
        print("setup the env variable for ", k)
        os.environ[k] = config.get(k)

In [None]:
DATABASES=json.loads(config.get('DATABASES'))
LLM_CREATE_QUESTIONS=config.get('LLM_CREATE_QUESTIONS').split(",")
LLM_CREATE_ANSWERS=config.get('LLM_CREATE_ANSWERS')

# Generate dataset

In [None]:
# LLM selection
models = []
for l in LLM_CREATE_QUESTIONS:
    models.append(init_chat_model(l, temperature=0))

In [None]:
simple_batch_count = 1 # Number of iterations for simple queries
multi_batch_count = 1 # Number of iterations complex queries

output = []

for model in models:
    print(model.model)
    for database in tqdm(DATABASES, desc="Processing databases"):
        # Simple question
        database_records = process_database(
            database, model, simple_batch_count, simple_system_prompt
        )
        output.extend(database_records)

        database_records = process_database(
            database, model, multi_batch_count, system_prompt
        )
        output.extend(database_records)

# Generate text answers

In [None]:
qa_model = init_chat_model(LLM_CREATE_ANSWERS, temperature=0)

In [None]:
validated = [el for el in output if el["validated"]]

In [None]:
len(validated)

In [None]:
# Generate text-based answers
await process_all_examples_with_limit(validated, qa_model)

In [None]:
# If the question cannot be answered, remove record
validated = [el for el in validated if not "UNKNOWN" in el['answer']]

df = pd.DataFrame.from_records(validated)
print(f"Total QA pairs: {len(df)}")
df.head(5)

In [None]:
# Assuming 'output' is defined elsewhere in your code
timestr = time.strftime("%Y%m%d-%H%M%S")
print(timestr)
with open(f"generated_dataset_{timestr}.json", "w") as f:
    json.dump(validated, f, indent=2, default=convert_datetime)