In [43]:
import os
import sys

try:
    get_ipython
    current_dir = os.getcwd()
except NameError:
    current_dir = os.path.dirname(os.path.abspath(__file__))

# Set path，temporary path expansion
project_dir = os.path.abspath(os.path.join(current_dir, "../../"))
if project_dir not in sys.path:
    sys.path.append(project_dir)


from neo4j import GraphDatabase
import numpy as np
from sentence_transformers import SentenceTransformer, util
from typing import List, Dict, Tuple, Optional
import os

USER ="neo4j"
PWD ="neo4j123"




#### 安装java
sudo apt update
sudo apt install openjdk-21-jdk


#### 解压
tar -xzf neo4j-community-2025.05.0-unix.tar.gz
cd neo4j-community-2025.05.0

#### 设置初始密码
./bin/neo4j-admin set-initial-password your_password

#### 启动服务
./bin/neo4j start


In [44]:
import neo4j
neo4j.__version__

'5.28.2'

In [45]:

driver = GraphDatabase.driver("bolt://localhost:7687", auth=(USER, PWD))
def test_connection():
    try:
        with driver.session() as session:
            result = session.run("RETURN 'Neo4j connection successful!' AS message")
            print(result.single()["message"])
    except Exception as e:
        print("连接失败：", e)
    finally:
        driver.close()
    
test_connection()

Neo4j connection successful!


In [46]:


