In [6]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
from neo4j import GraphDatabase

In [7]:
# 连接到Neo4j数据库
class Neo4jGraph:
    def __init__(self, uri, user, pwd, db_name):
        self._driver = GraphDatabase.driver(uri, auth=(user, pwd), database=db_name)

    def close(self):
        self._driver.close()

    def query(self, query):
        with self._driver.session() as session:
            result = session.run(query)
            return [record for record in result]

In [8]:
uri = "bolt://localhost:7687"
user = "neo4j"
pwd = "1742359208ys"
db_name = "neo4j"

# 创建Neo4jGraph实例
neo4j_graph = Neo4jGraph(uri, user, pwd, db_name)

In [25]:
def extract_graph_data_from_neo4j(neo4j_graph, node_query, feature_node_query, edge_query):
    # 查询Hole节点
    hole_nodes_result = neo4j_graph.query(node_query)
    # 查询特征节点
    feature_nodes_result = neo4j_graph.query(feature_node_query)
    # 查询边，确保返回节点的ID
    edges_result = neo4j_graph.query(edge_query)
    feature_nodes = {}

    # 创建特征节点字典
    # for node in feature_nodes_result:
    #     new_props = {k: v for k, v in node['properties'].items() if k != 'id'}
    #     feature_nodes[node['id']] = new_props
    feature_nodes = {node['id']: node['properties'] for node in feature_nodes_result}

    print(f"Fno:{feature_nodes}")

    # 创建边列表，只包含起点和终点的ID
    edges = []
    for edge in edges_result:
        start_node = edge['startNodeElementId']
        end_node = edge['endNodeElementId']
        # 检查节点是否存在，并获取ID
        # if start_node is not None and 'id' in start_node:
        if start_node is not None:
            # start_node_id = start_node['id']
            start_node_id = start_node
        else:
            start_node_id = None  # 或者选择一个合适的默认值或占位符

        # if end_node is not None and 'id' in end_node:
        if end_node is not None:
            # end_node_id = end_node['id']
            end_node_id = end_node
        else:
            end_node_id = None  # 或者选择一个合适的默认值或占位符

        edges.append((start_node_id, end_node_id))

    # 过滤掉None的边
    edges = [edge for edge in edges if edge[0] is not None and edge[1] is not None]

    # 创建特征矩阵x
    x = torch.tensor([list(features.values()) for features in feature_nodes.values()], dtype=torch.float)
    
    # 创建边索引edge_index
    edge_index = torch.tensor([edge[:2] for edge in edges], dtype=torch.long).t().contiguous()

    # 创建Data对象，只包含x和edge_index
    data = Data(x=x, edge_index=edge_index)

    return data

# 其余代码保持不变...

# 定义查询语句
node_query = "MATCH (n:Hole) RETURN n.hole_id AS id, properties(n) AS properties"
feature_node_query = "MATCH (n:III) RETURN n.id AS id, properties(n) AS properties"

edge_query = """
MATCH path = (n:III)<-[:HAS_FEATURE]-(h1:Hole)-[:CONNECTED]->(h2:Hole)-[:HAS_FEATURE]->(m:III)
WHERE n.id <> m.id AND NOT (n)-[:HAS_FEATURE*0..1]->(m) // 确保n和m不是直接相连
RETURN n.id AS startNodeElementId, m.id AS endNodeElementId
"""
# MATCH (n:III)<-[:HAS_FEATURE]-(h1:Hole)->[:CONNECTED]-(h2:Hole)-[:HAS_FEATURE]->(m:III)

# 提取图数据
data = extract_graph_data_from_neo4j(neo4j_graph, node_query, feature_node_query, edge_query)

# 输出Data对象的内容
print(data.x.shape)  # 输出特征矩阵
print(data.edge_index)  # 输出边索引
# print(data.edge_attr)  # 这行被注释掉了，因为我们不再使用边权重

Fno:{1609868.0: {'id': 1609868.0, 'time': 35.0, 'waterPressure': 23.0, 'waterFlow': 77.0, 'strikePressure': 146.0, 'velocity': 3.429, 'rotationPressure': 57.0, 'rotationVelocity': 0.0, 'propelPressure': 60.0, 'depth': 0.384}, 1609988.0: {'id': 1609988.0, 'time': 207.0, 'waterPressure': 28.0, 'waterFlow': 77.0, 'strikePressure': 85.0, 'velocity': 0.58, 'rotationPressure': 31.0, 'rotationVelocity': 0.0, 'propelPressure': 18.0, 'depth': 0.302}, 1609989.0: {'id': 1609989.0, 'time': 207.0, 'waterPressure': 23.0, 'waterFlow': 74.0, 'strikePressure': 94.0, 'velocity': 0.58, 'rotationPressure': 33.0, 'rotationVelocity': 0.0, 'propelPressure': 18.0, 'depth': 0.322}, 1609990.0: {'id': 1609990.0, 'time': 188.0, 'waterPressure': 23.0, 'waterFlow': 73.0, 'strikePressure': 91.0, 'velocity': 0.638, 'rotationPressure': 30.0, 'rotationVelocity': 0.0, 'propelPressure': 19.0, 'depth': 0.342}, 1609991.0: {'id': 1609991.0, 'time': 203.0, 'waterPressure': 23.0, 'waterFlow': 73.0, 'strikePressure': 96.0, 've

  edge_index = torch.tensor([edge[:2] for edge in edges], dtype=torch.long).t().contiguous()
