References: 
1. https://machinelearningmastery.com/building-graph-rag-system-step-by-step-approach/
2. https://learnopencv.com/graphrag-explained-knowledge-graphs-medical/

---

## Comparing to Regular Vector Database Retrieval

<img src="./media/basic_retrieval.png" width=600>
 
To give some comparison, let's look back at traditional chunking, embedding, and similarity retrieval RAG

<img src="./media/RAG_QA.png" width=600 style="background-color: white;">

# Understanding Graph RAG: Beyond Traditional RAG Systems

## The Limitations of LLMs and Traditional RAG

LLMs rely on static knowledge, which means they only use the data they were trained on. This limitation often makes them prone to hallucinations—generating incorrect or fabricated information. To handle this, RAG systems were developed. Unlike LLMs, RAG retrieves data in real-time from external knowledge bases, using this fresh context to generate more accurate and relevant responses. These traditional RAG systems work by using text embeddings to retrieve specific information. While powerful, they come with limitations. If you've worked on RAG-related projects, you'll probably relate to this: the quality of the system's response heavily depends on the clarity and specificity of the query. But an even bigger challenge emerged — the inability to reason effectively across multiple documents.

## The Challenge of Multi-Document Reasoning

Imagine asking:

**"Which animals live together, and how do lions and hyenas interact?"** 🦁

### Traditional RAG:

The system might retrieve individual facts independently:

- **Document 1:** "Lions often live in groups called prides."
- **Document 2:** "Hyenas live in clans and frequently compete with lions."
- **Document 3:** "Zebras and wildebeests often herd together for protection."

<img src="./media/zebraGR.png" width=600 style="background-color: white;">

Because traditional RAG sees these documents separately, its response might be fragmented like:

> "Lions live in prides, and hyenas live in clans. Zebras and wildebeests herd together. Lions and hyenas compete."

This response gives basic facts but doesn't clearly explain how these animals interact.

---

### Graph RAG:

Graph RAG organizes information as connected pieces:

- **Nodes (Facts):**
  - Lions live in prides.
  - Hyenas live in clans.
  - Zebras and wildebeests herd together.

- **Edges (Relationships):**
  - Lions ↔ compete with ↔ Hyenas.
  - Lions → hunt → Zebras and Wildebeests.
  - Hyenas → scavenge from → Lions.

<img src="./media/zebraRAG.png" width=600 style="background-color: white;">

Using these connected facts, Graph RAG can deliver a richer answer:

> "Lions live together in prides, while hyenas live in clans. Both animals often compete directly, as lions hunt animals like zebras and wildebeests, and hyenas frequently scavenge food left by lions. Zebras and wildebeests typically herd together to protect themselves from predators like lions."

---

Graph RAG provides a deeper, more coherent understanding compared to traditional RAG, by effectively connecting the relationships between facts. 🌿🐯

### Graph RAG: Step-by-Step with Animals 🐾

#### Step 1: Source Documents → Text Chunks
Imagine we have a book about the jungle. It’s too big for a lion 🦁 to read at once, so we break it into smaller chunks:

- "Lions live in prides and hunt in groups."
- "Hyenas scavenge and sometimes challenge lions."
- "Zebras and wildebeests often herd together."

#### Step 2: Text Chunks → Element Instances
From each chunk, the system finds the important animals (nodes) and how they interact (edges):

<img src="./media/d1.png" width=300 style="background-color: white;">

#### Step 3: Element Instances → Element Summaries
Now we simplify the relationships:
- **Lion:** "A big cat that hunts in groups."
- **Hyena:** "A scavenger that competes with lions."
- **Lion → Zebra:** "Lions hunt zebras and wildebeests."
- **Hyena → Lion:** "Hyenas often challenge lions for food."

#### Step 4: Element Summaries → Graph Communities
We group related animals together:

<img src="./media/d2.png" width=300 style="background-color: white;">

#### Step 5: Graph Communities → Community Summaries
- **Predator Community:** "Lions and hyenas often compete. Lions hunt; hyenas scavenge."
- **Prey Community:** "Zebras and wildebeests herd for safety."

#### Step 6: Community Summaries → Community Answers → Global Answer
If someone asks: **"How do animals survive in the jungle?"**

The system gathers:
- "Lions hunt in groups."
- "Hyenas scavenge and sometimes steal food."
- "Zebras and wildebeests stick together for protection."

And replies:

> "In the jungle, animals survive by forming groups. Lions hunt together, hyenas scavenge and compete, while zebras and wildebeests herd for protection."