class Neo4jGraphRAG:
    def __init__(self, uri: str, user: str, password: str, 
                 model_name: str = os.path.join(project_dir, "model/BAAI/bge-small-zh")):
        """
        初始化基于Neo4j的Graph RAG系统
        
        Args:
            uri: Neo4j数据库连接URI
            user: 数据库用户名
            password: 数据库密码
            model_name: 用于生成文本嵌入的模型名称
        """
        # 初始化Neo4j连接
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        
        # 初始化嵌入模型
        self.embedding_model = SentenceTransformer(model_name)
        
        # 确保嵌入向量的索引存在
        self._create_vector_index()
    
    def close(self):
        """关闭数据库连接"""
        if self.driver:
            self.driver.close()
    
    def _create_vector_index(self):
        """创建向量索引以加速相似性搜索"""
        with self.driver.session() as session:
            # 检查索引是否已存在
            result = session.run("""
                SHOW INDEXES WHERE name = 'node_embeddings_index'
            """)
            
            if  list(result):
                result = session.run("""
                    DROP INDEX `node_embeddings_index`
            """)
           
             
            # 创建向量索引 
            session.run("""
                    CREATE VECTOR INDEX `node_embeddings_index`
                    FOR (n:Entity) ON (n.embedding)
                    OPTIONS {
                    indexConfig: {
                        `vector.dimensions`: 512,
                        `vector.similarity_function`: 'cosine'
                    }
                    }

                """)


    
    def add_entity(self, entity_id: str, entity_type: str, properties: Dict, embedding: Optional[np.ndarray] = None):
        """
        向知识图谱添加实体
        
        Args:
            entity_id: 实体唯一标识符
            entity_type: 实体类型（标签）
            properties: 实体属性字典
            embedding: 实体的预计算嵌入（可选）
        """
        with self.driver.session() as session:
            # 如果没有提供嵌入，则从实体描述生成
            if embedding is None and "description" in properties:
                embedding = self.embedding_model.encode(properties["description"])
            
            # 将嵌入向量转换为列表以便存储
            embedding_list = embedding.tolist() if embedding is not None else None
            
            # 添加实体节点 - 修复了字符串格式化问题
            query = f"""
                MERGE (e:{entity_type} {{id: $entity_id}})
                SET e += $properties
                { 'SET e.embedding = $embedding' if embedding_list else '' }
            """
            session.run(query, {
                "entity_id": entity_id,
                "properties": properties,
                "embedding": embedding_list
            })
    
    def add_relationship(self, source_id: str, target_id: str, relationship_type: str, properties: Dict = None):
        """
        向知识图谱添加实体间的关系
        
        Args:
            source_id: 源实体ID
            target_id: 目标实体ID
            relationship_type: 关系类型
            properties: 关系属性（可选）
        """
        if properties is None:
            properties = {}
            
        with self.driver.session() as session:
            # 修复了关系创建的Cypher查询
            session.run("""
                MATCH (s) WHERE s.id = $source_id
                MATCH (t) WHERE t.id = $target_id
                MERGE (s)-[r:%s]->(t)
                SET r += $properties
            """ % relationship_type, {
                "source_id": source_id,
                "target_id": target_id,
                "properties": properties
            })
    
    def add_text_chunk(self, chunk_id: str, text: str, related_entities: List[str] = None):
        """
        添加文本块并与相关实体关联
        
        Args:
            chunk_id: 文本块唯一标识符
            text: 文本内容
            related_entities: 相关实体ID列表
        """
        # 生成文本嵌入
        embedding = self.embedding_model.encode(text)
        
        with self.driver.session() as session:
            # 添加文本块节点并确保其有Entity标签，以便被向量索引包含
            session.run("""
                MERGE (c:TextChunk:Entity {id: $chunk_id})
                SET c.content = $text,
                    c.embedding = $embedding
            """, {
                "chunk_id": chunk_id,
                "text": text,
                "embedding": embedding.tolist()
            })
        
        # 与相关实体建立连接
        if related_entities:
            for entity_id in related_entities:
                self.add_relationship(
                    chunk_id, 
                    entity_id, 
                    "MENTIONS"
                )
    
    def retrieve_relevant_nodes(self, query: str, top_k: int = 5) -> List[Dict]:
        """
        根据查询检索相关节点
        
        Args:
            query: 查询文本
            top_k: 返回的相关节点数量
            
        Returns:
            相关节点信息的列表
        """
        # 生成查询嵌入
        query_embedding = self.embedding_model.encode(query)
        
        with self.driver.session() as session:
            result = session.run("""
                CALL db.index.vector.queryNodes('node_embeddings_index', $top_k, $query_embedding)
                YIELD node, score
                RETURN node.id AS id, labels(node) AS labels, properties(node) AS properties, score
            """, {
                "top_k": top_k,
                "query_embedding": query_embedding.tolist()
            })
            
            # print("生成查询嵌入")
            return [record.data() for record in result]
    
    def get_connected_nodes(self, node_id: str, depth: int = 1) -> List[Dict]:
        """
        获取与指定节点连接的节点
        
        Args:
            node_id: 起始节点ID
            depth: 探索深度
            
        Returns:
            连接的节点信息列表
        """
        with self.driver.session() as session:
            # 构建Cypher查询，根据深度获取连接节点
            match_pattern = ""
            for i in range(1, depth + 1):
                match_pattern += f"-[r{i}]->(n{i})"
            
            nodes_list = [f"n{i}" for i in range(1, depth + 1)]
            nodes_str = ", ".join(nodes_list)

            query = f"""
                MATCH (n0) WHERE n0.id = $node_id
                MATCH (n0){match_pattern}
                UNWIND [{nodes_str}] AS connected_node
                WITH DISTINCT connected_node
                RETURN connected_node.id AS id, labels(connected_node) AS labels, properties(connected_node) AS properties
            """
            
            result = session.run(query, {"node_id": node_id})
            return [record.data() for record in result]
    
    def build_context(self, query: str, top_k: int = 5, context_depth: int = 1) -> str:
        """
        构建回答查询的上下文
        
        Args:
            query: 查询文本
            top_k: 检索的相关节点数量
            context_depth: 上下文扩展深度
            
        Returns:
            构建的上下文文本
        """
        # 检索相关节点
        relevant_nodes = self.retrieve_relevant_nodes(query, top_k)
        
        # 收集所有相关节点及其连接节点
        context_nodes = {}
        
        # 添加相关节点
        for node in relevant_nodes:
            context_nodes[node["id"]] = node
        
        # 添加连接节点
        for node in relevant_nodes:
            connected_nodes = self.get_connected_nodes(node["id"], context_depth)
            for conn_node in connected_nodes:
                if conn_node["id"] not in context_nodes:
                    context_nodes[conn_node["id"]] = conn_node
        # print(relevant_nodes)
        
        # 构建上下文文本
        context_parts = []
        
        for node_id, node_data in context_nodes.items():
            labels = [label for label in node_data["labels"] if label != "Entity"]  # 排除基础标签
            node_type = labels[0] if labels else "Node"
            
            if node_type == "TextChunk":
                context_parts.append(f"文本块: {node_data['properties'].get('content', '')}")
            else:
                entity_info = [f"{node_type} {node_id}:"]
                for key, value in node_data["properties"].items():
                    if key not in ["id", "embedding"]:  # 排除不需要的属性
                        entity_info.append(f"  {key}: {value}")
                
                # 添加关系信息
                with self.driver.session() as session:
                    rel_result = session.run("""
                        MATCH (n)-[r]->(m) WHERE n.id = $node_id
                        RETURN type(r) AS rel_type, m.id AS target_id, labels(m) AS target_labels
                    """, {"node_id": node_id})
                    
                    for rel in rel_result:
                        target_type = rel["target_labels"][0] if rel["target_labels"] else "节点"
                        entity_info.append(f"  与 {target_type} {rel['target_id']} 存在 {rel['rel_type']} 关系")
                
                context_parts.append("\n".join(entity_info))
        
        return "\n\n".join(context_parts)
    
    def generate_response(self, query: str, top_k: int = 5, context_depth: int = 1) -> str:
        """
        生成基于图谱上下文的回答
        
        Args:
            query: 查询文本
            top_k: 检索的相关节点数量
            context_depth: 上下文扩展深度
            
        Returns:
            生成的回答
        """
        # 构建上下文
        context = self.build_context(query, top_k, context_depth)
        
        # 构建提示词（实际应用中应替换为真实LLM调用）
        prompt = f"""基于以下上下文信息回答问题:
        
        上下文:{context}
        
        问题: {query}
        
        回答:"""
        
        # 模拟LLM输出
        simulated_response = f"关于'{query}'的回答如下：\n"
        simulated_response += "这是一个基于Neo4j知识图谱的回答，综合了检索到的实体、关系和文本信息。\n"
        simulated_response += "在实际应用中，这里会是大型语言模型生成的详细回答。"
        
        return simulated_response


In [47]:
# 初始化Neo4j Graph RAG系统

neo4j_rag = Neo4jGraphRAG(
    uri="bolt://localhost:7687",
    user=USER,
    password=PWD
)

try:
    # 清空现有数据（仅用于演示）
    with neo4j_rag.driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")
    
    # 添加一些实体
    neo4j_rag.add_entity(
        "einstein", 
        "Scientist", 
        {"name": "阿尔伯特·爱因斯坦", "birth_year": 1879, 
         "description": "著名物理学家，提出了相对论"}
    )
    
    neo4j_rag.add_entity(
        "relativity", 
        "Theory", 
        {"name": "相对论", "description": "关于时空和引力的物理理论"}
    )
    
    neo4j_rag.add_entity(
        "newton", 
        "Scientist", 
        {"name": "艾萨克·牛顿", "birth_year": 1643,
         "description": "物理学家，提出了万有引力定律和三大运动定律"}
    )
    
    # 添加实体间的关系
    neo4j_rag.add_relationship(
        "einstein", 
        "relativity", 
        "DEVELOPED",
        {"year": 1905}
    )
    
    neo4j_rag.add_relationship(
        "einstein", 
        "newton", 
        "WAS_INFLUENCED_BY"
    )
    
    # 添加文本块
    neo4j_rag.add_text_chunk(
        "chunk1", 
        "爱因斯坦在1905年发表了狭义相对论，后来在1915年提出了广义相对论。",
        ["einstein", "relativity"]
    )
    
    neo4j_rag.add_text_chunk(
        "chunk2", 
        "牛顿的力学理论在低速宏观情况下非常有效，但在高速或强引力场中需要相对论来解释。",
        ["newton", "relativity"]
    )
    
    # 测试查询
    query = "相对论是谁提出的？它与牛顿的理论有什么关系？"
    print(f"查询: {query}")
    
    # 获取回答
    response = neo4j_rag.generate_response(query, top_k=3, context_depth=2)
    print("\n回答:")
    print(response)
    
    # 打印构建的上下文（用于演示）
    print("\n构建的上下文:")
    print(neo4j_rag.build_context(query))
    
finally:
    # 关闭连接
    neo4j_rag.close()

查询: 相对论是谁提出的？它与牛顿的理论有什么关系？

回答:
关于'相对论是谁提出的？它与牛顿的理论有什么关系？'的回答如下：
这是一个基于Neo4j知识图谱的回答，综合了检索到的实体、关系和文本信息。
在实际应用中，这里会是大型语言模型生成的详细回答。

构建的上下文:
文本块: 牛顿的力学理论在低速宏观情况下非常有效，但在高速或强引力场中需要相对论来解释。

文本块: 爱因斯坦在1905年发表了狭义相对论，后来在1915年提出了广义相对论。

Scientist newton:
  description: 物理学家，提出了万有引力定律和三大运动定律
  name: 艾萨克·牛顿
  birth_year: 1643

Theory relativity:
  description: 关于时空和引力的物理理论
  name: 相对论

Scientist einstein:
  description: 著名物理学家，提出了相对论
  name: 阿尔伯特·爱因斯坦
  birth_year: 1879
  与 Theory relativity 存在 DEVELOPED 关系
  与 Scientist newton 存在 WAS_INFLUENCED_BY 关系
