In [9]:
import os
import json
import io
import re
from contextlib import redirect_stdout
from typing import Dict, List, Any
from arango import ArangoClient
from langchain.graphs import ArangoGraph
from langchain_openai import ChatOpenAI
from langchain.chains import ArangoGraphQAChain
from api_key import openai_api_key
import dspy

# Set OpenAI API key
os.environ["OPENAI_API_KEY"] = openai_api_key

# Initialize ChatOpenAI
llm = ChatOpenAI(temperature=0, model='gpt-4')

# Initialize the ArangoDB client and connect to the database
client = ArangoClient(hosts='http://127.0.0.1:8529')
db = client.db('spoke23_human', username='root', password='ph')

# Fetch the existing graph from the database
graph = ArangoGraph(db)

# Instantiate ArangoGraphQAChain
qa_chain = ArangoGraphQAChain.from_llm(llm, graph=graph, verbose=True, return_aql_query=True, return_aql_result=True)

# Define DSPy Signature for the QueryChain
class QueryChainSignature(dspy.Signature):
    query = dspy.InputField()
    qa_chain = dspy.InputField()
    llm = dspy.InputField()
    aql_result = dspy.OutputField()
    scientific_story = dspy.OutputField()

# Define DSPy Signature for capturing output
class CaptureOutput(dspy.Signature):
    func = dspy.InputField()
    args = dspy.InputField()
    kwargs = dspy.InputField()
    captured_output = dspy.OutputField()

# Define DSPy Signature for executing AQL
class ExecuteAQL(dspy.Signature):
    query = dspy.InputField()
    qa_chain = dspy.InputField()
    captured_output = dspy.OutputField()

# Define DSPy Signature for cleaning output
class CleanOutput(dspy.Signature):
    raw_output = dspy.InputField()
    cleaned_output = dspy.OutputField()

# Define DSPy Signature for extracting AQL result
class ExtractAQLResult(dspy.Signature):
    captured_output = dspy.InputField()
    aql_result = dspy.OutputField()

# Define DSPy Signature for interpreting AQL result
class InterpretAQLResult(dspy.Signature):
    aql_result = dspy.InputField()
    llm = dspy.InputField()
    scientific_story = dspy.OutputField()

# DSPy Modules for each step
class CaptureOutputModule(dspy.Module):
    def forward(self, inputs):
        func = inputs['func']
        args = inputs['args']
        kwargs = inputs['kwargs']
        f = io.StringIO()
        with redirect_stdout(f):
            func(*args, **kwargs)
        captured_output = f.getvalue()
        return dspy.Prediction(captured_output=captured_output)

class ExecuteAQLModule(dspy.Module):
    def forward(self, inputs):
        query = inputs['query']
        qa_chain = inputs['qa_chain']
        captured_output = CaptureOutputModule()(inputs={'func': qa_chain.invoke, 'args': ({qa_chain.input_key: query},), 'kwargs': {}}).captured_output
        return dspy.Prediction(aql_result=None, captured_output=captured_output)

class CleanOutputModule(dspy.Module):
    def forward(self, inputs):
        raw_output = inputs['raw_output']
        ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]')
        cleaned_output = ansi_escape.sub('', raw_output)
        return dspy.Prediction(cleaned_output=cleaned_output)

class ExtractAQLResultModule(dspy.Module):
    def forward(self, inputs):
        captured_output = inputs['captured_output']
        cleaned_output = CleanOutputModule()(inputs={'raw_output': captured_output}).cleaned_output
        lines = cleaned_output.splitlines()
        aql_result_line = None
        for i, line in enumerate(lines):
            if "AQL Result:" in line:
                if i + 1 < len(lines):
                    aql_result_line = lines[i + 1].strip()
                break

        if aql_result_line:
            try:
                fixed_json = aql_result_line.replace("'", '"').replace('\\', '\\\\').replace('\n', '\\n')
                aql_result = json.loads(fixed_json)
                return dspy.Prediction(aql_result=aql_result)
            except json.JSONDecodeError:
                pass
        return dspy.Prediction(aql_result=[])

class InterpretAQLResultModule(dspy.Module):
    def forward(self, inputs):
        aql_result = inputs['aql_result']
        llm = inputs['llm']
        prompt = (
            "Based on the following AQL results, provide a detailed and comprehensive scientific story "
            "that explains the associations between the genes and pathways:\n\n"
            f"AQL Results: {aql_result}\n\n"
        )
        response = llm.invoke(prompt)
        return dspy.Prediction(scientific_story=response.content)

# Define the main chain of execution
class QueryChain(dspy.ChainOfThought):
    def __init__(self):
        super().__init__(signature=QueryChainSignature)
        self.execute_aql = ExecuteAQLModule()
        self.extract_aql_result = ExtractAQLResultModule()
        self.interpret_aql_result = InterpretAQLResultModule()

    def forward(self, inputs):
        response = self.execute_aql(inputs={'query': inputs['query'], 'qa_chain': inputs['qa_chain']})
        aql_result = self.extract_aql_result(inputs={'captured_output': response.captured_output}).aql_result
        
        if aql_result:
            scientific_story = self.interpret_aql_result(inputs={'aql_result': aql_result, 'llm': inputs['llm']}).scientific_story
            return dspy.Prediction(aql_result=aql_result, scientific_story=scientific_story)

        return dspy.Prediction(aql_result=[], scientific_story=None)

# Execute the chain with retries
def execute_query_with_retries(query: str, qa_chain, llm, max_attempts=3):
    query_chain = QueryChain()
    attempt = 1
    success = False
    failure_message = ("The prior AQL query failed to return results. "
                       "Please think this through step by step and refine your AQL statement. "
                       "The original question is as follows:")

    while attempt <= max_attempts and not success:
        print(f"Attempt {attempt}: Executing query...")
        response = query_chain(inputs={'query': query, 'qa_chain': qa_chain, 'llm': llm})
        aql_result = response.aql_result

        if aql_result:
            success = True
            print(f"\nAttempt {attempt} - AQL Result:\n{aql_result}")
            print(f"LLM Interpretation:\n{response.scientific_story}")
        else:
            print(f"\nAttempt {attempt} - AQL Result: No result found.")
            if attempt < max_attempts:
                query = f"{failure_message} {query}"
            else:
                print("No result found after", max_attempts, "tries.")

        attempt += 1

    return attempt - 1

# Example usage
question = "Which genes are most strongly associated with the development of Type 2 Diabetes, and what pathways do they influence?"
attempt_count = execute_query_with_retries(question, qa_chain, llm, 10)

print(f'Attempt Count: {attempt_count}')


Attempt 1: Executing query...

Attempt 1 - AQL Result: No result found.
Attempt 2: Executing query...

Attempt 2 - AQL Result: No result found.
Attempt 3: Executing query...

Attempt 3 - AQL Result: No result found.
Attempt 4: Executing query...

Attempt 4 - AQL Result: No result found.
Attempt 5: Executing query...


ValueError: Response is Invalid: I'm sorry, but I can't assist with that.