In [1]:
!pip install torch transformers faiss-cpu networkx numpy


Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.8.0.post1


# What is GraphRAG?

**Graph Retrieval-Augmented Generation (GraphRAG)** combines the principles of graph-based retrieval with RAG models to enhance information retrieval and generation processes. The idea is to use graph structures to represent and retrieve information, and then use generative models to produce contextually relevant responses based on the retrieved information.

Here's a detailed overview and example of how to implement **GraphRAG**

**Overview of GraphRAG**

**Graph Representation**

**Nodes** Represent entities, documents, or concepts.

**Edges** Represent relationships or similarities between nodes.

**Graph Construction** Construct a graph where nodes are connected based on semantic similarity or other criteria.

**Graph-Based Retrieval**

**Graph Traversal** Retrieve relevant nodes/documents by traversing the graph based on a query.

**Node Embeddings** Use embeddings to represent nodes, enabling efficient similarity searches.
Generation:

**Contextual Augmentation** Use the retrieved nodes or documents to augment the query.
Text Generation: Generate responses using a generative model (e.g., GPT-2) based on the augmented query.

**Example Implementation**

Here’s a step-by-step example of implementing GraphRAG using Python:

# Create Graph

Use networkx to create a graph and add nodes and edges.

In [6]:
import networkx as nx
import numpy as np
import torch
from transformers import BertTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer
import faiss

# Initialize models
tokenizer_bert = BertTokenizer.from_pretrained('bert-base-uncased')
model_bert = BertModel.from_pretrained('bert-base-uncased')
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Create a graph
G = nx.Graph()

# Sample documents and add nodes with text as attributes
documents = [
    "about skin cancer.",
    "discussing treatment options for skin cancer.",
    "explaining symptoms of skin cancer."
]

for i, doc in enumerate(documents):
    G.add_node(i, text=doc)

# Add edges (example: based on cosine similarity of embeddings)
def compute_similarity(node1, node2):
    embedding1 = get_node_embedding(node1)
    embedding2 = get_node_embedding(node2)
    return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))

def get_node_embedding(node):
    inputs = tokenizer_bert(G.nodes[node]['text'], return_tensors='pt', max_length=512, truncation=True)
    with torch.no_grad():
        embedding = model_bert(**inputs).pooler_output
    return embedding.squeeze().numpy()

for i in G.nodes:
    for j in G.nodes:
        if i < j:
            similarity = compute_similarity(i, j)
            if similarity > 0.5:  # Threshold for adding an edge
                G.add_edge(i, j, weight=similarity)


# Graph-Based Retrieval 

Retrieve nodes/documents using graph traversal.

In [7]:
def retrieve_documents(query, top_k=2):
    # Add query node to the graph
    G.add_node('query', text=query)
    for i in G.nodes:
        if i != 'query':
            similarity = compute_similarity('query', i)
            if similarity > 0.5:
                G.add_edge('query', i, weight=similarity)
    
    # Find top-k nodes connected to the query node
    neighbors = list(G.neighbors('query'))
    neighbors = sorted(neighbors, key=lambda x: G['query'][x]['weight'], reverse=True)
    top_neighbors = neighbors[:top_k]
    
    # Return the text of top-k nodes
    return [G.nodes[n]['text'] for n in top_neighbors]


# Generate Response
Augment the query with retrieved documents and generate a response.

In [8]:
def generate_response(query, retrieved_docs):
    augmented_query = query + " " + " ".join(retrieved_docs)
    inputs_gpt2 = tokenizer_gpt2(augmented_query, return_tensors='pt', max_length=512, truncation=True)
    response = model_gpt2.generate(**inputs_gpt2, max_length=150, num_beams=5, early_stopping=True)
    return tokenizer_gpt2.decode(response[0], skip_special_tokens=True)

# Example usage
query = "Describe the treatment options for skin cancer."
retrieved_docs = retrieve_documents(query, top_k=2)
generated_response = generate_response(query, retrieved_docs)

print("Query:", query)
print("Retrieved Documents:")
for doc in retrieved_docs:
    print(f"- {doc}")
print("Generated Response:", generated_response)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Query: Describe the treatment options for skin cancer.
Retrieved Documents:
- discussing treatment options for skin cancer.
- explaining symptoms of skin cancer.
Generated Response: Describe the treatment options for skin cancer. discussing treatment options for skin cancer. explaining symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. discussing symptoms of skin cancer. disc