In [2]:

import os
import json
from typing import List, Dict, Tuple, Any

import networkx as nx
from dotenv import load_dotenv
from pymongo import MongoClient

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain_core.prompts import PromptTemplate


import utils  # your utility module

In [3]:
load_dotenv(override=True)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MONGODB_URI = os.environ.get("MONGODB_URI")


if not OPENAI_API_KEY:
    raise RuntimeError("Missing OPENAI_API_KEY in environment.")
if not MONGODB_URI:
    raise RuntimeError("Missing MONGODB_URI in environment.")


mongo_client = MongoClient(MONGODB_URI)
DB_NAME = "RAG-evaluation"
COLL_NAME = "RAG-graph"
collection = mongo_client[DB_NAME][COLL_NAME]

In [4]:
text = utils.read_data()
print(f"Data loaded: {len(text)} characters")


splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents([text])
# inject stable ids for cross-referencing in graph
for i, d in enumerate(docs):
    d.metadata = d.metadata or {}
    d.metadata.update({"chunk_id": i, "source": "project_corpus"})
print(f"Number of chunks: {len(docs)}")

Data loaded: 45744 characters
Number of chunks: 101


In [5]:
text = utils.read_data()
print(f"Data loaded: {len(text)} characters")


splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents([text])
# inject stable ids for cross-referencing in graph
for i, d in enumerate(docs):
    d.metadata = d.metadata or {}
    d.metadata.update({"chunk_id": i, "source": "project_corpus"})
print(f"Number of chunks: {len(docs)}")

Data loaded: 45744 characters
Number of chunks: 101


In [6]:
embedding = OpenAIEmbeddings(model="text-embedding-3-small")


VECTOR_INDEX = "vector_index_graph"
vectorstore = MongoDBAtlasVectorSearch.from_documents(
    documents=docs,
    embedding=embedding,
    collection=collection,
    index_name=VECTOR_INDEX,
)
print("Vector store created / updated.")


retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 8})

  embedding = OpenAIEmbeddings(model="text-embedding-3-small")


Vector store created / updated.


In [7]:
TRIPLE_PROMPT = PromptTemplate.from_template(
"""
    Extract up to 8 salient knowledge triples from the text.
    Use concise entities; prefer proper nouns; avoid pronouns.
    Return ONLY JSON in the following schema (no prose):
{{"triples":[{{"subject":"...","relation":"...","object":"...","chunk_id":CHUNK_ID}}]}}


Text (chunk_id={chunk_id}):
{chunk_text}
"""
)

In [8]:
llm_extractor = OpenAI(temperature=0)

  llm_extractor = OpenAI(temperature=0)


In [13]:
def _message_to_str(msg: Any) -> str:
    """Robustly convert AIMessage/content to text for JSON parsing."""
    content = getattr(msg, "content", msg)
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts = []
        for p in content:
            if isinstance(p, dict) and p.get("type") == "text":
                parts.append(p.get("text", ""))
            elif hasattr(p, "text"):
                parts.append(getattr(p, "text"))
        return "\n".join(parts)
    return str(content)

In [14]:
def extract_triples_for_chunk(chunk_text: str, chunk_id: int) -> List[Dict[str, Any]]:
    prompt = TRIPLE_PROMPT.format(chunk_text=chunk_text, chunk_id=chunk_id)
    raw = llm_extractor.invoke(prompt)
    text_out = _message_to_str(raw).strip()
    try:
        data = json.loads(text_out)
        triples = data.get("triples", [])
        for t in triples:
            t.setdefault("chunk_id", chunk_id)
        return triples
    except Exception:
        # If the model didn't return clean JSON, ignore this chunk's triples
        return []


In [15]:
def build_graph_from_docs(documents) -> Tuple[nx.MultiDiGraph, Dict[str, set]]:
    G = nx.MultiDiGraph()
    entity_to_chunks: Dict[str, set] = {}

    for d in documents:
        cid = d.metadata.get("chunk_id")
        triples = extract_triples_for_chunk(d.page_content, cid)
        for t in triples:
            s = t.get("subject", "").strip()
            r = t.get("relation", "").strip()
            o = t.get("object", "").strip()
            if not (s and r and o):
                continue
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, relation=r, chunk_id=cid)

            entity_to_chunks.setdefault(s, set()).add(cid)
            entity_to_chunks.setdefault(o, set()).add(cid)

    return G, entity_to_chunks




In [16]:
G, entity_index = build_graph_from_docs(docs)
print(f"Graph built: |V|={G.number_of_nodes()} |E|={G.number_of_edges()}")


KeyboardInterrupt: 

In [17]:
def persist_graph(G: nx.MultiDiGraph):
    nodes_coll = mongo_client[DB_NAME]["graph_nodes"]
    edges_coll = mongo_client[DB_NAME]["graph_edges"]
    nodes_coll.delete_many({"graph": COLL_NAME})
    edges_coll.delete_many({"graph": COLL_NAME})

    nodes_payload = [{"graph": COLL_NAME, "node": n, "degree": int(G.degree(n))} for n in G.nodes()]
    edges_payload = [
        {"graph": COLL_NAME, "u": u, "v": v, "relation": data.get("relation"), "chunk_id": data.get("chunk_id")}
        for u, v, data in G.edges(data=True)
    ]
    if nodes_payload:
        nodes_coll.insert_many(nodes_payload)
    if edges_payload:
        edges_coll.insert_many(edges_payload)
    print("Graph persisted to Mongo (graph_nodes, graph_edges).")


persist_graph(G)

Graph persisted to Mongo (graph_nodes, graph_edges).


In [19]:
import matplotlib.pyplot as plt


