# Create Graph-RAG with Neo4j

## Code to Natural Language

In [1]:
import os

def generate_repo_tree(repo_path, indent=""):
    tree_string = ""
    for root, dirs, files in os.walk(repo_path):
        # Filter out __pycache__ and hidden directories
        dirs[:] = [d for d in dirs if d != "__pycache__" and not d.startswith(".")]
        files = [f for f in files if not f.startswith(".")]

        level = root.replace(repo_path, "").count(os.sep)
        indent = "│   " * level + "├── "  # Formatting the tree
        tree_string += f"{indent}{os.path.basename(root)}/\n"

        sub_indent = "│   " * (level + 1) + "├── "
        for file in files:
            tree_string += f"{sub_indent}{file}\n"

    return tree_string

# Set your repo path
repo_path = "./assignment-2-mcdonald-s/src"  # Change this to your cloned repo path

# Generate tree and store as string
repo_tree_string = generate_repo_tree(repo_path)

# Print the repo tree
print(repo_tree_string)

# Store it as a variable to feed into an LLM
# llm_input = f"Here is the repository structure:\n{repo_tree_string}"


├── src/
│   ├── deduplication/
│   │   ├── bloom_filter.py
│   │   ├── dedup.py
│   │   ├── LSH.py
│   │   ├── LSHForest.py
│   │   ├── LSHImproved.py
│   │   ├── __init__.py
│   │   ├── __main__.py
│   ├── utils/
│   │   ├── use_cases.py
│   │   ├── utils.py
│   │   ├── visualizations.py
│   │   ├── visualization_lsh.py



In [98]:
from langchain.chat_models import init_chat_model
from typing_extensions import Annotated, TypedDict
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.documents import Document
from langchain_experimental.graph_transformers import LLMGraphTransformer


llm = init_chat_model("gpt-4o-mini", model_provider="openai")

system_prompt = """
    You are an expert in analyzing Python code and generating structured natural language descriptions for graph-based querying in Cypher. 
    Given a Python codebase, extract meaningful relationships between functions, classes, and imported modules. 
    
    Only use the list of types provided below:
    - class : classes defined in a module
    - method : methods defined in a class
    - function : functions defined in a module
    - module : python scripts defined within the repository. Exclude .py when mentioning the module name.
    - package : packages imported that are not modules.
    Do not include information on variables, parameters, arguments.
    Python scripts must be modules and pre defined packages such as numpy and pandas must be packages

    When generating the structured natural language description, follow these rules:    
    - Do not give explanations for the code logic or functionality.
    - Do not use adjectives and adverbs. 
    - Only describe the code and do not give an overall summary.
    - Do not use ambiguous pronouns and use exact names in every description.
    - Explain each class, function separately and do not include explanations such as 'as mentioned before' or anything that refers to a previous explanation.
    - make each description sufficient for a standalone statement for one relationship in the graph.    
    - Each class and funciton should be connected to the module where it was defined.
    - Each imported package should be connected to the function, method or class where it was used.
    - Always include an explanation on how the outermost class or method is connected to the module where it is defined.
    - If the outermost layer is an 'if __name__ == "__main__":' block, then the outermost layer is whatever is inside the block. Plus whatever is defined outside the block. Make sure to mention the connection between the module and the closses and functions.
    - When mentioning modules, take note of the current file path(relative repository) given in input, and change slashes to dots and remove the .py extension.
    - If a function or class is used in another function or class, make sure to mention the connection between them.
    - 
            
    Natural language should follow a similar format as below:
        {source.id} is a {source.type} with properties {source.properties} defined in {target.id} which is a {target.type}.        
    Example: 
    - When mentioning classes, always refer them as {relative_repository}.{module_name}.{class_name}
    - When mentioning methods, always refer to them as {relative_repository}.{module_name}.{class_name}.{method_name}
    - When mentioning functions, always refer them as {relative_repository}.{module_name}.{function_name}
    - If the file path is deduplication/LSH.py and there is a class LSH in it, the module is deduplication.LSH and the class is deduplication.LSH.LSH.

    Example:
    deduplication.LSHImproved.LSHImproved is a module that defines the class deduplication.LSH.lsh_base, which consists of  method deduplication.LSH.lsh_base.hash_function.
    deduplication.LSH.lsh_base is a class and inherits from the class utils.utils.BaseLSH.
    numpy is a package and is used in the method deduplication.LSH.lsh_base.hash_function.
    bitarray is a package and is used in the deduplication.bloom_filter.BloomFilter_KM_Opt.__init__ method of the class deduplication.bloom_filter.BloomFilter_KM_Opt.  
    deduplication.LSH.lsh_base is a class defined in the module deduplication.LSH.
    
    If a module from our repository is imported in another module in our repository, refer to it as the entire path of the module.
    Example:
    code within deduplication\\__main__.py : from deduplication.LSHImproved import LSHImproved
    Natural Language Description:
    The Class deduplication.LSHImproved.LSHImproved is imported into the module deduplication.__main__.
    """



