In [1]:
import pandas as pd
import re
from typing import List, Dict, Any
import ast  
from neo4j import GraphDatabase


In [None]:
processed_data_filepath = '../data/processed/processed_data_with_embeddings.csv'
df = pd.read_csv(processed_data_filepath)
df.info()

# Convert Named_Entities from string to list of dictionaries
df['Named_Entities'] = df['Named_Entities'].apply(ast.literal_eval)
# Convert Embeddings from string to list of floats
df['Embeddings'] = df['Embeddings'].apply(lambda x: [float(i) for i in ast.literal_eval(x)])

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1983 entries, 0 to 1982
Data columns (total 4 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   Image_ID        1983 non-null   object
 1   Document        1983 non-null   object
 2   Named_Entities  1983 non-null   object
 3   Embeddings      1983 non-null   object
dtypes: object(4)
memory usage: 62.1+ KB


In [None]:
def load_data_into_neo4j(df: pd.DataFrame, uri: str, user: str, password: str):
    """Loads the processed dataframe into a Neo4j graph database."""
    driver = GraphDatabase.driver(uri, auth=(user, password))
    
    def insert_data(tx, doc_id, document, doc_embedding, named_entities):
        query = (
            "MERGE (d:Document {id: $doc_id}) "
            "SET d.text = $document, d.embedding = $doc_embedding "
            "WITH d "
            "UNWIND $named_entities AS ne "
            "MERGE (e:NamedEntity {text: ne.ne_span, type: ne.ne_type}) "
            "MERGE (d)-[:HAS_ENTITY]->(e) "
        )
        tx.run(query, doc_id=doc_id, document=document, doc_embedding=doc_embedding, named_entities=named_entities)
    
    with driver.session() as session:
        for _, row in df.iterrows():
            session.execute_write(insert_data, row['Image_ID'], row['Document'], row['Embeddings'], row['Named_Entities'])
    
    driver.close()


# Load into Neo4j
neo4j_uri = "bolt://localhost:7687"
neo4j_user = "neo4j"
neo4j_password = "12345678"
load_data_into_neo4j(df, neo4j_uri, neo4j_user, neo4j_password)

In [9]:
def delete_graph(uri: str, user: str, password: str):
    """Deletes all nodes and relationships in the Neo4j database."""
    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")
    driver.close()

neo4j_uri = "bolt://localhost:7687"
neo4j_user = "neo4j"
neo4j_password = "12345678"
delete_graph(neo4j_uri, neo4j_user, neo4j_password)