# Step-by-Step Implementation of GraphRAG with LlamaIndex

1. Data Preparation:
- Loading news articles as sample data
- Converting them to LlamaIndex Document objects
- Splitting documents into nodes using a SentenceSplitter

In [21]:
import pandas as pd
from llama_index.core import Document

# Load sample dataset
news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")[:50]

# Convert data into LlamaIndex Document objects
documents = [
    Document(text=f"{row['title']}: {row['text']}")
    for _, row in news.iterrows()
]

In [22]:
from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(
    chunk_size=1024,
    chunk_overlap=20,
)
nodes = splitter.get_nodes_from_documents(documents)

2. LLM Configuration:
- Setting up OpenAI's GPT-4 as the language model
- Configuring API keys using environment variables

In [23]:
from llama_index.llms.openai import OpenAI
from dotenv import load_dotenv
import os

openai_api_key = os.getenv('OPENAI_API_KEY')

llm = OpenAI(model="gpt-4", openai_api_key=openai_api_key)

3. Entity and Relationship Extraction:
- Defining regex patterns to extract entities and relationships
- Creating a parsing function to process LLM responses

In [24]:
import re
entity_pattern = r'entity_name:\s*(.+?)\s*entity_type:\s*(.+?)\s*entity_description:\s*(.+?)\s*'
relationship_pattern = r'source_entity:\s*(.+?)\s*target_entity:\s*(.+?)\s*relation:\s*(.+?)\s*relationship_description:\s*(.+?)\s*'

def parse_fn(response_str: str):
    entities = re.findall(entity_pattern, response_str)
    relationships = re.findall(relationship_pattern, response_str)
    return entities, relationships

4. Graph RAG Implementation:
- Creating a GraphRAGExtractor class that:
  - Extracts triples from text using an LLM
  - Parses the output to identify entities and relationships
  - Processes documents in parallel for efficiency
- Implementing a GraphRAGStore class that:
  - Builds communities from the graph using hierarchical clustering
  - Generates summaries for each community
  - Provides methods to access the graph structure

In [25]:
import asyncio
import nest_asyncio

nest_asyncio.apply()

from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display

from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
    default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
    EntityNode,
    KG_NODES_KEY,
    KG_RELATIONS_KEY,
    Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field
class GraphRAGExtractor(TransformComponent):
    """Extract triples from a graph.

    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.

    Args:
        llm (LLM):
            The language model to use.
        extract_prompt (Union[str, PromptTemplate]):
            The prompt to use for extracting triples.
        parse_fn (callable):
            A function to parse the output of the language model.
        num_workers (int):
            The number of workers to use for parallel processing.
        max_paths_per_chunk (int):
            The maximum number of paths to extract per chunk.
    """

    llm: LLM
    extract_prompt: PromptTemplate
    parse_fn: Callable
    num_workers: int
    max_paths_per_chunk: int

    def __init__(
        self,
        llm: Optional[LLM] = None,
        extract_prompt: Optional[Union[str, PromptTemplate]] = None,
        parse_fn: Callable = default_parse_triplets_fn,
        max_paths_per_chunk: int = 10,
        num_workers: int = 4,
    ) -> None:
        """Init params."""
        from llama_index.core import Settings

        if isinstance(extract_prompt, str):
            extract_prompt = PromptTemplate(extract_prompt)

        super().__init__(
            llm=llm or Settings.llm,
            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
            parse_fn=parse_fn,
            num_workers=num_workers,
            max_paths_per_chunk=max_paths_per_chunk,
        )

    @classmethod
    def class_name(cls) -> str:
        return "GraphExtractor"

    def __call__(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes."""
        return asyncio.run(
            self.acall(nodes, show_progress=show_progress, **kwargs)
        )

    async def _aextract(self, node: BaseNode) -> BaseNode:
        """Extract triples from a node."""
        assert hasattr(node, "text")

        text = node.get_content(metadata_mode="llm")
        try:
            llm_response = await self.llm.apredict(
                self.extract_prompt,
                text=text,
                max_knowledge_triplets=self.max_paths_per_chunk,
            )
            entities, entities_relationship = self.parse_fn(llm_response)
        except ValueError:
            entities = []
            entities_relationship = []

        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
        metadata = node.metadata.copy()
        for entity, entity_type, description in entities:
            metadata[
                "entity_description"
            ] = description  # Not used in the current implementation. But will be useful in future work.
            entity_node = EntityNode(
                name=entity, label=entity_type, properties=metadata
            )
            existing_nodes.append(entity_node)

        metadata = node.metadata.copy()
        for triple in entities_relationship:
            subj, rel, obj, description = triple
            subj_node = EntityNode(name=subj, properties=metadata)
            obj_node = EntityNode(name=obj, properties=metadata)
            metadata["relationship_description"] = description
            rel_node = Relation(
                label=rel,
                source_id=subj_node.id,
                target_id=obj_node.id,
                properties=metadata,
            )

            existing_nodes.extend([subj_node, obj_node])
            existing_relations.append(rel_node)

        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations
        return node

    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes async."""
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))

        return await run_jobs(
            jobs,
            workers=self.num_workers,
            show_progress=show_progress,
            desc="Extracting paths from text",
        )

