# Hybrid Graph-RAG for DPP PoC

This notebook combines:
- Full-text search (BM25) for evidence retrieval
- Vector semantic search using pre-computed embeddings
- Hybrid search combining both methods
- Graph neighborhood expansion for context
- LLM synthesis with proper citations

**Prerequisites**: 
1. Run `vector-embeddings.ipynb` first to create chunks, embeddings, and Neo4j indexes
2. Ensure Neo4j database contains the knowledge graph entities from the charge bundle

In [9]:
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
from pathlib import Path
import json, textwrap, re, os
from typing import List, Dict
from dotenv import load_dotenv

# Load .env from home directory in AzureML
home_env = Path.home() / '.env'
if home_env.exists():
    load_dotenv(home_env)
else:
    load_dotenv()  # fallback to local .env
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USER = os.getenv('NEO4J_USER')
NEO4J_PASS = os.getenv('NEO4J_PASS')

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
EMBED_DIM = 384
EMBED_MODEL = "all-MiniLM-L6-v2" 

# Initialize embedding model for query-time vector search
# Note: Vector embeddings for chunks should already exist from vector-embeddings.ipynb
print("Loading embedding model...")
try:
    # Try local cached model first
    embedding_model = SentenceTransformer("/Users/alandevlin/data/all-MiniLM-L6-v2")
    print("✅ Loaded local cached model")
except:
    # Fallback to download (for AzureML)
    embedding_model = SentenceTransformer(EMBED_MODEL)
    print("✅ Downloaded model from HuggingFace")

print("✅ Connected to Neo4j")

Loading embedding model...
✅ Downloaded model from HuggingFace
✅ Connected to Neo4j


## Evidence Retrieval System

Supports both BM25 full-text search and vector semantic search for evidence retrieval.

In [10]:
# Create full-text index for BM25 search
with driver.session(database="dpppoc") as s:
    try:
        s.run("""
          CREATE FULLTEXT INDEX chunk_fulltext IF NOT EXISTS 
          FOR (c:Chunk) ON EACH [c.text, c.source_file, c.exhibit_id]
        """)
        print("✅ Full-text index ready")
    except Exception as e:
        print("Index already exists:", e)

✅ Full-text index ready


In [11]:
def bm25_chunks(query, k=8):
    """Search chunks using BM25 full-text search."""
    with driver.session(database="dpppoc") as s:
        return s.run("""
            CALL db.index.fulltext.queryNodes('chunk_fulltext', $q)
            YIELD node AS c, score
            RETURN c.id AS id, c.exhibit_id AS ex, c.source_file AS file,
                   c.text AS text, score
            ORDER BY score DESC LIMIT $k
        """, q=query, k=k, timeout=15).data()

def vector_chunks(query, k=8):
    """Search chunks using vector similarity search."""
    if embedding_model is None:
        print("Vector search not available - falling back to BM25")
        return bm25_chunks(query, k)
    
    # Generate query embedding
    query_vector = embedding_model.encode(query, show_progress_bar=False).tolist()
    
    with driver.session(database="dpppoc") as s:
        return s.run("""
            CALL db.index.vector.queryNodes('chunk_vec', $k, $embedding)
            YIELD node AS c, score
            RETURN c.id AS id, c.exhibit_id AS ex, c.source_file AS file,
                   c.text AS text, score
            ORDER BY score DESC
        """, k=k, embedding=query_vector, timeout=15).data()

def hybrid_chunks(query, k=8, vector_weight=0.7):
    """Combine BM25 and vector search results with weighted scores."""
    if embedding_model is None:
        return bm25_chunks(query, k)
    
    # Get results from both methods
    bm25_results = bm25_chunks(query, k*2)
    vector_results = vector_chunks(query, k*2)
    
    # Normalize scores and combine
    combined = {}
    
    # Add BM25 results
    if bm25_results:
        max_bm25 = max(r['score'] for r in bm25_results)
        for r in bm25_results:
            chunk_id = r['id']
            norm_score = r['score'] / max_bm25 if max_bm25 > 0 else 0
            combined[chunk_id] = {
                **r,
                'score': norm_score * (1 - vector_weight)
            }
    
    # Add vector results
    if vector_results:
        max_vector = max(r['score'] for r in vector_results)
        for r in vector_results:
            chunk_id = r['id']
            norm_score = r['score'] / max_vector if max_vector > 0 else 0
            if chunk_id in combined:
                combined[chunk_id]['score'] += norm_score * vector_weight
            else:
                combined[chunk_id] = {
                    **r,
                    'score': norm_score * vector_weight
                }
    
    # Sort by combined score
    return sorted(combined.values(), key=lambda x: x['score'], reverse=True)[:k]