import os

i = 0
descriptions = {}

for root, dirs, files in os.walk(repo_path):
    
    # Skip hidden directories (e.g., .git, .idea, __pycache__)
    dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]

    for file in files:
        if file.startswith("."):
            continue  # Skip hidden files

        file_path = os.path.join(root, file)
        relative_path = os.path.relpath(file_path, repo_path)
        print(relative_path)

        try:
            with open(file_path, "r", encoding="utf-8") as f:
                lsh_code = f.read()

            messages = [
                SystemMessage(system_prompt),
                HumanMessage(f'''
                    Tree:
                    {repo_tree_string}

                    Current File Path:
                    {relative_path}

                    Code:
                    {lsh_code}
                ''')
            ]

            response = llm.invoke(messages)
            descriptions[relative_path] = response.content
            i += 1

        except Exception as e:
            print(f"Error processing {file_path}: {e}")

        
        # if i > 1:
        #     break


deduplication\bloom_filter.py
deduplication\dedup.py
deduplication\LSH.py
deduplication\LSHForest.py
deduplication\LSHImproved.py
deduplication\__init__.py
deduplication\__main__.py
utils\use_cases.py
utils\utils.py
utils\visualizations.py
utils\visualization_lsh.py


In [99]:
for i in descriptions:
    print(i)
    print(descriptions[i])
    print("-------------------")

deduplication\bloom_filter.py
deduplication.bloom_filter.BloomFilter is a class defined in the module deduplication.bloom_filter. 
deduplication.bloom_filter.BloomFilter.__init__ is a method defined in the class deduplication.bloom_filter.BloomFilter. 
math is a package and is used in the method deduplication.bloom_filter.BloomFilter.__init__. 
bitarray is a package and is used in the method deduplication.bloom_filter.BloomFilter.__init__.
mmh3 is a package and is used in the method deduplication.bloom_filter.BloomFilter.add.
ngrams is a package and is used in the method deduplication.bloom_filter.BloomFilter.add.
mmh3 is a package and is used in the method deduplication.bloom_filter.BloomFilter.query.
ngrams is a package and is used in the method deduplication.bloom_filter.BloomFilter.query.

deduplication.bloom_filter.BloomFilter_KM_Opt is a class defined in the module deduplication.bloom_filter. 
deduplication.bloom_filter.BloomFilter_KM_Opt.__init__ is a method defined in the class