5 Query Engine:
- Creating a GraphRAGQueryEngine that:
   - Processes queries against the graph structure
   - Generates answers from community summaries
   - Aggregates responses for comprehensive answers


In [26]:
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity")

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other

Format each relationship as ("relationship")

3. When finished, output.

-Real Data-
######################
text: {text}
######################
output:"""

In [8]:
kg_extractor = GraphRAGExtractor(
    llm=llm,
    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
    max_paths_per_chunk=2,
    parse_fn=parse_fn,
)

In [10]:
import re
from llama_index.core.graph_stores import SimplePropertyGraphStore
import networkx as nx
from graspologic.partition import hierarchical_leiden

from llama_index.core.llms import ChatMessage
class GraphRAGStore(SimplePropertyGraphStore):
    community_summary = {}
    max_cluster_size = 5

    def generate_community_summary(self, text):
        """Generate summary for a given text using an LLM."""
        messages = [
            ChatMessage(
                role="system",
                content=(
                    "You are provided with a set of relationships from a knowledge graph, each represented as "
                    "entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
                    "relationships. The summary should include the names of the entities involved and a concise synthesis "
                    "of the relationship descriptions. The goal is to capture the most critical and relevant details that "
                    "highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
                    "integrates the information in a way that emphasizes the key aspects of the relationships."
                ),
            ),
            ChatMessage(role="user", content=text),
        ]
        response = OpenAI().chat(messages)
        clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return clean_response

    def build_communities(self):
        """Builds communities from the graph and summarizes them."""
        nx_graph = self._create_nx_graph()
        community_hierarchical_clusters = hierarchical_leiden(
            nx_graph, max_cluster_size=self.max_cluster_size
        )
        community_info = self._collect_community_info(
            nx_graph, community_hierarchical_clusters
        )
        self._summarize_communities(community_info)

    def _create_nx_graph(self):
        """Converts internal graph representation to NetworkX graph."""
        nx_graph = nx.Graph()
        for node in self.graph.nodes.values():
            nx_graph.add_node(str(node))
        for relation in self.graph.relations.values():
            nx_graph.add_edge(
                relation.source_id,
                relation.target_id,
                relationship=relation.label,
                description=relation.properties["relationship_description"],
            )
        return nx_graph

    def _collect_community_info(self, nx_graph, clusters):
        """Collect detailed information for each node based on their community."""
        community_mapping = {item.node: item.cluster for item in clusters}
        community_info = {}
        for item in clusters:
            cluster_id = item.cluster
            node = item.node
            if cluster_id not in community_info:
                community_info[cluster_id] = []

            for neighbor in nx_graph.neighbors(node):
                if community_mapping[neighbor] == cluster_id:
                    edge_data = nx_graph.get_edge_data(node, neighbor)
                    if edge_data:
                        detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                        community_info[cluster_id].append(detail)
        return community_info

    def _summarize_communities(self, community_info):
        """Generate and store summaries for each community."""
        for community_id, details in community_info.items():
            details_text = (
                "\n".join(details) + "."
            )  # Ensure it ends with a period
            self.community_summary[
                community_id
            ] = self.generate_community_summary(details_text)

    def get_community_summaries(self):
        """Returns the community summaries, building them if not already done."""
        if not self.community_summary:
            self.build_communities()
        return self.community_summary

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex(
    nodes=nodes,
    property_graph_store=GraphRAGStore(),
    kg_extractors=[kg_extractor],
    show_progress=True,
)

Extracting paths from text:  64%|██████▍   | 32/50 [02:49<01:52,  6.27s/it]Retrying llama_index.llms.openai.base.OpenAI._achat in 0.6031067171095901 seconds as it raised RateLimitError: Error code: 429 - {'error': {'message': 'Rate limit reached for gpt-4 in organization org-BJfk33FRGf5jxB2bas1dHDrk on tokens per min (TPM): Limit 10000, Used 9085, Requested 1122. Please try again in 1.242s. Visit https://platform.openai.com/account/rate-limits to learn more.', 'type': 'tokens', 'param': None, 'code': 'rate_limit_exceeded'}}.
Extracting paths from text:  70%|███████   | 35/50 [03:10<01:48,  7.22s/it]Retrying llama_index.llms.openai.base.OpenAI._achat in 0.04205088204457996 seconds as it raised RateLimitError: Error code: 429 - {'error': {'message': 'Rate limit reached for gpt-4 in organization org-BJfk33FRGf5jxB2bas1dHDrk on tokens per min (TPM): Limit 10000, Used 9802, Requested 1122. Please try again in 5.544s. Visit https://platform.openai.com/account/rate-limits to learn more.', 'ty

In [13]:
index.property_graph_store.build_communities()

In [14]:
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.llms import LLM
class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore
    llm: LLM

    def custom_query(self, query_str: str) -> str:
        """Process all community summaries to generate answers to a specific query."""
        community_summaries = self.graph_store.get_community_summaries()
        community_answers = [
            self.generate_answer_from_summary(community_summary, query_str)
            for _, community_summary in community_summaries.items()
        ]

        final_answer = self.aggregate_answers(community_answers)
        return final_answer

    def generate_answer_from_summary(self, community_summary, query):
        """Generate an answer from a community summary based on a given query using LLM."""
        prompt = (
            f"Given the community summary: {community_summary}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            ),
        ]
        response = self.llm.chat(messages)
        cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return cleaned_response

    def aggregate_answers(self, community_answers):
        """Aggregate individual community answers into a final, coherent response."""
        # intermediate_text = " ".join(community_answers)
        prompt = "Combine the following intermediate answers into a final, concise response."
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content=f"Intermediate answers: {community_answers}",
            ),
        ]
        final_response = self.llm.chat(messages)
        cleaned_final_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_final_response

In [15]:
query_engine = GraphRAGQueryEngine(
    graph_store=index.property_graph_store, llm=llm
)
response = query_engine.query("What are news related to financial sector?")
display(Markdown(f"{response.response}"))

The only news related to the financial sector is that Nirmal Bang has given a Buy Rating to Tata Chemicals Ltd. (TTCH), indicating a positive investment recommendation. The rest of the provided information does not contain any news related to the financial sector.

---
## Discussion

**Traditional/Naive RAG:**

Benefits:
- Simpler implementation and deployment
- Works well for straightforward information retrieval tasks
- Good at handling unstructured text data
- Lower computational overhead

Drawbacks:
- Loses structural information when chunking documents
- Can break up related content during text segmentation
- Limited ability to capture relationships between different pieces of information
- May struggle with complex reasoning tasks requiring connecting multiple facts
- Potential for incomplete or fragmented answers due to chunking boundaries

**GraphRAG:**

Benefits:
- Preserves structural relationships and hierarchies in the knowledge
- Better at capturing connections between related information
- Can provide more complete and contextual answers
- Improved retrieval accuracy by leveraging graph structure
- Better supports complex reasoning across multiple facts
- Can maintain document coherence better than chunk-based approaches
- More interpretable due to explicit knowledge representation

Drawbacks:
- More complex to implement and maintain
- Requires additional processing to construct and update knowledge graphs
- Higher computational overhead for graph operations
- May require domain expertise to define graph schema/structure
- More challenging to scale to very large datasets
- Additional storage requirements for graph structure

**Key Differentiators:**
1. Knowledge Representation: Traditional RAG treats everything as flat text chunks, while GraphRAG maintains structured relationships in a graph format

2. Context Preservation: GraphRAG better preserves context and relationships between different pieces of information compared to the chunking approach of traditional RAG

3. Reasoning Capability: GraphRAG enables better multi-hop reasoning and connection of related facts through graph traversal, while traditional RAG is more limited to direct retrieval

4. Answer Quality: GraphRAG tends to produce more complete and coherent answers since it can access related information through graph connections rather than being limited by chunk boundaries

The choice between traditional RAG and GraphRAG often depends on the specific use case, with GraphRAG being particularly valuable when maintaining relationships between information is important or when complex reasoning is required. An important note as well, GraphRAG approaches still rely on regular embedding and retrieval methods themselves. They compliment eahcother!