def dedupe_chunks(rows):
    """Remove duplicate chunks by ID."""
    seen, out = set(), []
    for r in rows:
        if r["id"] not in seen:
            out.append(r); seen.add(r["id"])
    return out

def rerank_by_exhibit(chunks, preferred=None, boost=0.12):
    """Boost scores for preferred exhibits."""
    if not preferred: return chunks
    preferred = set(preferred)
    boosted = []
    for c in chunks:
        c2 = dict(c)
        c2["score"] = float(c["score"]) + (boost if c["ex"] in preferred else 0.0)
        boosted.append(c2)
    return sorted(boosted, key=lambda x: x["score"], reverse=True)

def retrieve_chunks(question, k=8, preferred_exhibits=None, method="bm25"):
    """Main retrieval function with multiple search methods.
    
    Args:
        question: Query string
        k: Number of chunks to return
        preferred_exhibits: List of exhibit IDs to boost
        method: 'bm25', 'vector', or 'hybrid'
    """
    # Choose search method
    if method == "vector":
        rows = vector_chunks(question, k=20)
    elif method == "hybrid":
        rows = hybrid_chunks(question, k=20)
    else:  # default to bm25
        rows = bm25_chunks(question, k=20)
    
    # Post-process results
    rows = dedupe_chunks(rows)
    rows = rerank_by_exhibit(rows, preferred_exhibits)
    return rows[:k]

## Graph Context Expansion

Expands retrieved chunks with related graph entities for additional context.

In [12]:
def expand_neighborhood(chunk_ids, max_nodes=50):
    """Get graph context for retrieved chunks."""
    with driver.session(database="dpppoc") as s:
        return s.run("""
            UNWIND $ids AS cid
            MATCH (c:Chunk {id:cid})-[:FROM_EXHIBIT]->(e:Exhibit)
            OPTIONAL MATCH (c)-[:MENTIONS]->(n)
            RETURN c.id AS id, e.id AS exhibit,
                   collect(DISTINCT {labels:labels(n), props:n{.*}})[0..$m] AS nodes
        """, ids=chunk_ids, m=max_nodes).data()

def format_evidence(chunks, max_chars=800):
    """Format evidence chunks for LLM prompt."""
    blocks = []
    for ch in chunks[:8]:
        blocks.append(
            f"[Exhibit {ch['ex']}, Chunk {ch['id']}]\n" +
            textwrap.shorten(ch["text"], max_chars)
        )
    return "\n\n".join(blocks) if blocks else "(none)"

def format_graph_facts(neighborhoods, max_items=3):
    """Format graph context for LLM prompt."""
    return json.dumps(neighborhoods[:max_items], indent=2)

In [13]:
def answer_question(question: str, preferred_exhibits=None, k=8, method="bm25") -> Dict:
    """Main function for answering questions using hybrid graph-RAG.
    
    Args:
        question: The question to answer
        preferred_exhibits: List of exhibit IDs to boost (e.g., ['E7', 'E5'])
        k: Number of chunks to retrieve
        method: Retrieval method - 'bm25', 'vector', or 'hybrid'
    
    Returns:
        Dict with 'answer', 'citations', and metadata
    """
    # Retrieve relevant chunks using specified method
    chunks = retrieve_chunks(question, k=k, preferred_exhibits=preferred_exhibits, method=method)
    
    # Expand with graph context
    neighborhoods = expand_neighborhood([c["id"] for c in chunks[:8]])
    
    answer = synthesize_answer(question, chunks, neighborhoods)
    # Generate grounded answer
    
    # Extract citations
    citations = sorted(list({c['ex'] for c in chunks}))
    
    return {
        "answer": answer,
        "citations": citations,
        "method": method,
        "chunks_used": len(chunks),
        "graph_entities": sum(len(n.get('nodes', [])) for n in neighborhoods)
    }