In [100]:
system_prompt = """
    You are an expert text editor. Your goal is to modyfy the text in a way that it is consistent with the given repository tree and file path.
    Keep in mind this natural language is meant to be used for graph-based querying in Cypher.
    Only output the modified text, do not give any explanations or anything else.
    The only change you have to make is to modify potential node ids, so they are consistent and can be connected by a graph.
    
    Rules:
    - When mentioning modules, take note of the current file path(relative repository) given in input, and change slashes to dots and remove the .py extension.
    - This means that the current file path or module name should be the beginning of all classes/methods/functions defined in the module.    
    - When mentioning classes, always refer them as {module_location_name}.{class_name}
    - When mentioning methods, always refer to them as {module_location_name}.{class_name}.{method_name}
    - When mentioning functions, always refer them as {module_location_name}.{function_name}

    Example: 
    - if the relative path is deduplication/LSH.py and there is a class LSH in it, the module is deduplication.LSH and the class is deduplication.LSH.LSH.
    - Given the current file path deduplication\\dedup.py
        'deduplication.Baseline is a class defined in deduplication.dedup' is extremely incorrect. 
        The correct description is 'deduplication.dedup.Baseline is a class defined in deduplication.dedup'
    
    
    Example:
    - If the file location is deduplication/LSH.py and there is a class LSH in it, the module should be deduplication.LSH and the class should be deduplication.LSH.LSH. in the natural language description.
    - If the original text is bloom_filter.BloomFilter is a class, and the current file is deduplication\\bloom_filter.py, the modified text should be deduplication.bloom_filter.BloomFilter is a class.
    - If my current module imports a function from a module 'utils.utils import clean_document' the modified text should be utils.utils.clean_document is a function.


"""

refined_descriptions = {}
for i in descriptions:

    messages = [
        SystemMessage(system_prompt),
        HumanMessage(f'''
                     
            Repository Tree:
            {repo_tree_string}

            Current File Path:
            {i}

            Natural Language Desccription:
            {descriptions[i]}
        ''')
    ]

    response = llm.invoke(messages)

    refined_descriptions[i] = response.content

In [101]:
for i in refined_descriptions:
    print(i)
    print(refined_descriptions[i])
    print("-------------------")

deduplication\bloom_filter.py
deduplication.bloom_filter.BloomFilter is a class defined in the module deduplication.bloom_filter. 
deduplication.bloom_filter.BloomFilter.__init__ is a method defined in the class deduplication.bloom_filter.BloomFilter. 
math is a package and is used in the method deduplication.bloom_filter.BloomFilter.__init__. 
bitarray is a package and is used in the method deduplication.bloom_filter.BloomFilter.__init__.
mmh3 is a package and is used in the method deduplication.bloom_filter.BloomFilter.add.
ngrams is a package and is used in the method deduplication.bloom_filter.BloomFilter.add.
mmh3 is a package and is used in the method deduplication.bloom_filter.BloomFilter.query.
ngrams is a package and is used in the method deduplication.bloom_filter.BloomFilter.query.

deduplication.bloom_filter.BloomFilter_KM_Opt is a class defined in the module deduplication.bloom_filter. 
deduplication.bloom_filter.BloomFilter_KM_Opt.__init__ is a method defined in the class

In [102]:
from langchain_core.documents import Document

description_document = []
# for i in descriptions:
for i in refined_descriptions:
    
    document = Document(
        # page_content = descriptions[i],
        page_content = refined_descriptions[i],
        metadata = {"source": i}
    )
    
    description_document.append(document)

In [103]:
# from langchain_text_splitters import RecursiveCharacterTextSplitter

# llm = init_chat_model("gpt-4o-mini", model_provider="openai")

# text_splitter = RecursiveCharacterTextSplitter(
#     # Set a really small chunk size, just to show.
#     chunk_size=1000,
#     # chunk_overlap=100,
#     length_function=len,
#     is_separator_regex=False,
# )

# text_splitter = RecursiveCharacterTextSplitter(
#     separators=[
#         r"\n\n",
#         r"\n",
#         r"\\n"
#     ],
#     is_separator_regex=True,
#     keep_separator=False,
#     chunk_size=500,
#     chunk_overlap=0,
# )

# docs = text_splitter.create_documents([description])

## Natural Language to GraphDB

