In [2]:
import os
import json
import io
import sys
import re
from contextlib import redirect_stdout
from typing import Dict, List, Any
from arango import ArangoClient
from langchain.graphs import ArangoGraph
from langchain_openai import ChatOpenAI
from langchain.chains import ArangoGraphQAChain
from api_key import openai_api_key

# Set OpenAI API key
os.environ["OPENAI_API_KEY"] = openai_api_key

# Initialize ChatOpenAI
try:
    llm = ChatOpenAI(temperature=0, model='gpt-4')
    print("Initialization successful!")
except Exception as e:
    print(f"Initialization failed: {e}")

# Initialize the ArangoDB client and connect to the database
client = ArangoClient(hosts='http://127.0.0.1:8529')
db = client.db('spoke23_human', username='root', password='ph')

# Fetch the existing graph from the database
graph = ArangoGraph(db)

# Instantiate ArangoGraphQAChain
qa_chain = ArangoGraphQAChain.from_llm(llm, graph=graph, verbose=True, return_aql_query=True, return_aql_result=True)

def capture_stdout(func, *args, **kwargs) -> str:
    f = io.StringIO()
    with redirect_stdout(f):
        func(*args, **kwargs)
    captured_output = f.getvalue()
    return captured_output

def execute_aql(query: str, qa_chain) -> Dict[str, str]:
    captured_output = capture_stdout(qa_chain.invoke, {qa_chain.input_key: query})
    return {'captured_output': captured_output}

def clean_output(output: str) -> str:
    ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
    cleaned_output = ansi_escape.sub('', output)
    return cleaned_output

def fix_json_format(aql_result_line: str) -> str:
    fixed_json = aql_result_line.replace("'", '"').replace('\\', '\\\\').replace('\n', '\\n')
    return fixed_json

def extract_aql_result(captured_output: str) -> Dict[str, list]:
    cleaned_output = clean_output(captured_output)
    lines = cleaned_output.splitlines()
    aql_result_line = None
    for i, line in enumerate(lines):
        if "AQL Result:" in line:
            if i + 1 < len(lines):
                aql_result_line = lines[i + 1].strip()
            break

    if aql_result_line:
        try:
            fixed_json = fix_json_format(aql_result_line)
            aql_result = json.loads(fixed_json)
            return {'aql_result': aql_result}
        except json.JSONDecodeError:
            pass
    return {'aql_result': []}

def interpret_aql_result(aql_result: List[Dict[str, Any]], llm) -> str:
    prompt = (
        "Based on the following AQL results, provide a detailed and comprehensive scientific story "
        "that explains the associations between the genes and pathways:\n\n"
        f"AQL Results: {aql_result}\n\n"
        "Please include the significance of the associations, the role of the genes and pathways in Type 2 Diabetes, "
        "and any relevant biological processes involved."
    )
    response = llm.invoke(prompt)
    return response.content

def sequential_chain(query: str, qa_chain) -> Dict[str, Any]:
    response = execute_aql(query, qa_chain)
    captured_output = response['captured_output']
    final_response = extract_aql_result(captured_output)
    
    aql_result = final_response.get('aql_result', [])
    if aql_result:
        scientific_story = interpret_aql_result(aql_result, llm)
        final_response['scientific_story'] = scientific_story

    return final_response

def execute_query_with_retries(query: str, qa_chain, max_attempts=3):
    attempt = 1
    success = False
    failure_message = ("The prior AQL query failed to return results. "
                       "Please think this through step by step and refine your AQL statement. "
                       "The original question is as follows:")

    while attempt <= max_attempts and not success:
        print(f"Attempt {attempt}: Executing query...")
        response = sequential_chain(query, qa_chain)
        aql_result = response.get('aql_result', [])

        if aql_result:
            success = True
            print(f"\nAttempt {attempt} - AQL Result:\n{aql_result}")
            print(f"LLM Interpretation:\n{response.get('scientific_story')}")
        else:
            print(f"\nAttempt {attempt} - AQL Result: No result found.")
            if attempt < max_attempts:
                query = f"{failure_message} {query}"
            else:
                print("No result found after", max_attempts, "tries.")

        attempt += 1
        
    return attempt

# the prompt info...