## LLM Synthesis with Citations

Combines retrieved evidence and graph context to generate grounded answers.

In [14]:
SYSTEM_INSTRUCTIONS = """You are assisting the DPP with charge brief QA.
Answer ONLY using the EVIDENCE and GRAPH FACTS provided.
Every factual sentence MUST end with a citation like [Exhibit {ex}, Chunk {id}].
If evidence is insufficient, explicitly say what exhibits/witnesses to check next.
Be concise, formal, and neutral; avoid speculation beyond the evidence."""

# Setup OpenAI client
OpenAI = None
if os.getenv("OPENAI_API_KEY"):
    try:
        from openai import OpenAI as _OpenAI
        OpenAI = _OpenAI()
        print("✅ OpenAI client configured")
    except Exception as e:
        print("OpenAI client not available:", e)

def _call_openai(messages, model="gpt-4o-mini", temperature=0.2):
    """Call OpenAI API with messages."""
    if OpenAI is None:
        raise RuntimeError("OpenAI client not configured")
    resp = OpenAI.chat.completions.create(
        model=model,
        temperature=temperature,
        messages=messages
    )
    return resp.choices[0].message.content.strip()

def synthesize_answer(question: str, chunks: List[Dict], neighborhoods: List[Dict],
                      model="gpt-4o-mini") -> str:
    """Generate grounded answer using OpenAI with proper citations."""
    evidence = format_evidence(chunks)
    graphfacts = format_graph_facts(neighborhoods)

    messages = [
        {"role": "system", "content": SYSTEM_INSTRUCTIONS},
        {"role": "user", "content": f"""EVIDENCE:
{evidence}

GRAPH FACTS:
{graphfacts}

QUESTION:
{question}"""}
    ]

    try:
        txt = _call_openai(messages, model=model)
    except Exception as e:
        print(f"LLM error: {e}")
        # Fallback using top chunk
        top = chunks[0] if chunks else None
        return "(fallback) Insufficient LLM availability." if not top else \
               f"(fallback) {top['text'][:140]}... [Exhibit {top['ex']}, Chunk {top['id']}]"

    # Ensure at least one citation
    if not re.search(r"\[Exhibit\s+E\d+,\s*Chunk\s+[0-9a-f-]+\]", txt, re.I):
        if chunks:
            txt += f" [Exhibit {chunks[0]['ex']}, Chunk {chunks[0]['id']}]"
    return txt

✅ OpenAI client configured


In [15]:
# Compare different retrieval methods on the same question
question = "What are some suspicious findings in the evidence?"

print(f"🔍 QUESTION: {question}\n")
print("=" * 80)

methods = ["bm25", "vector", "hybrid"]
for method in methods:
    print(f"\n📊 {method.upper()} METHOD:")
    print("-" * 40)
    
    try:
        result = answer_question(question, method=method)
        
        print(f"ANSWER: {result['answer']}")
        print(f"CITATIONS: {', '.join(result['citations'])}")
        print(f"STATS: {result['chunks_used']} chunks, {result['graph_entities']} graph entities")
    except Exception as e:
        print(f"ERROR: {e}")

🔍 QUESTION: What are some suspicious findings in the evidence?


📊 BM25 METHOD:
----------------------------------------
ANSWER: The evidence presents several suspicious findings:

1. The vehicle belonging to Dr. Valerie Somers, a blue Honda Civic, was found near Lantana Park, which raises questions about her disappearance and potential foul play [Exhibit E0, Chunk a8b83294-0944-4260-95cc-11f95ecf432f].

2. A partial latent fingerprint matching Patrick Phelan was found on the door frame of the recovered vehicle, indicating his presence at the scene [Exhibit E4, Chunk edfa2ecf-8540-4840-8dbf-8944bcb0bf8a].

3. Trace fibres consistent with a scarf worn by the victim were discovered in the boot lining of the vehicle, suggesting a connection between the victim and the accused [Exhibit E4, Chunk edfa2ecf-8540-4840-8dbf-8944bcb0bf8a].

4. Patrick Phelan initially denied contact with Dr. Somers on the day of her disappearance but later admitted to meeting her to "settle differences," which ra