In [14]:
import neo4j
from pathlib import Path
import polars as pl
from rich import console

cons = console.Console()

In [15]:
data_path = Path("../data/radgraphXL/cleaned_data.jsonl")
assert data_path.exists(), f"Data file not found at {data_path}"

In [16]:
df = pl.read_ndjson(data_path)

# inspect
cons.print(df.head())

In [17]:
unique_list = df["dataset"].unique().to_list()
cons.print(f"Unique datasets: {unique_list}")

In [19]:
# Neo4j connection setup
from neo4j import GraphDatabase

# Update these with your Neo4j credentials
NEO4J_URI = "neo4j://127.0.0.1:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "radgraphkg"  # Change this to your password

# Create driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# Test connection
def test_connection():
    with driver.session() as session:
        result = session.run("RETURN 'Connection successful!' as message")
        return result.single()["message"]

try:
    message = test_connection()
    cons.print(f"[green]{message}[/green]")
except Exception as e:
    cons.print(f"[red]Connection failed: {e}[/red]")
    cons.print("[yellow]Make sure Neo4j is running and credentials are correct[/yellow]")

In [20]:
# Helper functions to parse entities and relations
def parse_entity(entity_data, tokens):
    """
    Parse entity data: [start_idx, end_idx, entity_type]
    Returns dict with entity information
    """
    start_idx = int(entity_data[0])
    end_idx = int(entity_data[1])
    entity_type = entity_data[2]
    
    # Extract the text from tokens
    text = " ".join(tokens[start_idx:end_idx + 1])
    
    # Parse entity type (e.g., "Anatomy::definite" -> category: Anatomy, modifier: definite)
    parts = entity_type.split("::")
    category = parts[0] if len(parts) > 0 else "Unknown"
    modifier = parts[1] if len(parts) > 1 else None
    
    return {
        "start_idx": start_idx,
        "end_idx": end_idx,
        "text": text,
        "category": category,
        "modifier": modifier,
        "entity_type": entity_type
    }

def parse_relation(relation_data):
    """
    Parse relation data: [head_start, head_end, tail_start, tail_end, relation_type]
    Returns dict with relation information
    """
    return {
        "head_start": int(relation_data[0]),
        "head_end": int(relation_data[1]),
        "tail_start": int(relation_data[2]),
        "tail_end": int(relation_data[3]),
        "relation_type": relation_data[4]
    }

# Test parsing on first row
sample_row = df.row(0, named=True)
cons.print("\n[bold]Sample Entity Parsing:[/bold]")
if sample_row["ner"]:
    sample_entity = parse_entity(sample_row["ner"][0], sample_row["tokens"])
    cons.print(sample_entity)

cons.print("\n[bold]Sample Relation Parsing:[/bold]")
if sample_row["relations"]:
    sample_relation = parse_relation(sample_row["relations"][0])
    cons.print(sample_relation)

In [None]:
# Setup database schema with constraints and indexes
def setup_schema(driver):
    """Create constraints and indexes for the graph database"""
    with driver.session() as session:
        # Create constraints for unique document IDs
        session.run("""
            CREATE CONSTRAINT document_id IF NOT EXISTS
            FOR (d:Document) REQUIRE d.id IS UNIQUE
        """)
        
        # Create constraint for entities (composite key of document + position)
        session.run("""
            CREATE CONSTRAINT entity_id IF NOT EXISTS
            FOR (e:Entity) REQUIRE e.id IS UNIQUE
        """)
        
        # Create indexes for better query performance
        session.run("""
            CREATE INDEX entity_category IF NOT EXISTS
            FOR (e:Entity) ON (e.category)
        """)
        
        session.run("""
            CREATE INDEX entity_text IF NOT EXISTS
            FOR (e:Entity) ON (e.text)
        """)
        
        session.run("""
            CREATE INDEX document_dataset IF NOT EXISTS
            FOR (d:Document) ON (d.dataset)
        """)
        
        cons.print("[green]Schema created successfully![/green]")

# Setup the schema
setup_schema(driver)

In [None]:
# Function to insert a single document with its entities and relations
def insert_document(tx, dataset, doc_key, tokens, entities, relations):
    """Insert a document with all its entities and relations into Neo4j"""
    
    # Create unique document ID
    doc_id = f"{dataset}:{doc_key}"
    full_text = " ".join(tokens)
    
    # Create the document node
    tx.run("""
        MERGE (d:Document {id: $doc_id})
        SET d.dataset = $dataset,
            d.doc_key = $doc_key,
            d.text = $full_text,
            d.token_count = $token_count
    """, doc_id=doc_id, dataset=dataset, doc_key=doc_key, 
        full_text=full_text, token_count=len(tokens))
    
    # Create entity nodes and link to document
    entity_map = {}  # Map from (start, end) to entity ID for relations
    
    for entity_data in entities:
        entity = parse_entity(entity_data, tokens)
        entity_id = f"{doc_id}:{entity['start_idx']}:{entity['end_idx']}"
        entity_map[(entity['start_idx'], entity['end_idx'])] = entity_id
        
        tx.run("""
            MERGE (e:Entity {id: $entity_id})
            SET e.text = $text,
                e.category = $category,
                e.modifier = $modifier,
                e.entity_type = $entity_type,
                e.start_idx = $start_idx,
                e.end_idx = $end_idx
            WITH e
            MATCH (d:Document {id: $doc_id})
            MERGE (d)-[:HAS_ENTITY]->(e)
        """, entity_id=entity_id, text=entity['text'], 
            category=entity['category'], modifier=entity['modifier'],
            entity_type=entity['entity_type'], start_idx=entity['start_idx'],
            end_idx=entity['end_idx'], doc_id=doc_id)
    
    # Create relations between entities
    for relation_data in relations:
        relation = parse_relation(relation_data)
        head_key = (relation['head_start'], relation['head_end'])
        tail_key = (relation['tail_start'], relation['tail_end'])
        
        if head_key in entity_map and tail_key in entity_map:
            head_id = entity_map[head_key]
            tail_id = entity_map[tail_key]
            relation_type = relation['relation_type'].upper().replace(" ", "_")
            
            # Create dynamic relationship based on relation type
            tx.run(f"""
                MATCH (head:Entity {{id: $head_id}})
                MATCH (tail:Entity {{id: $tail_id}})
                MERGE (head)-[r:{relation_type}]->(tail)
                SET r.type = $relation_type_orig
            """, head_id=head_id, tail_id=tail_id, 
                relation_type_orig=relation['relation_type'])

