In [1]:
import os
import dspy
from dspy.datasets import DataLoader
from dspy.teleprompt import BootstrapFewShot
from dspy.evaluate import Evaluate
from api_key import api_key, chatgpt_model

# Set the OpenAI API key
os.environ["OPENAI_API_KEY"] = api_key

# Configure DSPy to use the OpenAI API key
dspy.settings.configure(api_key=os.getenv("OPENAI_API_KEY"))

# Configure the language model (example using OpenAI's GPT-3.5)
gpt3_turbo = dspy.OpenAI(model=chatgpt_model, max_tokens=300)
dspy.configure(lm=gpt3_turbo)


In [59]:
import os
import dspy
from dspy.datasets import DataLoader
from dspy.teleprompt import BootstrapFewShot
from dspy.evaluate import Evaluate
from api_key import api_key, chatgpt_model

# Set the OpenAI API key
os.environ["OPENAI_API_KEY"] = api_key

# Configure DSPy to use the OpenAI API key
dspy.settings.configure(api_key=os.getenv("OPENAI_API_KEY"))

# Configure the language model (example using OpenAI's GPT-3.5)
gpt3_turbo = dspy.OpenAI(model=chatgpt_model, max_tokens=300)
dspy.configure(lm=gpt3_turbo)

class Example:
    def __init__(self, nl_query, aql_statement):
        self.inputs = {"nl_query": nl_query}
        self.outputs = {"aql_statement": aql_statement}
    
    def to_dict(self):
        return {
            "inputs": self.inputs,
            "outputs": self.outputs
        }
    
    def __getitem__(self, item):
        return self.__dict__[item]

class ExampleWrapper:
    def __init__(self, example_dict):
        self.example_dict = example_dict

    def inputs(self):
        return self.example_dict['inputs']

    def outputs(self):
        return self.example_dict['outputs']

    def __getitem__(self, key):
        if key in self.example_dict:
            return self.example_dict[key]
        elif isinstance(key, int):
            # Handle numeric keys gracefully
            return list(self.example_dict.items())[key]
        else:
            raise KeyError(f"Key {key} not found in example_dict")

    def items(self):
        return self.example_dict.items()

    def get(self, key, default=None):
        return self.example_dict.get(key, default)

    def copy(self):
        return ExampleWrapper(self.example_dict.copy())

# Convert dictionaries to Example objects
train_dataset = [
    Example(
        nl_query="What are the known targets of the drug Metformin, and what diseases are these targets most commonly associated with?",
        aql_statement="""
            WITH Nodes, Edges
            FOR compound IN Nodes
                FILTER 'Compound' IN compound.labels
                AND (
                    compound.properties.identifier LIKE '%Metformin%'
                    OR compound.properties.name LIKE '%Metformin%'
                    OR compound.properties.synonyms LIKE '%Metformin%'
                )
                FOR edge IN Edges
                    FILTER edge._from == compound._id
                    FOR relatedNode IN Nodes
                        FILTER relatedNode._id == edge._to
                        RETURN {
                            metformin: {
                                identifier: compound.properties.identifier,
                                name: compound.properties.name,
                                chembl_id: compound.properties.chembl_id
                            },
                            related: {
                                identifier: relatedNode.properties.identifier,
                                name: relatedNode.properties.name,
                                chembl_id: relatedNode.properties.chembl_id
                            },
                            edgeLabel: edge.label
                        }
        """
    ),
    Example(
        nl_query="Which genes are most strongly associated with the development of Type 2 Diabetes, and what pathways do they influence?",
        aql_statement="""
            WITH Nodes, Edges
            LET type2DiabetesGenes = (
                FOR disease IN Nodes
                    FILTER 'Disease' IN disease.labels
                    AND (
                        (CONTAINS(LOWER(disease.properties.name), 'type 2') AND CONTAINS(LOWER(disease.properties.name), 'diabetes'))
                        OR 
                        (CONTAINS(LOWER(disease.properties.synonyms), 'type 2') AND CONTAINS(LOWER(disease.properties.synonyms), 'diabetes'))
                    )
                    FOR edge IN Edges
                        FILTER edge._from == disease._id
                        AND edge.label == 'ASSOCIATES_DaG'
                        FOR geneNode IN Nodes
                            FILTER geneNode._id == edge._to
                            AND 'Gene' IN geneNode.labels
                            COLLECT geneId = geneNode._id INTO genes
                            RETURN geneId
            )
            FOR geneId IN type2DiabetesGenes
                FOR pathwayEdge IN Edges
                    FILTER pathwayEdge._from == geneId
                    AND pathwayEdge.label == 'PARTICIPATES_GpPW'
                    FOR pathwayNode IN Nodes
                        FILTER pathwayNode._id == pathwayEdge._to
                        AND 'Pathway' IN pathwayNode.labels
                        RETURN {
                            geneId: geneId,
                            pathway: {
                                identifier: pathwayNode.properties.identifier,
                                name: pathwayNode.properties.name
                            }
                        }
        """
    )
]

