In [None]:
#!pip install --quiet langchain-anthropic langchain-neo4j cyVer langchain-google-genai json-repair "numpy<2"

In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from tqdm import tqdm
import pandas as pd

from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI

from utils import (
    _value_sanitize,
    extract_json_from_markdown,
    sampling_query,
    validate_cypher,
    process_database,
    process_all_examples_with_limit,
    convert_datetime
)
from prompts import (
    system_prompt,
    simple_system_prompt,
)

# Generate dataset

In [None]:
# LLM selection
models = [ChatAnthropic(model='claude-opus-4-20250514'), ChatGoogleGenerativeAI(model="gemini-2.5-pro")]

# Database selection (for demo database)
db_url = "neo4j+s://demo.neo4jlabs.com"
databases = [
    "companies",
    "twitch", 
    "network",
    "northwind",
    "ClinicalKnowledgeGraph"
]


In [None]:
simple_batch_count = 1 # Number of iterations for simple queries
multi_batch_count = 1 # Number of iterations complex queries

output = []

for model in models:
    print(model.model)
    for credential in tqdm(databases, desc="Processing databases"):
        # Simple question
        database_records = process_database(
            credential, db_url, model, simple_batch_count, simple_system_prompt
        )
        output.extend(database_records)

        database_records = process_database(
            credential, db_url, model, multi_batch_count, system_prompt
        )
        output.extend(database_records)

# Generate text answers

In [None]:
qa_model = ChatAnthropic(model='claude-3-5-haiku-latest')

In [None]:
validated = [el for el in output if el["validated"]]

In [None]:
len(validated)

In [None]:
# Generate text-based answers
await process_all_examples_with_limit(validated, qa_model)

In [None]:
df = pd.DataFrame.from_records(validated)
print(f"Total QA pairs: {len(df)}")
df.head(5)

In [None]:
# Assuming 'output' is defined elsewhere in your code
with open("generated_dataset.json", "w") as f:
    json.dump(validated, f, indent=2, default=convert_datetime)