# Test insert with first row
def insert_single_document(driver, row_dict):
    """Helper to insert a single document"""
    with driver.session() as session:
        session.execute_write(
            insert_document,
            row_dict["dataset"],
            row_dict["doc_key"],
            row_dict["tokens"],
            row_dict["ner"] or [],
            row_dict["relations"] or []
        )

# Test with first row
test_row = df.row(0, named=True)
insert_single_document(driver, test_row)
cons.print("[green]Successfully inserted test document![/green]")

In [None]:
# Populate Neo4j with all documents from the dataframe
from tqdm import tqdm

def populate_database(driver, dataframe, batch_size=100):
    """Insert all documents from dataframe into Neo4j with progress tracking"""
    
    total_rows = len(dataframe)
    cons.print(f"[bold]Populating Neo4j with {total_rows} documents...[/bold]")
    
    # Process in batches for better performance
    for batch_start in tqdm(range(0, total_rows, batch_size), desc="Processing batches"):
        batch_end = min(batch_start + batch_size, total_rows)
        batch_df = dataframe.slice(batch_start, batch_end - batch_start)
        
        # Process each row in the batch
        for row_dict in batch_df.iter_rows(named=True):
            try:
                with driver.session() as session:
                    session.execute_write(
                        insert_document,
                        row_dict["dataset"],
                        row_dict["doc_key"],
                        row_dict["tokens"],
                        row_dict["ner"] or [],
                        row_dict["relations"] or []
                    )
            except Exception as e:
                cons.print(f"[red]Error inserting doc {row_dict['dataset']}:{row_dict['doc_key']}: {e}[/red]")
    
    cons.print("[green]Database population complete![/green]")
    
    # Print summary statistics
    with driver.session() as session:
        doc_count = session.run("MATCH (d:Document) RETURN count(d) as count").single()["count"]
        entity_count = session.run("MATCH (e:Entity) RETURN count(e) as count").single()["count"]
        relation_count = session.run("MATCH ()-[r]->() WHERE type(r) <> 'HAS_ENTITY' RETURN count(r) as count").single()["count"]
        
        cons.print(f"\n[bold cyan]Database Statistics:[/bold cyan]")
        cons.print(f"  Documents: {doc_count}")
        cons.print(f"  Entities: {entity_count}")
        cons.print(f"  Relations: {relation_count}")

# Populate the entire database
populate_database(driver, df)

In [None]:
# Example queries to explore the knowledge graph
def run_example_queries(driver):
    """Run some example queries to demonstrate the knowledge graph"""
    
    with driver.session() as session:
        # Query 1: Find most common entity categories
        cons.print("\n[bold cyan]Top 10 Entity Categories:[/bold cyan]")
        result = session.run("""
            MATCH (e:Entity)
            RETURN e.category as category, count(e) as count
            ORDER BY count DESC
            LIMIT 10
        """)
        for record in result:
            cons.print(f"  {record['category']}: {record['count']}")
        
        # Query 2: Find most common relations
        cons.print("\n[bold cyan]Top 10 Relation Types:[/bold cyan]")
        result = session.run("""
            MATCH ()-[r]->()
            WHERE type(r) <> 'HAS_ENTITY'
            RETURN type(r) as relation_type, count(r) as count
            ORDER BY count DESC
            LIMIT 10
        """)
        for record in result:
            cons.print(f"  {record['relation_type']}: {record['count']}")
        
        # Query 3: Find entities with most relations
        cons.print("\n[bold cyan]Top 10 Most Connected Entities:[/bold cyan]")
        result = session.run("""
            MATCH (e:Entity)-[r]->()
            WHERE type(r) <> 'HAS_ENTITY'
            RETURN e.text as entity, e.category as category, count(r) as relation_count
            ORDER BY relation_count DESC
            LIMIT 10
        """)
        for record in result:
            cons.print(f"  {record['entity']} ({record['category']}): {record['relation_count']} relations")
        
        # Query 4: Documents per dataset
        cons.print("\n[bold cyan]Documents per Dataset:[/bold cyan]")
        result = session.run("""
            MATCH (d:Document)
            RETURN d.dataset as dataset, count(d) as count
            ORDER BY count DESC
        """)
        for record in result:
            cons.print(f"  {record['dataset']}: {record['count']}")

# Run example queries
run_example_queries(driver)

In [None]:
# Clean up: close the driver connection when done
# Uncomment this when you're finished working with the database
# driver.close()
# cons.print("[yellow]Driver connection closed[/yellow]")