In [1]:
from neo4j import GraphDatabase
import openai
import ell
from dotenv import load_dotenv
import os
import networkx as nx
from pyvis.network import Network
from datetime import datetime

URI = "bolt://memgraph.chain-insights.io:30768"
AUTH = ("hyperclouds", "YouWish99")
load_dotenv()  # take environment variables from .env.
openai.api_key = os.getenv('OPENAI_API_KEY')
client = openai.Client()

def execute_query(query, uri=URI, auth=AUTH, database="memgraph"):
    """
    Execute a Cypher query on a Neo4j or Memgraph instance using the Neo4j driver.
    
    Args:
        query (str): The Cypher query to execute.
        uri (str): The connection URI, typically 'bolt://host:port'.
        auth (tuple): A tuple containing the username and password for authentication.
        database (str): The name of the database to query.
        
    Returns:
        list: A list of records (result rows).
    """
    try:
        # Create a driver instance
        driver = GraphDatabase.driver(uri, auth=auth)
        
        # Open a session and execute the query
        with driver.session(database=database) as session:
            result = session.run(query)
            
            # Collect all records from the result
            records = [record.data() for record in result]
            
            return records
    except Exception as e:
        print(f"An error occurred: {e}")
        return []
    finally:
        # Always close the driver connection
        if 'driver' in locals():
            driver.close()

def preprocess_query(raw_query):
    """
    Preprocess the raw query text to extract and clean the Cypher query for direct use.
    
    Args:
        raw_query (str): The raw input containing a Cypher query, including unwanted formatting.
        
    Returns:
        str: The cleaned Cypher query ready for database execution.
    """
    # Remove leading and trailing triple backticks and any additional characters
    clean_query = raw_query.strip()

    # If the query starts and ends with triple backticks, remove them
    if clean_query.startswith("```") and clean_query.endswith("```"):
        clean_query = clean_query[3:-3].strip()

    # Remove any remaining unnecessary newlines or whitespace characters
    clean_query = clean_query.replace("\\n", " ").replace("\n", " ").strip()

    # Ensure no double spaces after replacing newlines
    clean_query = ' '.join(clean_query.split())

    return clean_query

@ell.simple(model="gpt-4o-mini", client=client)
def text2cypher(text: str):
    """
    <instruction>
    - You are an expert in Neo4j Database Administration (DBA), specializing in Cypher query generation.
    - Your mission is to generate accurate and efficient Cypher queries to answer users' requests in the Neo4j or Memgraph database.
    - Generate a Cypher query that directly fulfills the user's request by interacting with the database.
    - Use the <schema> to understand the structure, labels, relationships, and properties of the database.
    - If the request involves multiple steps, ensure each step is integrated into the final Cypher query (e.g., filters, conditions, and specific properties).
    - Always rely on the <schema> provided to ensure consistency with the actual data model of the database.
    - Extract information about node labels, relationship types, and properties to construct queries that align with the database schema accurately.
    - If a property, label, or relationship type is missing or unknown in the schema, refrain from guessing. Either return an informative response or default to the schema details available.
    - When the generated query is intended for direct database execution, output only the Cypher query.
    - Avoid adding any additional explanatory text, comments, or metadata around the query. The response should be formatted as a pure Cypher query.
    - When generating queries, if user instructions are ambiguous, make reasonable assumptions based on typical graph modeling patterns. Document these assumptions internally if needed.
    - When the user's question involves aggregates, filters, or sorting, add the appropriate clauses (MATCH, WHERE, RETURN, ORDER BY, etc.).
    </instruction>

    <schema>    
    Nodes:
    - Address: Properties include address (string).
    - Transaction: Properties include block_height (int) in_total_amount (int) in_coinbase (bool) out_total_amount (int) timestamp (int) tx_id (string)
    
    Relationships:
    - SENT: Address -> Transaction -> Address with property value_satoshi (int)
    </schema>
    """
    return f"{text}"

def visualize_results(results):
    from pyvis.network import Network
    from datetime import datetime

    net = Network(notebook=True, height="750px", width="100%", directed=True, cdn_resources='in_line')

    for record in results:
        from_address = record.get("sender_address")  # Correct key
        to_address = record.get("receiver_address")  # Correct key
        timestamp = record.get("t.timestamp")
        block_height = record.get("t.block_height")
        in_total_amount = record.get("t.in_total_amount")
        out_total_amount = record.get("t.out_total_amount")

        if from_address and to_address:
            formatted_time = (
                datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S UTC')
                if timestamp
                else "Unknown"
            )

            net.add_node(from_address, label=from_address, title=f"From Address: {from_address}", color='blue')
            net.add_node(to_address, label=to_address, title=f"To Address: {to_address}", color='green')

            edge_title = (f"Block Height: {block_height}, "
                          f"Timestamp: {formatted_time}, "
                          f"In Total Amount: {in_total_amount}, "
                          f"Out Total Amount: {out_total_amount}")

            net.add_edge(from_address, to_address, title=edge_title, color='black')

    net.show_buttons(filter_=['physics'])
    net.show("cypher_query_visualization.html")

def visualize_results_dynamic(results, node_key_pairs=None, relationship_keys=None, node_colors=None):
    """
    Flexible visualization for graph schemas using PyVis.

    Args:
        results (list): A list of result records from the Cypher query.
        node_key_pairs (tuple): Tuple containing the keys for source and target nodes (e.g., ("sender_address", "receiver_address")).
        relationship_keys (list): List of keys to include in the relationship description (e.g., ["block_height", "timestamp"]).
        node_colors (dict): A dictionary to specify colors for specific nodes (e.g., {"sender_address": "blue", "receiver_address": "green"}).
    """
    from pyvis.network import Network
    from datetime import datetime

    # Initialize PyVis Network
    net = Network(notebook=True, height="750px", width="100%", directed=True, cdn_resources='in_line')

    # Set default schema if not provided
    if not node_key_pairs:
        node_key_pairs = ("from", "to")  # Fallback to generic keys
    if not relationship_keys:
        relationship_keys = []  # Default to no additional relationship details
    if not node_colors:
        node_colors = {"default": "gray"}  # Default color for all nodes

    source_key, target_key = node_key_pairs

    for record in results:
        # Get source and target nodes
        source = record.get(source_key)
        target = record.get(target_key)

        if source and target:
            # Add nodes with dynamic colors
            net.add_node(
                source, 
                label=source, 
                title=f"Node: {source}", 
                color=node_colors.get(source_key, "blue")
            )
            net.add_node(
                target, 
                label=target, 
                title=f"Node: {target}", 
                color=node_colors.get(target_key, "green")
            )

            # Build edge description dynamically based on relationship keys
            edge_title = []
            for key in relationship_keys:
                value = record.get(key, "Unknown")
                if key == "timestamp" and isinstance(value, int):
                    value = datetime.utcfromtimestamp(value).strftime('%Y-%m-%d %H:%M:%S UTC')
                edge_title.append(f"{key}: {value}")
            edge_title = ", ".join(edge_title)

            # Add edge between nodes
            net.add_edge(source, target, title=edge_title, color='black')

    # Show buttons for customization
    net.show_buttons(filter_=['physics'])
    net.show("cypher_query_visualization.html")

In [4]:
generatedcypher = text2cypher("what transaction id '31b768384f5ae35efd4834f0f374ac4d4468233c2a8d8b01ebe968bebc01e6fa' have transactions between address?")
clean_query = preprocess_query(generatedcypher)
records = execute_query(clean_query)