graph_info = """
                ### Contextual Intro
                ArangoDB graph DB represents a biomedical entity network, structured with nodes & edges, each carrying biomedical data types. Nodes = entities like proteins, drugs, diseases, genes. Edges = relationships/interactions. Aim: facilitate complex queries for insights into drug discovery, disease understanding, bio research.

                ### Node Struct
                - **Sample Node**: `node_sample` JSON shows a protein node. Elements:
                - IDs: `_key`, `_id`, `_rev`.
                - `type`: "Protein".
                - `labels`: ["Protein"].
                - `properties`: Dict of relevant properties, e.g., `identifier` ("A0A1B0GTW7"), `gene`, `description`, `org_ncbi_id`, `name`, etc.

                ### Edge Struct
                - **Sample Edge**: `edge_sample` JSON. Elements:
                - IDs: `_key`, `_id`, `_rev`.
                - Connects: `_from`, `_to`.
                - `label`: Type of relationship, e.g., "INCLUDES_PCiC".
                - `properties`: Edge attributes, like `license`, `source`, `vestige`, `forward_degrees`, etc.
                - Nodes: `start`, `end` with their properties.

                ### Edge Labels
                - Variety of labels for relationship types, e.g., `ADVRESPONSE_TO_mGarC`, `ASSOCIATES_DaG`, etc.
                - Each label, like `INCLUDES_PCiC`, signifies a specific interaction or association.
                - Crucial for query construction; guide graph traversal linking entities.

                ### Aim
                By understanding node/edge structure & labels, construct effective AQL queries for exploring bio networks, uncovering insights in drug-target interactions, gene-disease associations, etc.

            """

available_edge_labels = """ADVRESPONSE_TO_mGarC
                            ASSOCIATES_DaG
                            ASSOCIATES_GaS
                            BINDS_CbP
                            BINDS_CbPD
                            CATALYZES_ECcR
                            CAUSES_CcSE
                            CAUSES_OcD
                            CLEAVESTO_PctP
                            CONSUMES_RcC
                            CONTAINS_CcG
                            CONTAINS_FcC
                            CONTRAINDICATES_CcD
                            DECREASEDIN_PdD
                            DOWNREGULATES_AdG
                            DOWNREGULATES_CdG
                            DOWNREGULATES_GPdG
                            DOWNREGULATES_KGdG
                            DOWNREGULATES_OGdG
                            ENCODES_GeM
                            ENCODES_GeP
                            EXPRESSEDIN_GeiCT
                            EXPRESSEDIN_GeiD
                            EXPRESSEDIN_PeCT
                            EXPRESSES_AeG
                            HAS_PhEC
                            INCLUDES_OiPW
                            INCLUDES_PCiC
                            INCREASEDIN_PiD
                            INTERACTS_PDiPD
                            INTERACTS_PiC
                            INTERACTS_PiP
                            ISA_AiA
                            ISA_CTiCT
                            ISA_DiD
                            ISA_ECiEC
                            ISA_FiF
                            ISA_OiO
                            ISA_PWiPW
                            LOCALIZES_DlA
                            MARKER_NEG_GmnD
                            MARKER_POS_GmpD
                            MEMBEROF_PDmPF
                            PARTICIPATES_CpR
                            PARTICIPATES_GpBP
                            PARTICIPATES_GpCC
                            PARTICIPATES_GpMF
                            PARTICIPATES_GpPW
                            PARTICIPATES_GpR
                            PARTICIPATES_PpR
                            PARTOF_ApA
                            PARTOF_CTpA
                            PARTOF_PDpP
                            PARTOF_PpC
                            PARTOF_RpPW
                            PRESENTS_DpS
                            PRODUCES_RpC
                            REDUCES_SEN_mGrsC
                            RESEMBLES_DrD
                            RESISTANT_TO_mGrC
                            RESPONSE_TO_mGrC
                            TARGETS_MtG
                            TRANSPORTS_PtC
                            TREATS_CtD
                            UPREGULATES_AuG
                            UPREGULATES_CuG
                            UPREGULATES_GPuG
                            UPREGULATES_KGuG
                            UPREGULATES_OGuG
                            """

