In [6]:
# %pip install -q networkx pyvis matplotlib neo4j

# Start `neo4j` before executing this notebook

- Create a directory to persist the database:
```bash
mkdir -p $(pwd)/data/neo4j
```

- Run the Neo4j Docker container:
```bash
docker run -d \
  --name neo4j-local \
  -p 7474:7474 -p 7687:7687 \
  -e NEO4J_AUTH=none \
  -v "$(pwd)/data/neo4j:/data" \
  neo4j:latest
```

- Grant permissions to the data directory:
```bash
sudo chown -R 7474:7474 ./data/neo4j
```

In [None]:
import json
import networkx as nx
from pyvis.network import Network
import pandas as pd
import matplotlib.pyplot as plt
from neo4j import GraphDatabase
from IPython.display import display, HTML
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Load the data
data_path = "../data/raw/graph_documents.json"
with open(data_path, 'r', encoding='utf-8') as f:
    graph_documents = json.load(f)

print(f"Loaded {len(graph_documents)} documents.")

Loaded 604 documents.


# Config

In [8]:
class Neo4jLoader:
    def __init__(self, uri, auth):
        try:
            # If auth is None, we don't pass it to the driver
            if auth is None:
                self.driver = GraphDatabase.driver(uri)
            else:
                self.driver = GraphDatabase.driver(uri, auth=auth)
            self.verify_connection()
        except Exception as e:
            print(f"Initialization Error: {e}")
            self.driver = None

    def verify_connection(self):
        try:
            self.driver.verify_connectivity()
            print("Connected to Neo4j successfully.")
        except Exception as e:
            print(f"Failed to connect to Neo4j: {e}")
            print("1. Is the Neo4j database running?")
            print("2. Are the URI and PORT correct? (default: bolt://localhost:7687)")
            print("3. Is the password correct?")
            raise e

    def close(self):
        if self.driver:
            self.driver.close()

    def clear_database(self):
        if not self.driver: return
        with self.driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")
            print("Database cleared.")

    def load_data(self, documents):
        if not self.driver: return
        with self.driver.session() as session:
            for i, doc in enumerate(documents):

                # Create nodes
                for node in doc['nodes']:
                    # remove "(Phụ Lục ...)" from nodes
                    node['id'] = node['id'].split(' (Phụ Lục')[0].strip()
                    # Sanitize label
                    label = "".join(x for x in node['type'] if x.isalnum() or x == "_").upper()
                    if not label: label = "ENTITY"
                    
                    query = (
                        f"MERGE (n:`{label}` {{id: $id}}) "
                        "SET n.type = $type "
                        "SET n += $properties"
                    )
                    session.run(query, id=node['id'], type=node['type'], properties=node.get('properties', {}))
                
                # Create relationships
                for rel in doc['relationships']:
                    rel_type = "".join(x for x in rel['type'] if x.isalnum() or x == "_").upper()
                    if not rel_type: rel_type = "RELATED_TO"
                    
                    # Clean source/target to match nodes
                    source_id = rel['source'].split(' (Phụ Lục')[0].strip()
                    target_id = rel['target'].split(' (Phụ Lục')[0].strip()

                    query = (
                        "MATCH (a {id: $source}), (b {id: $target}) "
                        f"MERGE (a)-[r:`{rel_type}`]->(b) "
                        "SET r += $properties"
                    )
                    session.run(
                        query, 
                        source=source_id, 
                        target=target_id, 
                        properties=rel.get('properties', {})
                    )
                
                if (i + 1) % 10 == 0:
                    print(f"Processed {i + 1} documents...", end = "\r")