dev_dataset = [
    Example(
        nl_query="Find all nodes connected to node 123.",
        aql_statement="""
            WITH Nodes, Edges
            FOR node IN Nodes
                FILTER node._id == 'node/123'
                FOR edge IN Edges
                    FILTER edge._from == node._id OR edge._to == node._id
                    RETURN edge
        """
    )
]

wrapped_train_dataset = [ExampleWrapper(e.to_dict()) for e in train_dataset]
wrapped_dev_dataset = [ExampleWrapper(e.to_dict()) for e in dev_dataset]

class NLToAQLSignature(dspy.Signature):
    """Convert natural language query to AQL statement."""
    nl_query = dspy.InputField(desc="Natural language query")
    aql_statement = dspy.OutputField(desc="Generated AQL statement")

class NLToAQL(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_aql = dspy.ChainOfThought(NLToAQLSignature)
    
    def forward(self, nl_query):
        aql_statement = self.generate_aql(nl_query=nl_query)
        return dspy.Prediction(aql_statement=aql_statement.aql_statement)

# Corrected validate_aql function
def validate_aql(example, pred, trace=None):
    # Add debugging statements to confirm structure
    print("Debug: Example structure:", example)
    print("Debug: Example dict:", example.example_dict)
    print("Debug: Prediction structure:", pred)
    
    if 'outputs' not in example.example_dict or 'aql_statement' not in example.example_dict['outputs']:
        raise KeyError("'outputs' or 'aql_statement' key not found in example")
    
    if not hasattr(pred, 'aql_statement'):
        raise KeyError("'aql_statement' key not found in prediction")
    
    return example.example_dict['outputs']['aql_statement'].strip() == pred.aql_statement.strip()

teleprompter = BootstrapFewShot(metric=validate_aql)

# Compile the program
compiled_nl_to_aql = teleprompter.compile(student=NLToAQL(), trainset=wrapped_train_dataset)

# Evaluate the program
evaluate_program = Evaluate(devset=wrapped_dev_dataset, metric=validate_aql)
results = evaluate_program(compiled_nl_to_aql)
print(f"Evaluation Results: {results}")

# Test the prediction
nl_query = "Find all nodes connected to node 123."
pred = compiled_nl_to_aql(nl_query=nl_query)
print(f"Generated AQL: {pred.aql_statement}")


100%|██████████| 2/2 [00:00<00:00, 2337.96it/s]

Debug: Example structure: <__main__.ExampleWrapper object at 0x7fabd3d36ea0>
Debug: Example dict: {'inputs': {'nl_query': 'What are the known targets of the drug Metformin, and what diseases are these targets most commonly associated with?'}, 'outputs': {'aql_statement': "\n            WITH Nodes, Edges\n            FOR compound IN Nodes\n                FILTER 'Compound' IN compound.labels\n                AND (\n                    compound.properties.identifier LIKE '%Metformin%'\n                    OR compound.properties.name LIKE '%Metformin%'\n                    OR compound.properties.synonyms LIKE '%Metformin%'\n                )\n                FOR edge IN Edges\n                    FILTER edge._from == compound._id\n                    FOR relatedNode IN Nodes\n                        FILTER relatedNode._id == edge._to\n                        RETURN {\n                            metformin: {\n                                identifier: compound.properties.identifier,\n   