few_shot = """<Example Question 1>Question 1: What are the known targets of the drug Metformin, and what diseases are these targets most commonly associated with? </Example Question 1>
              <Example Answer 1>AQL Statement 1: WITH Nodes, Edges
                                FOR compound IN Nodes
                                    FILTER 'Compound' IN compound.labels
                                    AND (
                                        compound.properties.identifier LIKE '%Metformin%'
                                        OR compound.properties.name LIKE '%Metformin%'
                                        OR compound.properties.synonyms LIKE '%Metformin%'
                                    )
                                    FOR edge IN Edges
                                        FILTER edge._from == compound._id
                                        FOR relatedNode IN Nodes
                                            FILTER relatedNode._id == edge._to
                                            RETURN {
                                                metformin: {
                                                    identifier: compound.properties.identifier,
                                                    name: compound.properties.name,
                                                    chembl_id: compound.properties.chembl_id
                                                },
                                                related: {
                                                    identifier: relatedNode.properties.identifier,
                                                    name: relatedNode.properties.name,
                                                    chembl_id: relatedNode.properties.chembl_id,
                                                    // Include any other fields you need from relatedNode
                                                },
                                                edgeLabel: edge.label
                                            }</Example Answer 1>
                
                <Example Question 2>Question 2: Which genes are most strongly associated with the development of Type 2 Diabetes, and what pathways do they influence?</Example Question 2>
                <Example Answer 2>AQL Statement 2: WITH Nodes, Edges
                                LET type2DiabetesGenes = (
                                    FOR disease IN Nodes
                                        FILTER 'Disease' IN disease.labels
                                        AND (
                                            (CONTAINS(LOWER(disease.properties.name), 'type 2') AND CONTAINS(LOWER(disease.properties.name), 'diabetes'))
                                            OR 
                                            (CONTAINS(LOWER(disease.properties.synonyms), 'type 2') AND CONTAINS(LOWER(disease.properties.synonyms), 'diabetes'))
                                        )
                                        FOR edge IN Edges
                                            FILTER edge._from == disease._id
                                            AND edge.label == 'ASSOCIATES_DaG'
                                            FOR geneNode IN Nodes
                                                FILTER geneNode._id == edge._to
                                                AND 'Gene' IN geneNode.labels
                                                COLLECT geneId = geneNode._id INTO genes
                                                RETURN geneId
                                )
                                FOR geneId IN type2DiabetesGenes
                                    FOR pathwayEdge IN Edges
                                        FILTER pathwayEdge._from == geneId
                                        AND pathwayEdge.label == 'PARTICIPATES_GpPW' // Assuming this label connects genes to pathways
                                        FOR pathwayNode IN Nodes
                                            FILTER pathwayNode._id == pathwayEdge._to
                                            AND 'Pathway' IN pathwayNode.labels
                                            RETURN {
                                                geneId: geneId,
                                                pathway: {
                                                    identifier: pathwayNode.properties.identifier,
                                                    name: pathwayNode.properties.name
                                                    // Add other properties you need
                                                }
                                            }</Example Answer 2>
"""

# Base Prompt
base_prompt = f"""
    <System Instructions>Answer the above question using the following data model and AQL query template.</System Instructions>
    
    <Graph Description>{graph_info}</Graph Description>

    <Edge Label Description>This is a list of the available edge labels in the graph. You can use these to filter edges in your AQL query.</Edge Label Description>
    <Available Edge Labels>{available_edge_labels}</Available Edge Labels>
    
    <Example Few-Shot Description>These questions and AQL queries demonstrate how to construct working AQL queries based on natural language questions using the provided node, edge, and edge label information. To adapt this query for different scenarios, modify the entity types, filter conditions, and return statements based on your specific data and question.</Example Few-Shot Description>
    <Example Few-Shot>{few_shot}</Example Few-Shot>
"""

# Example usage
question = "Which genes are most strongly associated with the development of Type 2 Diabetes, and what pathways do they influence?"
attempt_count = execute_query_with_retries(question + base_prompt, qa_chain, 3)

print(f'Attempt Count: {attempt_count}')


Initialization successful!
Attempt 1: Executing query...

Attempt 1 - AQL Result:
[{'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP3594_r117682', 'name': 'Circadian rhythm genes'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP3594_r108782', 'name': 'Circadian rhythm related genes'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP706_r117178', 'name': 'Sudden infant death syndrome (SIDS) susceptibility pathways'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP706_r113813', 'name': 'Sudden Infant Death Syndrome (SIDS) Susceptibility Pathways'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP2855_r117704', 'name': 'Dopaminergic neurogenesis'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP2855_r106728', 'name': 'Dopaminergic Neurogenesis'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP3925_r117062', 'name': 'Amino acid metabolism'}}, {'geneId': 'Nodes/100005', 'pathway': {'identifier': 'WP3925_r115752', 'name': 'Ami

In [None]:
question = "Which genes are most strongly associated with the development of Type 2 Diabetes, and what pathways do they influence?"
attempt_count = execute_query_with_retries(question + base_prompt, 3)

print(f'Attempt Count: {attempt_count}')