In [104]:
from langchain_core.prompts import SystemMessagePromptTemplate, PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
a = """# Knowledge Graph Instructions for GPT-4\n## 1. Overview\nYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.\nTry to capture as much information from the text as possible without sacrificing accuracy. Do not add any information that is not explicitly mentioned in the text.\n- **Nodes** represent entities and concepts.\n- The aim is to achieve simplicity and clarity in the knowledge graph, making it\naccessible for a vast audience.\n## 2. Labeling Nodes\n- **Consistency**: Ensure you use available types for node labels.\nEnsure you use basic or elementary types for node labels.\n- For example, when you identify an entity representing a person, always label it as **\'person\'**. Avoid using more specific terms like \'mathematician\' or \'scientist\'.- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.\n- **Relationships** represent connections between entities or concepts.\nEnsure consistency and generality in relationship types when constructing knowledge graphs. Instead of using specific and momentary types such as \'BECAME_PROFESSOR\', use more general and timeless relationship types like \'PROFESSOR\'. Make sure to use general and timeless relationship types!\n## 3. Coreference Resolution\n- **Maintain Entity Consistency**: When extracting entities, it\'s vital to ensure consistency.\nIf an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the entity ID.\nRemember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.\n## 4. Strict Compliance\nAdhere to the rules strictly. Non-compliance will result in termination.\n"""
# Define the prompt template with variables
system_prompt = a

prompt = ChatPromptTemplate.from_template(system_prompt)

chat_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt,
        ),
        (
            "human",
            (
                """
                A period does not mean an end of a sentence. It is part of a node id. Only a line break means an end of a sentence.
                Take this into account when identifying node ids.

                When translating names into node ids, do not shorten anything. use the entire name as it is.
                For example, if the class name is deduplication.LSHForest.LSHForest, do not shorten it to LSHForest or deduplication.LSHForest.
                Use the full name deduplication.LSHForest.LSHForest.
                Here is the text to analyze:\n\n{input}"""
            ),
        ),
    ]
)


from langchain_experimental.graph_transformers import LLMGraphTransformer


llm_transformer = LLMGraphTransformer(
    llm=llm,
    allowed_nodes=["class", "method", "function",'package','module'],
    # allowed_relationships=["NATIONALITY", "LOCATED_IN", "WORKED_AT", "SPOUSE"],
    # node_properties=['defined_in'],
    # prompt=chat_prompt
)


# llm_transformer = LLMGraphTransformer(llm=llm)
# logger.info(f"documents:{documents}")
# graph_documents = llm_transformer.convert_to_graph_documents(docs)
graph_documents = llm_transformer.convert_to_graph_documents(description_document)


In [105]:
graph_documents[6].relationships

