In [135]:

!pip3 install langchain neo4j py2neo openai langchain-community langchain-experimental langchain-openai langchain-neo4j langchain-openai neo4j unstructured[all-docs] unstructured[openai] jsonlines wandb python-magic
!brew install libmagic



In [222]:
import os
import json
import uuid

from dotenv import load_dotenv
load_dotenv(override=True)

True

In [223]:
import os
import json
from uuid import uuid4
import pandas as pd
from langchain.llms import OpenAI
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

# Load environment variables
load_dotenv()

# Initialize LLM for entity and relationship extraction
# llm_extraction = ChatOpenAI(model="gpt-4o-mini", temperature=0)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
llm_extraction = ChatOpenAI(model="gpt-4o", temperature=0)

In [224]:
# Directory containing the files
directory_path = "./docs5"  # Update this to your directory path

# Build the Ontology

In [225]:
import os
import json
from collections import defaultdict

# Ontology prompt template
ontology_prompt = """
<OntologyPrompt>
    <Instruction>
        Given the provided TEXT, identify and define entities, relationships, and schema for constructing a knowledge graph ontology. Output the results in JSON format as follows:
        <Guidelines>
            <RelationshipClarity>
                <Point>Avoid relationships that are overly generic (e.g., "MENTIONS") unless essential for describing entity interactions.</Point>
                <Point>Focus on verbs or phrases that express meaningful actions or dependencies (e.g., "GUIDES", "COMMUNICATES_WITH", "DEPENDS_ON").</Point>
                <Point>Use a hierarchical structure where broad relationships (e.g., "COMMANDS") can have subtypes (e.g., "SETS", "ADJUSTS").</Point>
            </RelationshipClarity>
            <EntityCoverage>
                <Point>Include entities that represent distinct and meaningful concepts from the TEXT.</Point>
                <Point>Group similar entities under higher-level abstractions where granularity is unnecessary (e.g., consolidate "Altitude" and "SpeedParameter" into "Parameter").</Point>
            </EntityCoverage>
            <SchemaValidation>
                <Point>Assign relationships based on clear and explicit interactions or dependencies described in the TEXT.</Point>
                <Point>Avoid redundant relationships unless they add distinct value to the schema.</Point>
            </SchemaValidation>
        </Guidelines>
    </Instruction>
    <InputText>
        <![CDATA[
        {text}
        ]]>
    </InputText>
    <OutputFormat>
        <![CDATA[
        {{
            "VALID_ENTITIES": [
                "Entity1",
                "Entity2"
                // Add more entities as needed
            ],
            "VALID_RELATIONS": [
                "RELATIONSHIP1",
                "RELATIONSHIP2"
                // Add more relationships as needed
            ],
            "VALIDATION_SCHEMA": {{
                "Entity1": ["RELATIONSHIP1", "RELATIONSHIP2"],
                "Entity2": ["RELATIONSHIP1", "RELATIONSHIP3"]
                // Add more entities and their valid relationships as needed
            }}
        }}
        ]]>
    </OutputFormat>
</OntologyPrompt>
"""

# Initialize running sets for deduplication
entities_set = set()
relations_set = set()
schema_dict = defaultdict(set)

# Helper function to clean LLM response
def clean_json_response(raw_response):
    """Remove Markdown-style backticks and ensure valid JSON."""
    try:
        cleaned_response = raw_response.strip()
        if cleaned_response.startswith("```json"):
            cleaned_response = cleaned_response[7:]  # Remove leading ```json
        if cleaned_response.endswith("```"):
            cleaned_response = cleaned_response[:-3]  # Remove trailing ```
        return cleaned_response
    except Exception as e:
        print("Error cleaning response:", e)
        return raw_response

# Process files in the directory
for file in os.listdir(directory_path):
    file_path = os.path.join(directory_path, file)
    if os.path.isfile(file_path):
        # Read file content
        with open(file_path, "r", encoding='utf-8') as f:
            text = f.read()
        
        print(f"\nProcessing file: {file}")
        
        # Format the prompt for the current document
        formatted_prompt = ontology_prompt.format(text=text)
        
        try:
            # Invoke LLM (assume llm_extraction.invoke exists)
            response = llm_extraction.invoke([{"role": "system", "content": formatted_prompt}])
            
            # Debugging: Print raw response content
            print("Raw response:", response.content)
            
            # Clean the response
            cleaned_content = clean_json_response(response.content)
            
            # Debugging: Print cleaned response content
            print("Cleaned response:", cleaned_content)
            
            # Attempt to parse response JSON
            response_data = json.loads(cleaned_content)
            
            # Update running sets and schema
            entities_set.update(response_data.get("VALID_ENTITIES", []))
            relations_set.update(response_data.get("VALID_RELATIONS", []))
            for entity, relations in response_data.get("VALIDATION_SCHEMA", {}).items():
                schema_dict[entity].update(relations)
        
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON for file {file}: {e}")
            print("Cleaned response content:", cleaned_content)
        except Exception as e:
            print(f"Unexpected error processing file {file}: {e}")

# Deduplicate and sort results
final_entities = sorted(entities_set)
final_relations = sorted(relations_set)
final_schema = {entity: sorted(relations) for entity, relations in schema_dict.items()}

# Combine everything into a single JSON object
combined_output = {
    "VALID_ENTITIES": final_entities,
    "VALID_RELATIONS": final_relations,
    "VALIDATION_SCHEMA": final_schema
}

# Save the combined and deduplicated ontology to a JSON file
output_file = "combined_ontology.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(combined_output, f, indent=4)

print(f"\nCombined ontology saved to {output_file}.")



Processing file: ac_throttlemid.rst
Raw response: ```json
{
    "VALID_ENTITIES": [
        "Copter",
        "Hover Throttle",
        "MOT_THST_HOVER",
        "MOT_HOVER_LEARN",
        "Dataflash Log",
        "CTUN.ThO",
        "Flight Mode"
    ],
    "VALID_RELATIONS": [
        "LEARNS",
        "MOVES_TOWARDS",
        "SETS",
        "DISABLES",
        "DOWNLOADS",
        "OBSERVES"
    ],
    "VALIDATION_SCHEMA": {
        "Copter": ["LEARNS", "MOVES_TOWARDS"],
        "Hover Throttle": ["LEARNS", "MOVES_TOWARDS"],
        "MOT_THST_HOVER": ["SETS"],
        "MOT_HOVER_LEARN": ["DISABLES"],
        "Dataflash Log": ["DOWNLOADS"],
        "CTUN.ThO": ["OBSERVES"],
        "Flight Mode": ["LEARNS"]
    }
}
```
Cleaned response: 
{
    "VALID_ENTITIES": [
        "Copter",
        "Hover Throttle",
        "MOT_THST_HOVER",
        "MOT_HOVER_LEARN",
        "Dataflash Log",
        "CTUN.ThO",
        "Flight Mode"
    ],
    "VALID_RELATIONS": [
        "LEARNS",
        

# Build the Text to Node Model (Fine Tuning)

In [226]:
import os
import logging
import json
import glob
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
from uuid import uuid4
from datetime import datetime
import jsonlines
from time import sleep

from unstructured.partition.rst import partition_rst
from unstructured.documents.elements import Text
from unstructured.partition.auto import partition
from typing import List, Optional

load_dotenv()

###################################
# Define Schema from ontology
###################################
# load the ontology
with open("combined_ontology.json", "r") as f:
    ontology = json.load(f)

VALID_ENTITIES = ontology["VALID_ENTITIES"]
VALID_RELATIONS = ontology["VALID_RELATIONS"]
VALIDATION_SCHEMA = ontology["VALIDATION_SCHEMA"]


###################################
# LLM Extraction Functions
###################################


def create_prompt(text_chunk):
    # Create entity tags dynamically
    entity_tags = "\n".join(f"    <entity>{entity}</entity>" for entity in VALID_ENTITIES)
    
    # Create relationship tags dynamically
    relationship_tags = "\n".join(f"    <relationship>{rel}</relationship>" for rel in VALID_RELATIONS)
    
    # Create schema tags dynamically
    schema_tags = []
    for entity, relationships in VALIDATION_SCHEMA.items():
        schema_tags.append(f"    <entitySchema name=\"{entity}\">")
        for rel in relationships:
            schema_tags.append(f"      <allowedRelation>{rel}</allowedRelation>")
        schema_tags.append("    </entitySchema>")
    schema_section = "\n".join(schema_tags)

    prompt_text = (
        "<prompt>\n"
        "  <role>You are an expert in creating knowledge graphs. Respond with a valid JSON object only.</role>\n"
        "  <instructions>Do not include any extra text outside of the JSON. The JSON must have 'Nodes' and 'Relationships'.</instructions>\n"
        "  <validEntities>\n"
        f"{entity_tags}\n"
        "  </validEntities>\n"
        "  <validRelationships>\n"
        f"{relationship_tags}\n"
        "  </validRelationships>\n"
        "  <validationSchema>\n"
        f"{schema_section}\n"
        "  </validationSchema>\n"
        "  <exampleFormat>\n"
        "    {\n"
        "      \"Nodes\": [\n"
        "        {\"id\": \"1\", \"label\": \"FlightMode\", \"properties\": {\"name\": \"Loiter\"}}\n"
        "      ],\n"
        "      \"Relationships\": [\n"
        "        {\"from_id\": \"1\", \"to_id\": \"2\", \"type\": \"REQUIRES\", \"properties\": {}}\n"
        "      ]\n"
        "    }\n"
        "  </exampleFormat>\n"
        f"  <inputText>{text_chunk.strip()}</inputText>\n"
        "</prompt>"
    )
    print(prompt_text)
    return prompt_text


def create_jsonl_entry(text_chunk, llm_response):
    """Create a JSONL entry from the text chunk and LLM response."""
    entry = {
        "messages": [
            {
                "role": "system",
                "content": create_prompt(text_chunk)
            },
            {
                "role": "assistant",
                "content": f"```json\n{llm_response}\n```"
            }
        ]
    }
    return entry


def sanitize_llm_response(response_content):
    """Sanitize LLM response to ensure valid JSON."""
    # Remove leading/trailing whitespace
    sanitized = response_content.strip()

    # Handle responses wrapped in code block formatting
    if sanitized.startswith("```json"):
        sanitized = sanitized[7:]  # Remove leading "```json"
    if sanitized.endswith("```"):
        sanitized = sanitized[:-3]  # Remove trailing "```"

    # Validate basic JSON structure
    if not ('"Nodes"' in sanitized and '"Relationships"' in sanitized):
        raise ValueError(f"Invalid JSON response structure: {sanitized[:100]}...")

    return sanitized



def validate_extraction(nodes, relationships):
    """Filter nodes and relationships based on the new schema."""
    valid_nodes = [node for node in nodes if node["label"] in VALID_ENTITIES]
    valid_relationships = []
    for rel in relationships:
        if rel["type"] in VALID_RELATIONS:
            # Check if relationship is allowed from the from_id node type
            from_node_label = None
            for n in valid_nodes:
                if n["id"] == rel["from_id"]:
                    from_node_label = n["label"]
                    break
            if from_node_label and rel["type"] in VALIDATION_SCHEMA.get(from_node_label, []):
                valid_relationships.append(rel)
    return valid_nodes, valid_relationships


def extract_nodes_relationships(llm, text_chunk, jsonl_writer=None):
    """Extract nodes and relationships from text, optionally writing to JSONL."""
    prompt = create_prompt(text_chunk)
    print("\nGenerated Prompt:")
    print(prompt)

    response = llm.invoke([{"role": "system", "content": prompt}])
    
    print("\nRaw LLM Response:")
    print(response.content)

    try:
        sanitized_response = sanitize_llm_response(response.content)
        print("\nSanitized Response:")
        print(sanitized_response)

        # Write to JSONL if writer is provided
        if jsonl_writer:
            entry = create_jsonl_entry(text_chunk, sanitized_response)
            jsonl_writer.write(entry)

        result = json.loads(sanitized_response)
        nodes = result.get("Nodes", [])
        relationships = result.get("Relationships", [])
        return validate_extraction(nodes, relationships)
    except json.JSONDecodeError as e:
        print(f"JSON Decode Error: {e}")
        return [], []
    except ValueError as ve:
        print(f"Response Validation Error: {ve}")
        return [], []


def extract_nodes_relationships_with_retry(llm, text_chunk, retries=3, delay=1, jsonl_writer=None):
    """Extract nodes and relationships with retry logic."""
    for attempt in range(retries):
        try:
            return extract_nodes_relationships(llm, text_chunk, jsonl_writer)
        except (ValueError, json.JSONDecodeError) as e:
            print(f"Retry {attempt + 1}/{retries} failed: {e}")
            sleep(delay)
    print("All retry attempts failed.")
    return [], []

###################################
# Processing RST Files
###################################

def extract_text_from_file(filepath: str) -> List[str]:
    """
    Extract text from file with fallback methods for different formats.
    Returns a list of text elements.
    """
    file_extension = Path(filepath).suffix.lower()
    
    try:
        if file_extension == '.rst':
            # Use RST-specific parser
            elements = partition_rst(filename=filepath)
            return [elem.text for elem in elements if hasattr(elem, 'text') and elem.text.strip()]
        else:
            # Try using unstructured's auto-detection for other formats
            try:
                elements = partition(filename=filepath)
                return [elem.text for elem in elements if hasattr(elem, 'text') and elem.text.strip()]
            except Exception as auto_error:
                # Fallback to simple text reading
                with open(filepath, 'r', encoding='utf-8') as f:
                    content = f.read()
                return [content] if content.strip() else []
                
    except Exception as e:
        logging.error(f"Error extracting text from {filepath}: {str(e)}")
        # Final fallback: try simple text reading
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read()
            return [content] if content.strip() else []
        except Exception as read_error:
            logging.error(f"Failed to read file {filepath}: {str(read_error)}")
            return []

# def chunk_text(text_elements: List[str], max_characters: int = 1200, 
#                new_after_n_chars: int = 500) -> List[str]:
def chunk_text(text_elements: List[str], max_characters: int = 3000, 
               new_after_n_chars: int = 1000) -> List[str]:
    """
    Chunk text elements into smaller pieces.
    """
    chunks = []
    current_chunk = []
    current_length = 0

    for text in text_elements:
        text = text.strip()
        if not text:
            continue
        text_len = len(text)

        if current_length + text_len > max_characters:
            if current_chunk:
                chunks.append(" ".join(current_chunk))
            current_chunk = [text]
            current_length = text_len
        else:
            current_chunk.append(text)
            current_length += text_len

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

def process_file(filepath: str, llm, jsonl_path: Optional[str] = None) -> bool:
    """
    Process any text file and write to JSONL.
    """
    print(f"Processing file: {filepath}")
    
    # Extract text from file
    text_elements = extract_text_from_file(filepath)
    if not text_elements:
        print(f"No text content extracted from {filepath}")
        return False

    # Create text chunks
    text_chunks = chunk_text(text_elements)

    # Skip if no JSONL path provided
    if not jsonl_path:
        print("No JSONL path provided - skipping processing")
        return False

    # Process chunks and write to JSONL
    with jsonlines.open(jsonl_path, mode='a') as jsonl_writer:
        for i, text_chunk in enumerate(text_chunks):
            print(f"Processing chunk {i+1}/{len(text_chunks)}...")
            nodes, relationships = extract_nodes_relationships_with_retry(
                llm, text_chunk, jsonl_writer=jsonl_writer
            )

    print("JSONL generation complete.")
    return True

def process_directory(directory_path: str, llm, jsonl_path: str = "training_data.jsonl",
                     file_extensions: Optional[List[str]] = None) -> List[str]:
    """
    Process all files with specified extensions in the directory and create JSONL file.
    
    Args:
        directory_path: Path to directory containing files
        llm: Language model instance
        jsonl_path: Path to output JSONL file
        file_extensions: List of file extensions to process (e.g., ['.rst', '.txt', '.md'])
                        If None, processes all files
    """
    processed_files = []
    directory_path = os.path.abspath(directory_path)
    
    if not os.path.exists(directory_path):
        raise ValueError(f"Directory does not exist: {directory_path}")

    # Get all files with specified extensions
    if file_extensions:
        files = []
        for ext in file_extensions:
            files.extend(glob.glob(os.path.join(directory_path, f"*{ext}")))
    else:
        files = [os.path.join(directory_path, f) for f in os.listdir(directory_path)
                if os.path.isfile(os.path.join(directory_path, f))]

    print(f"Found {len(files)} files in {directory_path}")

    for fpath in files:
        try:
            print(f"Processing {fpath}...")
            success = process_file(fpath, llm, jsonl_path)
            if success:
                processed_files.append(fpath)
        except Exception as e:
            print(f"Error processing {fpath}: {str(e)}")
            logging.error(f"Error processing {fpath}: {str(e)}", exc_info=True)

    print(f"Completed processing. Total files processed: {len(processed_files)}")
    return processed_files

# Usage example:
process_directory(
    "./docs5",
    llm_extraction,
    jsonl_path="training_data.jsonl",
    file_extensions=['.rst', '.txt', '.md']
)


Found 5 files in /Users/trevormcgirr/Desktop/operator/graphRAG/docs5
Processing /Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_throttlemid.rst...
Processing file: /Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_throttlemid.rst
Processing chunk 1/1...
<prompt>
  <role>You are an expert in creating knowledge graphs. Respond with a valid JSON object only.</role>
  <instructions>Do not include any extra text outside of the JSON. The JSON must have 'Nodes' and 'Relationships'.</instructions>
  <validEntities>
    <entity>3DR Power Module</entity>
    <entity>APM</entity>
    <entity>APMPlanner</entity>
    <entity>ATC_THR_MIX_MAN</entity>
    <entity>Aircraft</entity>
    <entity>Alt Hold Mode</entity>
    <entity>Altitude</entity>
    <entity>ArduPilot</entity>
    <entity>Auto Trim</entity>
    <entity>Autotune</entity>
    <entity>Axis</entity>
    <entity>Barometer</entity>
    <entity>Battery</entity>
    <entity>Battery Voltage Monitor</entity>
    <entity>CTUN.ThO</ent

['/Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_throttlemid.rst',
 '/Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_rollpitchtuning.rst',
 '/Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_tipsfornewpilots.rst',
 '/Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac2_guidedmode.rst',
 '/Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac2_followme.rst']

# Fine tuning the model

In [None]:
# Fine tune the model
from openai import OpenAI
import datetime
client = OpenAI()

base_model_name = "gpt-4o-mini-2024-07-18"
fine_tuned_model_name = f"{base_model_name}-finetune-{datetime.now().strftime('%Y-%m-%d')}"


# Upload the training file
file = client.files.create(
    file=open("training_data.jsonl", "rb"),
    purpose="fine-tune"
)

# Create fine-tuning job with W&B integration
client.fine_tuning.jobs.create(
    training_file=file.id,
    model=base_model_name, 
    integrations=[{
        "type": "wandb",
        "wandb": {
            "project": f"{fine_tuned_model_name}",  
            "tags": ["knowledge-graph", "neo4j"]  
        }
    }]
)

In [None]:
# check the status of the fine-tuning job

# client.fine_tuning.jobs.retrieve(id=fine_tuning_job.id)




In [259]:
# Update imports
from langchain_openai import ChatOpenAI
from langchain_community.graphs import Neo4jGraph
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
import os
from dotenv import load_dotenv

load_dotenv(override=True)

# Initialize Neo4j graph with the correct database
graph = Neo4jGraph(
    url=os.getenv("NEO4J_URI"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    database="graphrag"

)

# verify connection
graph.query("RETURN 1")




[{'1': 1}]

In [None]:
# llm_extraction = ChatOpenAI(model="ft:gpt-4o-mini-2024-07-18:personal::AcRnftp2", temperature=0)
# Might need to use the fine-tuned model name  that is generated by the fine-tuning job 
# llm_extraction_finetuned = ChatOpenAI(model=fine_tuned_model_name, temperature=0)
llm_extraction_finetuned = ChatOpenAI(model="gpt-4o", temperature=0)

# Verify model
# llm_extraction_finetuned.invoke([])


# directory to process
directory_path = "./docs5"

In [245]:
# ###################################
# Neo4j Insertion Functions
# ###################################
import re


# Embedding Model
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

def insert_source_node(graph, source_id, source_text, filename):
    """Insert a source node representing the text chunk."""
    graph.query(
        f"""
        MERGE (src:Source {{id: $source_id}})
        SET src.text = $source_text, src.filename = $filename
        """,
        {"source_id": source_id, "source_text": source_text, "filename": filename}
    )

def sanitize_label(label: str) -> str:
    """Sanitize node labels for Neo4j compatibility."""
    # Remove any special characters and replace with underscore
    sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', label)
    
    # Ensure label starts with a letter
    if not sanitized[0].isalpha():
        sanitized = 'N_' + sanitized
    
    # Remove consecutive underscores and trailing underscore
    sanitized = re.sub(r'_+', '_', sanitized.rstrip('_'))
    
    return sanitized


def insert_relationships_with_source(graph, relationships, source_id):
    """Insert relationships into Neo4j and link entities to the Source node."""
    for rel in relationships:
        # Add the main relationship
        graph.query(
            f"""
            MATCH (a {{id: $from_id}}), (b {{id: $to_id}})
            MERGE (a)-[r:{rel['type']}]->(b)
            SET r += $properties
            """,
            {
                "from_id": rel["from_id"],
                "to_id": rel["to_id"],
                "properties": rel.get("properties", {}),
            },
        )

        # Link both nodes to the source
        for node_id in [rel["from_id"], rel["to_id"]]:
            graph.query(
                f"""
                MATCH (n {{id: $node_id}}), (src:Source {{id: $source_id}})
                MERGE (n)-[:EXTRACTED_FROM]->(src)
                """,
                {"node_id": node_id, "source_id": source_id},
            )

def process_file(filepath: str, llm, graph: Neo4jGraph, embedding_model, jsonl_path: Optional[str] = None) -> bool:
    """Process file with embedding generation."""
    print(f"Processing file: {filepath}")
    
    text_elements = extract_text_from_file(filepath)
    if not text_elements:
        print(f"No text content extracted from {filepath}")
        return False

    text_chunks = chunk_text(text_elements)
    jsonl_writer = jsonlines.open(jsonl_path, mode='a') if jsonl_path else None

    try:
        for i, text_chunk in enumerate(text_chunks):
            print(f"Processing chunk {i+1}/{len(text_chunks)}...")
            source_id = str(uuid4())
            filename = Path(filepath).name
            
            insert_source_node(graph, source_id, text_chunk, filename)
            nodes, relationships = extract_nodes_relationships_with_retry(
                llm, text_chunk, jsonl_writer=jsonl_writer
            )
            
            # Pass embedding_model to insert_nodes_with_source
            insert_nodes_with_source(graph, nodes, source_id, embedding_model)
            insert_relationships_with_source(graph, relationships, source_id)

    finally:
        if jsonl_writer:
            jsonl_writer.close()

    return True

def process_directory(
    directory_path: str, 
    llm, 
    graph: Neo4jGraph,
    embedding_model,
    jsonl_path: str = "training_data_finetuned.jsonl",
    file_extensions: Optional[List[str]] = None
) -> List[str]:
    """Process all files with specified extensions in the directory."""
    processed_files = []
    directory_path = os.path.abspath(directory_path)
    
    if not os.path.exists(directory_path):
        raise ValueError(f"Directory does not exist: {directory_path}")

    if file_extensions:
        files = []
        for ext in file_extensions:
            files.extend(glob.glob(os.path.join(directory_path, f"*{ext}")))
    else:
        files = [os.path.join(directory_path, f) for f in os.listdir(directory_path)
                if os.path.isfile(os.path.join(directory_path, f))]

    print(f"Found {len(files)} files in {directory_path}")

    for fpath in files:
        try:
            print(f"Processing {fpath}...")
            success = process_file(fpath, llm, graph, embedding_model, jsonl_path)
            if success:
                processed_files.append(fpath)
        except Exception as e:
            print(f"Error processing {fpath}: {str(e)}")
            logging.error(f"Error processing {fpath}: {str(e)}", exc_info=True)

    print(f"Completed processing. Total files processed: {len(processed_files)}")
    return processed_files
  
def create_vector_index(graph):
    """Create vector index for node embeddings."""
    try:
        graph.query(
            """
            CREATE VECTOR INDEX nodeEmbeddings
            FOR (n:Node)
            ON (n.embedding)
            OPTIONS {
                indexConfig: {
                    `vector.dimensions`: 1536,
                    `vector.similarity_function`: 'cosine'
                }
            }
            """
        )
    except Exception as e:
        print(f"Error creating vector index: {str(e)}")
        # Try alternative syntax for older Neo4j versions
        try:
            graph.query(
                """
                CALL db.index.vector.createNodeIndex(
                    'nodeEmbeddings',
                    'Node',
                    'embedding',
                    1536,
                    'cosine'
                )
                """
            )
        except Exception as nested_e:
            print(f"Error with alternative syntax: {str(nested_e)}")

def insert_nodes_with_source(graph, nodes, source_id, embedding_model):
    """Insert nodes and link them to a Source node, including embeddings."""
    for node in nodes:
        try:
            # Sanitize the label
            sanitized_label = sanitize_label(node['label'])
            
            # Generate embedding from node properties or label
            node_text = (
                node["properties"].get("name", "") or 
                node["properties"].get("description", "") or 
                node["label"]
            )
            
            # Add display properties
            node["properties"]["name"] = node["label"]  # Use original label as name
            node["properties"]["displayName"] = node["label"]  # For visualization
            
            # Generate and store embedding
            if node_text:
                try:
                    embedding = embedding_model.embed_query(node_text)
                    if embedding:
                        node["properties"]["embedding"] = embedding
                except Exception as e:
                    print(f"Error generating embedding: {str(e)}")

            # Store original label
            node["properties"]["original_label"] = node["label"]

            # Insert node with sanitized label and Node base label
            graph.query(
                f"""
                MERGE (n:Node:{sanitized_label} {{id: $id}})
                SET n += $properties
                SET n.displayName = $displayName
                MERGE (src:Source {{id: $source_id}})
                MERGE (n)-[:EXTRACTED_FROM]->(src)
                """,
                {
                    "id": node["id"], 
                    "properties": node["properties"], 
                    "source_id": source_id,
                    "displayName": node["label"]  # Original, unsanitized label
                },
            )
            
        except Exception as e:
            print(f"Error processing node {node['id']} with label {node['label']}: {str(e)}")
            continue
        
# After processing, verify embeddings
def verify_embeddings(graph):
    result = graph.query(
        """
        MATCH (n:Node)
        WHERE n.embedding IS NOT NULL
        RETURN count(n) as nodes_with_embeddings
        """
    )
    print(f"\nNodes with embeddings: {result[0]['nodes_with_embeddings']}")
    
    result = graph.query(
        """
        MATCH (n:Node)
        WHERE n.embedding IS NULL
        RETURN count(n) as nodes_without_embeddings
        """
    )
    print(f"Nodes without embeddings: {result[0]['nodes_without_embeddings']}")

# Run processing
graph.query("MATCH (n) DETACH DELETE n")

process_directory(
    directory_path=directory_path,
    llm=llm_extraction_finetuned,
    graph=graph,
    embedding_model=embedding_model,
    jsonl_path="training_data_finetuned.jsonl",
    file_extensions=['.rst', '.txt', '.md']
)

create_vector_index(graph)
verify_embeddings(graph)

Found 5 files in /Users/trevormcgirr/Desktop/operator/graphRAG/docs5
Processing /Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_throttlemid.rst...
Processing file: /Users/trevormcgirr/Desktop/operator/graphRAG/docs5/ac_throttlemid.rst
Processing chunk 1/1...
<prompt>
  <role>You are an expert in creating knowledge graphs. Respond with a valid JSON object only.</role>
  <instructions>Do not include any extra text outside of the JSON. The JSON must have 'Nodes' and 'Relationships'.</instructions>
  <validEntities>
    <entity>3DR Power Module</entity>
    <entity>APM</entity>
    <entity>APMPlanner</entity>
    <entity>ATC_THR_MIX_MAN</entity>
    <entity>Aircraft</entity>
    <entity>Alt Hold Mode</entity>
    <entity>Altitude</entity>
    <entity>ArduPilot</entity>
    <entity>Auto Trim</entity>
    <entity>Autotune</entity>
    <entity>Axis</entity>
    <entity>Barometer</entity>
    <entity>Battery</entity>
    <entity>Battery Voltage Monitor</entity>
    <entity>CTUN.ThO</ent

# Test the Graph with GraphCypherQAChain

In [239]:
# Add these imports at the top
from langchain_community.graphs import Neo4jGraph

# from langchain.chains import GraphCypherQAChain
from langchain.graphs import Neo4jGraph
from langchain.llms import OpenAI
from neo4j import GraphDatabase

# Initialize QA Chain for Neo4j
llm_chat = ChatOpenAI(model="gpt-4o", temperature=0)
chain = GraphCypherQAChain.from_llm(
    llm=llm_chat,
    graph=graph,
    verbose=True,
    allow_dangerous_requests=True,
)


In [240]:
from langchain.prompts import PromptTemplate
from langchain.chains import GraphCypherQAChain
import json

# Define the Cypher prompt template
cypher_prompt = PromptTemplate(
    template="""
Given the following schema for entities and relationships:

Entities: {entities}
Relationships: {relationships}
Validation Schema: {validation_schema}

Generate a Cypher query to answer this question: {query}

Requirements:
- Match relevant nodes and relationships based on the schema
- Include source text for context in the query
- Handle case-insensitive matching for string comparisons
- Return a limited number of results (e.g., LIMIT 10)

Example structure for querying based on relationships:
MATCH (node1:FlightMode)-[rel:REQUIRES]->(node2:Parameter)
WHERE toLower(node1.name) CONTAINS toLower('Loiter')
RETURN DISTINCT
    node1.name as node_name,
    node1.source_text as context,
    collect(DISTINCT {{
        related_node: node2.name,
        relationship_type: type(rel),
        details: properties(rel)
    }}) as relationships
LIMIT 10

Your task is to generate a Cypher query that adheres to the schema and satisfies the question.
""",
    input_variables=["entities", "relationships", "validation_schema", "query"]
)

# Define the user's query
query = "List all parameters that affect the flight mode 'Loiter' and provide their relationships."

# Define the chain
chain = GraphCypherQAChain.from_llm(
    llm=llm_chat,
    graph=graph,
    verbose=True,
    cypher_prompt=cypher_prompt,
    top_k=10,  # Max number of results to retrieve
    allow_dangerous_requests=True,
    return_direct=True
)

# Test the chain with the schema-aware prompt
response = chain.invoke({
    "query": query,
    "entities": VALID_ENTITIES,
    "relationships": VALID_RELATIONS,
    "validation_schema": VALIDATION_SCHEMA
})

# Extract and print the generated query
print("Generated Cypher Query:")
print(response['query'])

# Extract and print the result from Neo4j
print("Result:")
print(response['result'])

# Use the result to craft a user-friendly answer
answer_prompt = f"""
The user asked the following question:
{query}

Based on the data retrieved, provide a detailed and structured answer.
Here is the result from the database:
{response['result']}
"""

# Get the final answer from the LLM
answer = llm_chat.invoke(answer_prompt)

# Output the result in JSON format
result_json = {
    "query": query,
    "source": response['result'],
    "answer": answer.content  # Extract the content attribute for the answer
}

# Print the formatted result
print(json.dumps(result_json, indent=4))

# Print the answer directly
print("Final Answer:")
print(answer.content)




[1m> Entering new GraphCypherQAChain chain...[0m




Generated Cypher:
[32;1m[1;3mcypher
MATCH (loiterMode:FlightMode)-[rel]->(parameter:Parameter)
WHERE toLower(loiterMode.name) CONTAINS toLower('Loiter')
AND type(rel) IN ['SETS', 'ADJUSTS', 'REQUIRES', 'AFFECTS']
RETURN DISTINCT
    loiterMode.name as node_name,
    loiterMode.source_text as context,
    collect(DISTINCT {
        related_node: parameter.name,
        relationship_type: type(rel),
        details: properties(rel)
    }) as relationships
LIMIT 10
[0m

[1m> Finished chain.[0m
Generated Cypher Query:
List all parameters that affect the flight mode 'Loiter' and provide their relationships.
Result:
[]
{
    "query": "List all parameters that affect the flight mode 'Loiter' and provide their relationships.",
    "source": [],
    "answer": "In the context of drone flight control, particularly with systems like ArduPilot, the 'Loiter' mode is a flight mode that allows the drone to maintain a stable position and altitude, effectively hovering in place. Several parameters 

In [247]:
!pip3 install python-Levenshtein

Collecting python-Levenshtein
  Downloading python_Levenshtein-0.26.1-py3-none-any.whl.metadata (3.7 kB)
Collecting Levenshtein==0.26.1 (from python-Levenshtein)
  Downloading levenshtein-0.26.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (3.2 kB)
Downloading python_Levenshtein-0.26.1-py3-none-any.whl (9.4 kB)
Downloading levenshtein-0.26.1-cp311-cp311-macosx_11_0_arm64.whl (157 kB)
Installing collected packages: Levenshtein, python-Levenshtein
Successfully installed Levenshtein-0.26.1 python-Levenshtein-0.26.1


In [308]:
import os
import logging
import asyncio
from typing import Dict, List, Tuple, Optional, Any
from langchain_neo4j import Neo4jGraph
from langchain_openai import OpenAIEmbeddings, OpenAI
from Levenshtein import distance as levenshtein_distance
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain


# Initialize logging
logger = logging.getLogger('graph_rag_test')
logger.setLevel(logging.DEBUG)
if not logger.handlers:
    logger.addHandler(logging.StreamHandler())
     
     

VALID_ENTITIES = [
        "3DR Power Module",
        "APM",
        "APMPlanner",
        "ATC_THR_MIX_MAN",
        "Aircraft",
        "Alt Hold Mode",
        "Altitude",
        "ArduPilot",
        "Auto Trim",
        "Autotune",
        "Axis",
        "Barometer",
        "Battery",
        "Battery Voltage Monitor",
        "CTUN.ThO",
        "CompanionComputers",
        "Copter",
        "D_Term",
        "Dataflash Log",
        "DroidPlanner",
        "ESC (Electronic Speed Controller)",
        "Flight Mode",
        "FlightMode",
        "FollowMeMode",
        "GCS",
        "GPSDevice",
        "GUID_OPTIONS",
        "GUID_TIMEOUT",
        "GroundStation",
        "GroundStationApplication",
        "GuidedMode",
        "Guided_NoGPS",
        "Gyroscope",
        "Home Position",
        "Hover Throttle",
        "I_Term",
        "LUA Scripts",
        "Laptop",
        "LiPo Battery",
        "Loiter Mode",
        "LoiterMode",
        "MAVLink",
        "MOT_HOVER_LEARN",
        "MOT_THST_HOVER",
        "ManualTuning",
        "Mission Planner",
        "MissionPlanner",
        "Motor",
        "Oscillation",
        "P-PCB",
        "P_Term",
        "Parameter",
        "Phone",
        "Pilot",
        "Propeller",
        "RC Transmitter",
        "Simple Mode",
        "SpeedControl",
        "Stabilize Mode",
        "Tablet",
        "TelemetryRadio",
        "TransmitterTuning",
        "Video",
    "Wind"
]

VALID_RELATIONS = [
    "ADJUSTS",
        "AFFECTS",
        "ARMED_AT",
        "CALCULATES",
        "CALIBRATES",
        "COMMANDS",
        "COMMUNICATES_WITH",
        "CONNECTS_TO",
        "CONTROLS",
        "DISABLES",
        "DOWNLOADS",
        "ENABLES",
        "ENSURES",
        "FOLLOWS",
        "FUNCTIONS_AS",
        "GENERATES_LIFT",
        "GUIDES",
        "INSTALLS",
        "INTERACTS_WITH",
        "IS_INDIVIDUAL",
        "IS_STABLE",
        "IS_SYMMETRICAL",
        "IS_UNSTABLE",
        "LEARNS",
        "MONITORS",
        "MOVES_TOWARDS",
        "OBSERVES",
        "PROVIDES",
        "RECEIVES",
        "RECOMMENDS",
        "REQUIRES",
        "SENDS",
        "SETS",
        "SUPPORTS",
    "USES"
]

VALIDATION_SCHEMA = {
    "Copter": [
            "ARMED_AT",
            "CALIBRATES",
            "COMMANDS",
            "COMMUNICATES_WITH",
            "CONNECTS_TO",
            "CONTROLS",
            "FOLLOWS",
            "GENERATES_LIFT",
            "LEARNS",
            "MOVES_TOWARDS",
            "RECEIVES",
            "USES"
        ],
        "Hover Throttle": [
            "LEARNS",
            "MOVES_TOWARDS"
        ],
        "MOT_THST_HOVER": [
            "SETS"
        ],
        "MOT_HOVER_LEARN": [
            "DISABLES"
        ],
        "Dataflash Log": [
            "DOWNLOADS"
        ],
        "CTUN.ThO": [
            "OBSERVES"
        ],
        "Flight Mode": [
            "LEARNS"
        ],
        "ManualTuning": [
            "ADJUSTS",
            "ENSURES",
            "REQUIRES",
            "USES"
        ],
        "Autotune": [
            "PROVIDES"
        ],
        "Aircraft": [
            "IS_INDIVIDUAL",
            "IS_SYMMETRICAL"
        ],
        "Pilot": [
            "CONTROLS",
            "ENSURES"
        ],
        "ATC_THR_MIX_MAN": [
            "SETS"
        ],
        "Oscillation": [
            "IS_UNSTABLE",
            "OBSERVES"
        ],
        "Axis": [
            "OBSERVES"
        ],
        "P_Term": [
            "ADJUSTS"
        ],
        "D_Term": [
            "ADJUSTS"
        ],
        "I_Term": [
            "ADJUSTS"
        ],
        "GCS": [
            "USES"
        ],
        "TransmitterTuning": [
            "SETS",
            "USES"
        ],
        "Parameter": [
            "ADJUSTS",
            "SETS"
        ],
        "Video": [
            "PROVIDES"
        ],
        "Battery": [
            "CONNECTS_TO",
            "MONITORS"
        ],
        "Gyroscope": [
            "CALIBRATES"
        ],
        "RC Transmitter": [
            "CONTROLS"
        ],
        "Motor": [
            "GENERATES_LIFT"
        ],
        "Propeller": [
            "GENERATES_LIFT"
        ],
        "ESC (Electronic Speed Controller)": [
            "CONTROLS"
        ],
        "Mission Planner": [
            "CALIBRATES"
        ],
        "Stabilize Mode": [
            "RECOMMENDS"
        ],
        "Alt Hold Mode": [
            "RECOMMENDS"
        ],
        "Loiter Mode": [
            "RECOMMENDS"
        ],
        "Simple Mode": [
            "RECOMMENDS"
        ],
        "Home Position": [
            "ARMED_AT"
        ],
        "Wind": [
            "AFFECTS"
        ],
        "Auto Trim": [
            "AFFECTS"
        ],
        "ArduPilot": [
            "SUPPORTS"
        ],
        "Battery Voltage Monitor": [
            "INSTALLS",
            "MONITORS"
        ],
        "3DR Power Module": [
            "SUPPORTS"
        ],
        "LiPo Battery": [
            "CONNECTS_TO"
        ],
        "P-PCB": [
            "CONNECTS_TO"
        ],
        "APM": [
            "FUNCTIONS_AS"
        ],
        "GuidedMode": [
            "COMMANDS",
            "ENABLES",
            "GUIDES",
            "USES"
        ],
        "TelemetryRadio": [
            "COMMUNICATES_WITH",
            "CONNECTS_TO"
        ],
        "GroundStationApplication": [
            "INTERACTS_WITH",
            "USES"
        ],
        "MissionPlanner": [
            "INTERACTS_WITH",
            "SENDS",
            "USES"
        ],
        "MAVLink": [
            "COMMANDS",
            "USES"
        ],
        "LUA Scripts": [
            "COMMANDS"
        ],
        "CompanionComputers": [
            "COMMANDS"
        ],
        "FlightMode": [
            "ENABLES"
        ],
        "SpeedControl": [
            "ADJUSTS"
        ],
        "GUID_OPTIONS": [
            "ENABLES"
        ],
        "GUID_TIMEOUT": [
            "REQUIRES"
        ],
        "Guided_NoGPS": [
            "COMMANDS",
            "USES"
        ],
        "GroundStation": [
            "CONTROLS",
            "SENDS"
        ],
        "APMPlanner": [
            "USES"
        ],
        "DroidPlanner": [
            "USES"
        ],
        "Laptop": [
            "CONNECTS_TO"
        ],
        "Phone": [
            "CONNECTS_TO"
        ],
        "Tablet": [
            "CONNECTS_TO"
        ],
        "GPSDevice": [
            "CONNECTS_TO",
            "SENDS"
        ],
        "LoiterMode": [
            "SETS"
        ],
        "FollowMeMode": [
            "SETS"
        ],
        "Altitude": [
            "CALCULATES"
        ],
        "Barometer": [
            "USES"
        ]
    }     



class GraphQueryTester:
    def __init__(self, entities: List[str], relationships: List[str], validation_schema: dict):
        logger.info("Initializing GraphQueryTester")
        try:
            self.graph = Neo4jGraph(
                url=os.getenv("NEO4J_URI"),
                username=os.getenv("NEO4J_USERNAME"),
                password=os.getenv("NEO4J_PASSWORD"),
                database=os.getenv("NEO4J_DATABASE")
            )
            logger.info("Connected to Neo4j")

            # Attempt to create indexes on initialization. 
            # If they already exist, 'IF NOT EXISTS' will ensure no error.
            self.create_indexes()
            self.verify_database_structure()

        except Exception as e:
            logger.error(f"Error connecting to Neo4j: {str(e)}", exc_info=True)
            raise

        self.embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
        self.llm = ChatOpenAI(model="gpt-4o", temperature=0)

        self.entities = entities
        self.relationships = relationships
        self.validation_schema = validation_schema
        self.template_multiple_general = """
        Given these exact node names and relationships:
        Node: "Stabilize Mode" (not StabilizeMode)
        Relationships: {relationships}
        
        The user wants to know: "{user_query}"
        
        {fallback_context}
        
        Generate 5 Cypher queries to find information about Stabilize Mode issues and fixes.
        Return as a raw JSON array of strings.
        
        Requirements:
        - Use exact node name 'Stabilize Mode' with quotes
        - Use relationships from the provided list
        - Return exactly 5 queries
        - No markdown, no code blocks
        
        Example format:
        [
            "MATCH (n {{name: 'Stabilize Mode'}})-[r:REQUIRES]->(m) RETURN n, r, m",
            "MATCH (n {{name: 'Stabilize Mode'}})<-[r:USES]-(m) RETURN n, r, m"
        ]
        """

    def create_indexes(self):
        """Create necessary indexes for the graph database"""
        try:
            # First check if vector index exists
            check_index = """
            SHOW INDEXES 
            WHERE type = 'VECTOR' 
            AND name = 'nodeEmbeddings'
            """
            
            # Create vector index if it doesn't exist
            vector_index_cypher = """
            CREATE VECTOR INDEX nodeEmbeddings IF NOT EXISTS
            FOR (n:Node) ON (n.embedding)
            OPTIONS {
            indexConfig: {
                `vector.dimensions`: 1536,
                `vector.similarity_function`: 'cosine'
            }
            }
            """

            # Create name index
            name_index_cypher = """
            CREATE INDEX node_name_idx IF NOT EXISTS 
            FOR (n:Node) ON (n.name)
            """

            # Execute in order
            existing_indexes = self.graph.query(check_index)
            if not existing_indexes:
                self.graph.query(vector_index_cypher)
                logger.info("Created vector index")
            
            self.graph.query(name_index_cypher)
            logger.info("Successfully created/verified indexes")

        except Exception as e:
            logger.error(f"Error creating indexes: {str(e)}", exc_info=True)
            # Continue even if index creation fails
            pass
        
    def execute_cypher_query(self, cypher_query: str, params: dict = None) -> List[Dict]:
        try:
            logger.debug(f"Executing Cypher query:\n{cypher_query}")
            cypher_query = self._fix_relationship_types(cypher_query)
            results = list(self.graph.query(cypher_query, params))
            return self._sanitize_results(results)
        except Exception as e:
            logger.error(f"Error executing Cypher query: {str(e)}", exc_info=True)
            return []

    def _fix_relationship_types(self, query: str) -> str:
        pattern = r":([A-Za-z0-9_]+)\|:([A-Za-z0-9_]+)"
        while re.search(pattern, query):
            query = re.sub(pattern, r":\1|\2", query)
        return query

    def semantic_search(self, query: str, top_n: int = 5) -> List[Dict]:
        """Perform semantic search using embeddings"""
        try:
            query_embedding = self.embedding_model.embed_query(query)
            logger.info(f"Generated query embedding (showing first 5 dimensions): {query_embedding[:5]}...")

            # Vector search query that works with any labeled node
            vector_query = """
            CALL {
                CALL db.index.vector.queryNodes('nodeEmbeddings', $k, $embedding)
                YIELD node, score
                RETURN node, score
            }
            RETURN 
                node.id AS id,
                node.name AS name,
                labels(node)[0] AS label,
                score
            ORDER BY score DESC
            LIMIT $k
            """
            
            params = {
                "k": top_n,
                "embedding": query_embedding
            }
            
            logger.debug(f"Executing vector query with embedding: {query_embedding}")
            results = self.execute_cypher_query(vector_query, params)
            logger.debug(f"Semantic search results: {results}")
            
            return results

        except Exception as e:
            logger.error(f"Error in semantic search: {str(e)}", exc_info=True)
            # Fallback to text-based search
            fallback_query = f"""
            MATCH (n)
            WHERE n.name STARTS WITH '{query[:3]}' 
            OR n.name ENDS WITH '{query[-3:]}' 
            OR n.name CONTAINS '{query}'
            RETURN 
                n.id as id,
                n.name as name
            LIMIT {top_n}
            """
            return self.execute_cypher_query(fallback_query)

    def label_search(self, label: str, value: str) -> List[Dict]:
        # Uses the label/property index
        # Ensure that the node label and property align with how your data is modeled.
        query = f"""
        MATCH (n:{label} {{name: $value}})
        RETURN n
        """
        params = {"value": value}
        return self.execute_cypher_query(query, params)

    def find_similar_nodes(self, term: str, limit: int = 50) -> List[str]:
        cypher = f"""
        MATCH (n)
        WHERE n.name STARTS WITH '{term[:3]}' 
           OR n.name ENDS WITH '{term[-3:]}' 
           OR n.name CONTAINS '{term[:3]}'
        RETURN DISTINCT n.name AS name
        LIMIT {limit}
        """
        results = self.execute_cypher_query(cypher)
        candidate_names = [r["name"] for r in results if "name" in r]
        return candidate_names

    def get_best_match_for_term(self, term: str) -> Tuple[str, float]:
        semantic_results = self.semantic_search(term, top_n=5)
        if semantic_results:
            best_semantic = max(semantic_results, key=lambda x: x.get("score", 0))
            if best_semantic.get("score", 0) > 0.8:
                logger.info(f"Semantic best match for {term}: {best_semantic['name']} (score={best_semantic['score']})")
                return best_semantic["name"], best_semantic["score"]

        candidates = self.find_similar_nodes(term, limit=50)
        if not candidates:
            logger.info(f"No candidates found for {term} using wildcard search.")
            return "", 0.0

        best_candidate = None
        best_distance = float('inf')
        term_lower = term.lower()
        for c in candidates:
            d = levenshtein_distance(term_lower, c.lower())
            if d < best_distance:
                best_distance = d
                best_candidate = c

        max_length = max(len(term), len(best_candidate)) if best_candidate else len(term)
        similarity_score = 1 - (best_distance / max_length) if max_length > 0 else 0
        logger.info(f"Best fuzzy match for {term}: {best_candidate} with similarity {similarity_score:.2f}")
        return best_candidate, similarity_score

    def generate_multiple_cypher_queries(self, user_query: str, search_terms: List[str]) -> List[str]:
        try:
            if search_terms:
                fallback_context = f"Consider these related nodes: {', '.join(search_terms)}"
            else:
                fallback_context = "No related nodes found."

            chain = LLMChain(llm=self.llm, prompt=PromptTemplate(
                template=self.template_multiple_general,
                input_variables=["relationships", "user_query", "fallback_context"]
            ))

            # Generate queries using valid relationships
            response = chain.run(
                relationships=", ".join(self.relationships),
                user_query=user_query,
                fallback_context=fallback_context
            )

            # Clean and parse response
            response = response.strip()
            response = response.replace('```json', '').replace('```', '')
            response = response.strip('[]')
            response = f'[{response}]'

            queries = json.loads(response)
            
            # Validate and fix queries
            fixed_queries = []
            for q in queries:
                if isinstance(q, str):
                    # Fix relationship syntax
                    q = self._fix_relationship_types(q.strip())
                    fixed_queries.append(q)

            return fixed_queries[:5]

        except Exception as e:
            logger.error(f"Error generating queries: {str(e)}\nResponse was: {response}", exc_info=True)
            return []

    def run_multiple_queries_in_parallel(self, queries: List[str]) -> List[Dict]:
        combined_results = []
        for q in queries:
            res = self.execute_cypher_query(q)
            if res:
                combined_results.extend(res)
        return combined_results

    def _sanitize_value(self, value: Any) -> Any:
        if isinstance(value, dict):
            if 'embedding' in value:
                value = {k: v for k, v in value.items() if k != 'embedding'}
            clean_dict = {}
            for k, v in value.items():
                clean_dict[k] = self._sanitize_value(v)
            return clean_dict
        elif isinstance(value, list):
            return [self._sanitize_value(item) for item in value]
        else:
            return value

    def _sanitize_results(self, results: List[Dict]) -> List[Dict]:
        sanitized_results = []
        for result in results:
            clean_result = {}
            for key, value in result.items():
                clean_result[key] = self._sanitize_value(value)
            sanitized_results.append(clean_result)
        return sanitized_results
    
    def verify_database_structure(self):
        """Verify database structure and indexes"""
        try:
            # Check labels
            labels_query = "CALL db.labels()"
            labels = self.graph.query(labels_query)
            logger.info(f"Found labels: {labels}")
            
            # Check indexes
            indexes_query = "SHOW INDEXES"
            indexes = self.graph.query(indexes_query)
            logger.info(f"Found indexes: {indexes}")
            
            # Check a sample node with embedding
            sample_query = """
            MATCH (n:Node)
            RETURN n
            """
            sample = self.graph.query(sample_query)
            logger.info(f"Sample node with embedding: {sample}")
            
            return True
        except Exception as e:
            logger.error(f"Error verifying database structure: {str(e)}")
            return False


# Example usage:
entities = VALID_ENTITIES
relationships = VALID_RELATIONS
validation_schema = VALIDATION_SCHEMA

tester = GraphQueryTester(entities, relationships, validation_schema)
user_query = "Stabilize mode is not working, what can I do to fix it?"
key_term = "stabilize mode"
matched_name, score = tester.get_best_match_for_term(key_term)
search_terms = tester.find_similar_nodes(matched_name if matched_name else key_term, limit=10)
queries = tester.generate_multiple_cypher_queries(user_query, search_terms)
if queries:
    all_results = tester.run_multiple_queries_in_parallel(queries)
    print("Combined Results:", all_results)
else:
    print("No queries generated.")

Initializing GraphQueryTester
INFO:graph_rag_test:Initializing GraphQueryTester
Connected to Neo4j
INFO:graph_rag_test:Connected to Neo4j
Successfully created/verified indexes
INFO:graph_rag_test:Successfully created/verified indexes
Found labels: [{'label': 'Source'}, {'label': 'FlightMode'}, {'label': 'Parameter'}, {'label': 'VehicleType'}, {'label': 'Procedure'}, {'label': 'Guide'}, {'label': 'Hardware'}, {'label': 'Issue'}, {'label': 'Node'}]
INFO:graph_rag_test:Found labels: [{'label': 'Source'}, {'label': 'FlightMode'}, {'label': 'Parameter'}, {'label': 'VehicleType'}, {'label': 'Procedure'}, {'label': 'Guide'}, {'label': 'Hardware'}, {'label': 'Issue'}, {'label': 'Node'}]
Found indexes: [{'id': 0, 'name': 'index_343aff4e', 'state': 'ONLINE', 'populationPercent': 100.0, 'type': 'LOOKUP', 'entityType': 'NODE', 'labelsOrTypes': None, 'properties': None, 'indexProvider': 'token-lookup-1.0', 'owningConstraint': None, 'lastRead': neo4j.time.DateTime(2024, 12, 12, 9, 28, 13, 163000000,

Combined Results: [{'n': {'name': 'Stabilize Mode', 'id': '7'}, 'r': ({'name': 'Stabilize Mode', 'id': '7'}, 'REQUIRES', {'name': 'Return To Launch', 'id': '4'}), 'm': {'name': 'Return To Launch', 'id': '4'}}, {'n': {'name': 'Stabilize Mode', 'id': '7'}, 'r': ({'name': 'Stabilize Mode', 'id': '7'}, 'REQUIRES', {'name': "Mission Planner's Config/Tuning | Extended Tuning screen", 'id': '8'}), 'm': {'name': "Mission Planner's Config/Tuning | Extended Tuning screen", 'id': '8'}}, {'n': {'name': 'Stabilize Mode', 'id': '7'}, 'r': ({'name': 'Stabilize Mode', 'id': '7'}, 'REQUIRES', {'name': 'Copter', 'id': '2'}), 'm': {'name': 'Copter', 'id': '2'}}, {'n': {'name': 'Stabilize Mode', 'id': '7'}, 'r': ({'name': 'Stabilize Mode', 'id': '7'}, 'REQUIRES', {'name': 'Common Lua Scripts', 'id': '4'}), 'm': {'name': 'Common Lua Scripts', 'id': '4'}}, {'n': {'name': 'Stabilize Mode', 'id': '7'}, 'r': ({'name': 'Stabilize Mode', 'id': '7'}, 'REQUIRES', {'name': 'Active braking/Damped light', 'id': '2', 