def get_subgraph_from_neo4j(uri, auth, limit=50):
    """Fetches a subgraph from Neo4j into NetworkX for visualization."""
    driver = GraphDatabase.driver(uri, auth=auth)
    G = nx.DiGraph()
    
    query = f"""
    MATCH (n)-[r]->(m)
    RETURN n, r, m
    LIMIT {limit}
    """
    
    try:
        with driver.session() as session:
            result = session.run(query)
            for record in result:
                n = record['n']
                m = record['m']
                r = record['r']
                
                # Add nodes with properties
                G.add_node(n['id'], label=n['id'], title=n['id'], group=n.get('type', 'Entity'), **dict(n))
                G.add_node(m['id'], label=m['id'], title=m['id'], group=m.get('type', 'Entity'), **dict(m))
                
                # Add edge
                G.add_edge(n['id'], m['id'], title=r.type, label=r.type, **dict(r))
    except Exception as e:
        print(f"Error fetching data: {e}")
    finally:
        driver.close()
        
    return G

def visualize_neo4j_subgraph(uri, auth, limit=50, output_file="neo4j_viz.html"):
    G = get_subgraph_from_neo4j(uri, auth, limit)
    
    if G.number_of_nodes() == 0:
        print("No data found or connection failed.")
        return

    print(f"Visualizing subgraph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

    net = Network(notebook=True, height="600px", width="100%", bgcolor="#222222", font_color="white", cdn_resources='in_line')
    net.from_nx(G)
    net.repulsion(node_distance=150, spring_length=200)
    net.show(output_file)
    return output_file

# ------ Print full schema ------
def print_neo4j_schema(uri, auth):
    driver = GraphDatabase.driver(uri, auth=auth)
    try:
        with driver.session() as session:
            result = session.run("CALL db.schema.visualization()")
            nodes = set()
            relationships = set()
            for record in result:
                for node in record['nodes']:
                    nodes.add(tuple(node.labels))
                for rel in record['relationships']:
                    relationships.add((rel.start_node.labels, rel.type, rel.end_node.labels))
            print("Node Labels:")
            for label in nodes:
                print(f" - {label}")
            print("\nRelationships:")
            for start_labels, rel_type, end_labels in relationships:
                print(f" - ({start_labels})-[:{rel_type}]->({end_labels})")
            # for rel_type in relationships:
                # print(f" - {rel_type}")
    except Exception as e:
        print(f"Error fetching schema: {e}")
    finally:
        driver.close()
        
# ------ Get Statistics ------
def get_neo4j_statistics(uri, auth):
    driver = GraphDatabase.driver(uri, auth=auth)
    stats = {}
    try:
        with driver.session() as session:
            node_count = session.run("MATCH (n) RETURN count(n) AS count").single()["count"]
            rel_count = session.run("MATCH ()-[r]->() RETURN count(r) AS count").single()["count"]
            stats['node_count'] = node_count
            stats['relationship_count'] = rel_count
    except Exception as e:
        print(f"Error fetching statistics: {e}")
    finally:
        driver.close()
    return stats

# ------ Find orphan nodes ------
def find_orphans_neo4j(uri, auth, n=10):
    driver = GraphDatabase.driver(uri, auth=auth)
    query = f"""
    MATCH (n)
    WHERE NOT (n)--()
    RETURN n.id AS id, n.type AS type
    LIMIT {n}
    """
    
    try:
        with driver.session() as session:
            result = session.run(query)
            orphans = [record['id'] for record in result]
            
            # Get count
            count_result = session.run("MATCH (n) WHERE NOT (n)--() RETURN count(n) as count")
            total = count_result.single()['count']
            
            print(f"Total orphan nodes in DB: {total}")
            if orphans:
                print(f"First {n} orphan nodes:")
                for oid in orphans:
                    print(f" - {oid}")
    except Exception as e:
        print(f"Error: {e}")
    finally:
        driver.close()

# ------ Remove orphan nodes ------
def remove_orphans_neo4j(uri, auth):
    driver = GraphDatabase.driver(uri, auth=auth)
    # Cypher query to match nodes with no relationships and delete them
    query = "MATCH (n) WHERE NOT (n)--() DELETE n"
    
    try:
        with driver.session() as session:
            # Check count before
            count_before = session.run("MATCH (n) WHERE NOT (n)--() RETURN count(n) as count").single()['count']
            print(f"Orphans before deletion: {count_before}")
            
            if count_before > 0:
                session.run(query)
                print(f"Successfully deleted {count_before} orphan nodes.")
            else:
                print("No orphan nodes found to delete.")
                
    except Exception as e:
        print(f"Error removing orphans: {e}")
    finally:
        driver.close()
        
# ------ Query function ------
def query(uri, auth, cypher_query):
    driver = GraphDatabase.driver(uri, auth=auth)
    results = []
    try:
        with driver.session() as session:
            result = session.run(cypher_query)
            for record in result:
                results.append(record.data())
    except Exception as e:
        print(f"Error executing query: {e}")
    finally:
        driver.close()
    return results

# Configuration

In [9]:
# Neo4j Configuration
URI = os.getenv("Neo4j_URI")
# AUTH = ("neo4j", "password")
AUTH = None

# Initialize and load
# NOTE: Set run_import to True to execute the import. 
# Be careful as this clears the database first.
run_import = True
output_html = "../data/neo4j_subgraph.html"
remove_orphan_nodes = True

# Main execution

In [10]:
# ------ Load data into Neo4j -----
print("\n====== Import Data======\n")
if run_import:
    try:
        loader = Neo4jLoader(URI, AUTH)
        if loader.driver:
            loader.clear_database()
            loader.load_data(graph_documents)
            loader.close()
            print("\nImport finished.")
    except Exception:
        print("Import failed due to connection error.")
else:
    print("Skipping import. Set run_import=True to run.")
    
# ------ Get statistics and schema ------
print("\n====== Statistics======\n")
stats = get_neo4j_statistics(URI, AUTH)
print(f"Total Nodes: {stats.get('node_count', 'N/A')}")
print(f"Total Relationships: {stats.get('relationship_count', 'N/A')}")
        
print("\n====== Full Schema======\n")
print_neo4j_schema(URI, AUTH)

# ------ Find and remove orphan nodes ------
if remove_orphan_nodes:
    print("\n====== Remove Orphan Nodes======\n")
    find_orphans_neo4j(URI, AUTH, 5)
    remove_orphans_neo4j(URI, AUTH)

# ------ Visualize ------
print("\n====== Visualize Subgraph======\n")
output_html = visualize_neo4j_subgraph(URI, AUTH, limit=50, output_file=output_html)



Connected to Neo4j successfully.
Database cleared.
Processed 600 documents...
Import finished.


Total Nodes: 927
Total Relationships: 24452


Node Labels:
 - ('DRUG',)
 - ('CHEMICAL',)
 - ('PRODUCTION_METHOD',)
 - ('STANDARD',)
 - ('TEST_METHOD',)
 - ('STORAGE_CONDITION',)
 - ('DISEASE',)
 - ('ORGANISM',)

Relationships:
 - (frozenset({'STORAGE_CONDITION'}))-[:HAS_STANDARD]->(frozenset({'PRODUCTION_METHOD'}))
 - (frozenset({'DRUG'}))-[:PRODUCED_BY]->(frozenset({'TEST_METHOD'}))
 - (frozenset({'DRUG'}))-[:TESTED_BY]->(frozenset({'TEST_METHOD'}))
 - (frozenset({'PRODUCTION_METHOD'}))-[:HAS_STANDARD]->(frozenset({'ORGANISM'}))
 - (frozenset({'CHEMICAL'}))-[:PRODUCED_BY]->(frozenset({'STORAGE_CONDITION'}))
 - (frozenset({'DISEASE'}))-[:STORED_AT]->(frozenset({'TEST_METHOD'}))
 - (frozenset({'STORAGE_CONDITION'}))-[:PRODUCED_BY]->(frozenset({'STANDARD'}))
 - (frozenset({'STORAGE_CONDITION'}))-[:TESTED_BY]->(frozenset({'STANDARD'}))
 - (frozenset({'STORAGE_CONDITION'}))-[:STORED_AT]->(fro