In [20]:
def draw_graph(G: nx.MultiDiGraph, out_path: str = "graph.png"):
    plt.figure(figsize=(12, 10))
    pos = nx.spring_layout(G, k=0.6, seed=42)
    degrees = dict(G.degree())
    node_sizes = [200 + 30 * degrees[n] for n in G.nodes()]

    nx.draw_networkx_nodes(G, pos, node_size=node_sizes, alpha=0.8)
    nx.draw_networkx_edges(G, pos, alpha=0.3)
    top_nodes = sorted(degrees, key=degrees.get, reverse=True)[:20]
    nx.draw_networkx_labels(G, pos, labels={n: n for n in top_nodes}, font_size=9)

    edge_labels = {}
    for u, v, data in G.edges(data=True):
        if u in top_nodes and v in top_nodes:
            edge_labels[(u, v)] = data.get("relation", "")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"Graph image saved to {out_path}")


draw_graph(G)

Graph image saved to graph.png


In [21]:
ENTITY_PROMPT_TMPL = PromptTemplate.from_template(
    """
    List the 3-7 most relevant entities (proper nouns or concise concepts) in the query.
    Return ONLY a JSON list of strings, e.g. ["Entity A","Entity B"].

    Query: {query}
    """
)

In [22]:
def extract_query_entities(query: str) -> List[str]:
    raw = llm_extractor.invoke(ENTITY_PROMPT_TMPL.format(query=query))
    text_out = _message_to_str(raw).strip()
    try:
        ents = json.loads(text_out)
        if isinstance(ents, list):
            return [str(e) for e in ents][:10]
    except Exception:
        pass
    return []

In [23]:
def graph_expand_chunks(entities: List[str], hops: int = 1) -> List[int]:
    chunk_ids: set = set()
    frontier = set(entities)
    visited = set()

    for _ in range(max(1, hops)):
        next_frontier = set()
        for e in frontier:
            if e in visited:
                continue
            visited.add(e)
            for cid in entity_index.get(e, set()):
                chunk_ids.add(cid)
            if e in G:
                for _, nbr, data in G.out_edges(e, data=True):
                    next_frontier.add(nbr)
                for nbr, _, data in G.in_edges(e, data=True):
                    next_frontier.add(nbr)
        frontier = next_frontier

    return sorted(chunk_ids)

In [24]:
def retrieve_graph_augmented(query: str, k_vector: int = 5, hops: int = 1) -> List[Dict[str, Any]]:
    vec_docs = retriever.get_relevant_documents(query)
    vec_hits = [(d.metadata.get("chunk_id"), d.page_content) for d in vec_docs]

    ents = extract_query_entities(query)
    graph_chunk_ids = graph_expand_chunks(ents, hops=hops)
    graph_docs = [d for d in docs if d.metadata.get("chunk_id") in graph_chunk_ids]

    by_id: Dict[int, Dict[str, Any]] = {}
    for cid, content in vec_hits:
        if cid is None:
            continue
        by_id[cid] = {"chunk_id": cid, "content": content, "source": "vector"}
    for d in graph_docs:
        cid = d.metadata.get("chunk_id")
        if cid not in by_id:
            by_id[cid] = {"chunk_id": cid, "content": d.page_content, "source": "graph"}

    ordered = [by_id[cid] for cid, _ in vec_hits] + [v for cid, v in by_id.items() if v["source"] == "graph" and cid not in [id for id, _ in vec_hits]]
    return ordered


In [25]:
ANSWER_PROMPT = PromptTemplate.from_template(
    """
    You are answering based on the provided context passages from a vector+graph retriever.
    - Cite entities and relationships when relevant.
    - If unsure, say you are unsure.

    Question: {question}

    Context:\n{context}

    Helpful, concise answer:
    """
)


In [27]:
from langchain.chat_models import ChatOpenAI

qa_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

  qa_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


In [28]:
def answer_with_graphrag(question: str, k_vector: int = 5, hops: int = 1) -> str:
    ctx_items = retrieve_graph_augmented(question, k_vector=k_vector, hops=hops)
    ctx_lines = []
    for item in ctx_items[:12]:
        ctx_lines.append(f"[chunk {item['chunk_id']} | {item['source']}] {item['content']}")
    context = "\n".join(ctx_lines)

    prompt = ANSWER_PROMPT.format(question=question, context=context)
    resp = qa_llm.invoke(prompt)
    return _message_to_str(resp)



In [29]:
if __name__ == "__main__":
    q = "how to delay skin aging"
    print("\n>>> GraphRAG answer:\n")
    print(answer_with_graphrag(q, k_vector=5, hops=1))



>>> GraphRAG answer:



  vec_docs = retriever.get_relevant_documents(query)


To delay skin aging, consider the following strategies:

1. **Sun Protection**: Use broad-spectrum sunscreen with an SPF of at least 30 daily to protect against UV damage, which can accelerate skin aging.

2. **Moisturization**: Keep the skin hydrated with moisturizers that contain hyaluronic acid, glycerin, or ceramides to maintain skin elasticity and prevent dryness.

3. **Healthy Diet**: Consume a balanced diet rich in antioxidants (found in fruits and vegetables), omega-3 fatty acids (found in fish and nuts), and vitamins (especially vitamins C and E) to support skin health.

4. **Hydration**: Drink plenty of water to keep the skin hydrated from the inside out.

5. **Avoid Smoking and Excessive Alcohol**: Both can contribute to premature skin aging by reducing blood flow and depleting essential nutrients.

6. **Regular Exercise**: Physical activity improves circulation and can promote a healthier complexion.

7. **Skincare Products**: Use products containing retinoids, peptides, an