[Relationship(source=Node(id='Deduplication.__Main__', type='Module', properties={}), target=Node(id='Deduplication.__Main__.Log_Memory_Usage', type='Function', properties={}), type='DEFINES', properties={}),
 Relationship(source=Node(id='Deduplication.__Main__.Log_Memory_Usage', type='Function', properties={}), target=Node(id='Psutil', type='Package', properties={}), type='USES', properties={}),
 Relationship(source=Node(id='Deduplication.__Main__', type='Module', properties={}), target=Node(id='Deduplication.__Main__.Model', type='Function', properties={}), type='DEFINES', properties={}),
 Relationship(source=Node(id='Deduplication.__Main__.Model', type='Function', properties={}), target=Node(id='Deduplication.Lshimproved.Lsh', type='Class', properties={}), type='USES', properties={}),
 Relationship(source=Node(id='Deduplication.__Main__.Model', type='Function', properties={}), target=Node(id='Deduplication.Lshforest.Lshforest', type='Class', properties={}), type='USES', properties={

In [106]:
print(refined_descriptions['deduplication\\LSHForest.py'])

deduplication.LSHForest.LSHForest is a class defined in the module deduplication.LSHForest. 
deduplication.LSHForest.LSHForest.__init__ is a method defined in the class deduplication.LSHForest.LSHForest.
deduplication.LSHForest.LSHForest.banding is a method defined in the class deduplication.LSHForest.LSHForest.
deduplication.LSH.LSH is a class and is inherited by the class deduplication.LSHForest.LSHForest.
utils.utils.split_dict is a function used in the method deduplication.LSHForest.LSHForest.banding.
utils.utils.majority_vote is a function used in the method deduplication.LSHForest.LSHForest.banding.
collections.defaultdict is a package and is used in the method deduplication.LSHForest.LSHForest.banding.
itertools is a package and is used in the method deduplication.LSHForest.LSHForest.banding.


## Create Visualization

In [107]:
from pyvis.network import Network
import networkx as nx
import matplotlib.pyplot as plt

# Create Pyvis network
net = Network(notebook=True, cdn_resources='in_line', height="1000px", width="100%")

# Create a NetworkX graph
G = nx.Graph()

# Dictionary to track unique nodes and metadata
node_metadata = {}

# Helper to create a hashable key from type, id, and properties
def make_node_key(node):
    return (
        node.type,
        node.id,
        tuple(sorted(node.properties.items()))
    )

# Add nodes and edges
for graph in graph_documents:
    for rel in graph.relationships:
        if rel.source.type == "Package" or rel.target.type == "Package":
            continue  # Skip packages entirely
        # Get full, unique node keys
        source_key = make_node_key(rel.source)
        target_key = make_node_key(rel.target)
        rel_type = rel.type

        # Store node metadata
        node_metadata[source_key] = {
            "id": rel.source.id,
            "type": rel.source.type,
            "properties": rel.source.properties
        }
        node_metadata[target_key] = {
            "id": rel.target.id,
            "type": rel.target.type,
            "properties": rel.target.properties
        }

        # Add nodes and edges
        G.add_node(source_key)
        G.add_node(target_key)
        G.add_edge(source_key, target_key, label=rel_type)

# Get unique types for coloring
unique_types = list(set(meta["type"] for meta in node_metadata.values()))
color_map = plt.get_cmap("tab10")
type_colors = {t: color_map(i / len(unique_types)) for i, t in enumerate(unique_types)}
type_colors_rgba = {
    t: f'rgba({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)}, 0.8)' for t, c in type_colors.items()
}

# Degree-based sizing
degrees = dict(G.degree())
min_size, max_size = 10, 50
max_degree = max(degrees.values()) if degrees else 1
size_scale = {
    node: min_size + (max_size - min_size) * (deg / max_degree)
    for node, deg in degrees.items()
}

# Add nodes to Pyvis
for node_key in G.nodes():
    metadata = node_metadata[node_key]
    label = metadata["id"]
    node_type = metadata["type"]
    properties = metadata.get("properties", {})
    color = type_colors_rgba.get(node_type, "gray")

    # Property display
    props_html = "<br>".join(f"{k}: {v}" for k, v in properties.items()) if properties else "No properties"

    net.add_node(
        str(node_key),  # string key for Pyvis
        label=label,
        size=size_scale[node_key],
        color=color,
        title=f"<b>{node_type}</b> ({label})<br>{props_html}"
    )

# Add edges
for source, target, attr in G.edges(data=True):
    rel_label = attr.get("label", "")
    net.add_edge(str(source), str(target), title=rel_label, label=rel_label)

# Save graph
net.save_graph("graph_simple.html")

# Build legend
legend_html = """
<div id="legend" style="position: absolute; top: 10px; left: 10px; background: white; padding: 10px; border-radius: 8px; box-shadow: 0px 0px 5px rgba(0,0,0,0.2); font-family: Arial, sans-serif; z-index: 1000;">
    <h4 style="margin: 0; padding-bottom: 5px;">Node Legend</h4>
"""

for node_type, color in type_colors_rgba.items():
    legend_html += f'<div style="display: flex; align-items: center; margin-bottom: 5px;"><div style="width: 15px; height: 15px; background:{color}; margin-right: 5px; border-radius: 50%;"></div> {node_type}</div>'

legend_html += "</div>"

# Inject legend
with open("graph_simple.html", "r", encoding="utf-8") as file:
    html_content = file.read()

html_content = html_content.replace("</body>", legend_html + "</body>")

with open("graph_simple.html", "w", encoding="utf-8") as file:
    file.write(html_content)

print("Graph with fully disambiguated nodes saved as graph_simple.html")


Graph with fully disambiguated nodes saved as graph_simple.html


In [3]:
# from dotenv import load_dotenv
# import os
# from neo4j import GraphDatabase

# load_dotenv()

# URI = os.getenv("NEO4J_URI")
# USER = os.getenv("NEO4J_USERNAME")
# PWD = os.getenv("NEO4J_PASSWORD")

# print("Trying:", URI)

# driver = GraphDatabase.driver(uri=URI, auth=(USER, PWD))

# try:
#     driver.verify_connectivity()
#     print("✅ Connected to Aura!")
# except Exception as e:
#     print("❌ Still not working:", e)
# finally:
#     driver.close()


# Create / Connect to Database

In [108]:
from langchain_neo4j import Neo4jGraph

import os
from dotenv import load_dotenv

# You can pass the path if the file isn't in the same directory
load_dotenv(dotenv_path='../../.env')

# Access your variables
url = os.getenv('NEO4J_URI')
username = os.getenv('NEO4J_USERNAME')
password = os.getenv('NEO4J_PASSWORD')

graph = Neo4jGraph(url=url, username=username, password=password)



## Add graphdocuments do Neo4j Database

In [None]:
# Add documents
graph.add_graph_documents(graph_documents)



# graph.query("MATCH (n) DETACH DELETE n")
# graph.query(QUERY, genre="action")


In [110]:
graph.refresh_schema()
print(graph.schema)

Node properties:
Class {id: STRING}
Method {id: STRING}
Package {id: STRING}
Module {id: STRING}
Function {id: STRING}
Relationship properties:

The relationships:
(:Class)-[:INCLUDES]->(:Method)
(:Class)-[:USES]->(:Package)
(:Class)-[:CONTAINS]->(:Method)
(:Class)-[:INHERITS]->(:Class)
(:Class)-[:DEFINES]->(:Method)
(:Class)-[:IMPORTS]->(:Package)
(:Class)-[:METHOD_OF]->(:Method)
(:Class)-[:DEFINED_IN]->(:Module)
(:Class)-[:DEFINED_IN]->(:Method)
(:Method)-[:USES]->(:Package)
(:Method)-[:USES]->(:Method)
(:Method)-[:USES]->(:Module)
(:Method)-[:USES]->(:Function)
(:Method)-[:USED_IN]->(:Function)
(:Method)-[:USED_IN]->(:Package)
(:Method)-[:CALLS]->(:Method)
(:Method)-[:CALLS]->(:Function)
(:Method)-[:READS]->(:Class)
(:Method)-[:PROCESSES]->(:Class)
(:Method)-[:INITIALIZES]->(:Class)
(:Method)-[:INITIALIZES]->(:Function)
(:Method)-[:DEFINED_IN]->(:Class)
(:Module)-[:CONTAINS]->(:Class)
(:Module)-[:USES]->(:Package)
(:Module)-[:DEFINES]->(:Function)
(:Module)-[:DEFINES]->(:Method)
(:M

In [111]:
enhanced_graph = Neo4jGraph(enhanced_schema=True)
print(enhanced_graph.schema)

Node properties:
- **Class**
  - `id`: STRING Example: "Deduplication.Bloom_Filter.Bloomfilter"
- **Method**
  - `id`: STRING Example: "Deduplication.Bloom_Filter.Bloomfilter.__Init__"
- **Package**
  - `id`: STRING Example: "Math"
- **Module**
  - `id`: STRING Available options: ['Deduplication.Lsh', 'Utils.Utils', 'Joblib', 'Deduplication.__Init__', 'Deduplication.__Main__', 'Utils.Use_Cases', 'Src.Utils.Utils', 'Utils.Visualizations', 'Utils.Visualization_Lsh']
- **Function**
  - `id`: STRING Example: "Utils.Utils.Split_Dict"
Relationship properties:

The relationships:
(:Class)-[:INCLUDES]->(:Method)
(:Class)-[:USES]->(:Package)
(:Class)-[:CONTAINS]->(:Method)
(:Class)-[:INHERITS]->(:Class)
(:Class)-[:DEFINES]->(:Method)
(:Class)-[:IMPORTS]->(:Package)
(:Class)-[:METHOD_OF]->(:Method)
(:Class)-[:DEFINED_IN]->(:Module)
(:Class)-[:DEFINED_IN]->(:Method)
(:Method)-[:USES]->(:Package)
(:Method)-[:USES]->(:Method)
(:Method)-[:USES]->(:Module)
(:Method)-[:USES]->(:Function)
(:Method)-[:U

# Implement GRAPH-RAG
https://python.langchain.com/v0.1/docs/use_cases/graph/prompting/

**Manual GraphRAG Chain**

https://python.langchain.com/docs/tutorials/graph/

In [147]:
from langchain_neo4j import GraphCypherQAChain
from langchain_openai import ChatOpenAI
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT


from langchain_core.prompts import SystemMessagePromptTemplate, PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate

cypher_prompt = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Translate user input into available nodes.
Do not use any other relationship types or properties that are not provided.

Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.

Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.

Guidelines:
Always try and use MATCH path ..... RETURN path to get the entire relationship and do not gnenerate query that returns paris of nodes.
Never use relationships that include packages : WHERE NONE(n IN nodes(path) WHERE n:Package)

Example:
Question : "How is the Utils.Utils module related to the Deduplicatioin.__Main__ module?"
Cypher Query : MATCH path = (m1:Module {{id: 'Utils.Utils'}})-[*..8]-(m2:Module {{id: 'Deduplication.__Main__'}}) WHERE NONE(n IN nodes(path) WHERE n:Package) RETURN path


The question is:
{question}
"""


prompt = ChatPromptTemplate.from_template(cypher_prompt)


llm = ChatOpenAI(model="gpt-4o", temperature=0)

chain = GraphCypherQAChain.from_llm(
    graph=enhanced_graph, 
    llm=llm, 
    verbose=True, 
    allow_dangerous_requests=True, 
    validate_cypher=True,
    # cypher_prompt=prompt
)
response = chain.invoke({"query": "How is the Utils.Utils module related to the Deduplicatioin.__Main__ module?"})
response



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


{'query': 'How is the Utils.Utils module related to the Deduplicatioin.__Main__ module?',
 'result': "I don't know the answer."}

In [130]:
results = graph.query("""
MATCH path = allShortestPaths(
  (m1:Module {id: 'Utils.Utils'})-[*..8]-(m2:Module {id: 'Deduplication.__Main__'})
)
WHERE NONE(n IN nodes(path) WHERE n:Package)
RETURN path
""")

In [131]:
len(results)

3

In [132]:
results

[{'path': [{'id': 'Utils.Utils'},
   'USES',
   {'id': 'Deduplication.Lsh.Lsh.Compute_Minhash_Signatures'},
   'CONTAINS',
   {'id': 'Deduplication.Lsh.Lsh'},
   'USES',
   {'id': 'Deduplication.__Main__.Model'},
   'DEFINES',
   {'id': 'Deduplication.__Main__'}]},
 {'path': [{'id': 'Utils.Utils'},
   'USES',
   {'id': 'Deduplication.Lsh.Lsh.Compute_Minhash_Signatures'},
   'CONTAINS',
   {'id': 'Deduplication.Lsh.Lsh'},
   'CALLS',
   {'id': 'Deduplication.__Main__.Model'},
   'DEFINES',
   {'id': 'Deduplication.__Main__'}]},
 {'path': [{'id': 'Utils.Utils'},
   'USES',
   {'id': 'Deduplication.Lsh.Lsh.Compute_Minhash_Signatures'},
   'CONTAINS',
   {'id': 'Deduplication.Lsh.Lsh'},
   'INHERITS',
   {'id': 'Deduplication.Lshforest.Lshforest'},
   'IMPORTS',
   {'id': 'Deduplication.__Main__'}]}]

In [124]:
print(CYPHER_QA_PROMPT.template)

You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here is an example:

Question: Which managers own Neo4j stocks?
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.

Follow this example when generating answers.
If the provided information is empty, say that you don't know the answer.
Information:
{context}

Question: {question}
Helpful Answer:
