# 🧠 Historical GraphRAG Notebook
This notebook demonstrates how to use a lightweight knowledge graph with LangChain + Azure OpenAI to extract and reason over relationships between historical entities.

## 🔧 Environment Setup

In [None]:
%pip install --upgrade langchain langchain-core



In [None]:
import os
import pandas as pd
import networkx as nx
from collections import defaultdict
from langchain.chat_models import AzureChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser  # ✅ CORRECT
from langchain_community.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from dotenv import load_dotenv

load_dotenv()

## 🤖 Azure OpenAI Setup

In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv

# Get the parent directory of the current working directory (where the notebook is running)
parent_dir = Path.cwd().parent
env_path = parent_dir / ".env"

# Load the .env file from the parent directory
load_dotenv(dotenv_path=env_path, override=True)

AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AzureOpenAIEmbeddingsModel = os.getenv("AzureOpenAIEmbeddingsModel", "text-embedding-ada-002")
AzureChatOpenAIModel = os.getenv("AzureChatOpenAIModel")

In [None]:
# load llm
from langchain_openai import AzureChatOpenAI
llm = AzureChatOpenAI(
    model=AzureChatOpenAIModel,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_key=AZURE_OPENAI_API_KEY,
    api_version=AZURE_OPENAI_API_VERSION,    
    temperature=1
)

## 📄 Load and Chunk Historical Text

In [None]:
loader = CSVLoader(file_path="historical_figures.csv")
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
documents = splitter.split_documents(docs)

## 🏗️ Entity Extraction Prompt (with Example)

In [None]:
dynamic_entity_prompt = PromptTemplate.from_template("""
You are an expert in information extraction. Analyze the following text and extract factual triples and new entities.

Extract factual triples and new entities from the following text.

- Return the triples as plain lines in the format: Subject, Predicate, Object
- Do not use dashes, bullets, or numbering.
- Avoid extra punctuation like periods at the end.
- Return the new entities as a plain list, one per line.
- Do not include the triples in the entity list.

Text:
{input}

Known Entities:
{known}

Respond in this format exactly:

Triples:
Subject1, Predicate1, Object1
Subject2, Predicate2, Object2

New Entities:
Entity1
Entity2
Entity3
""")
dynamic_entity_chain = dynamic_entity_prompt | llm | StrOutputParser()

## 🌐 Build Knowledge Graph from Triples

In [None]:
import os
import time
import json
from networkx.readwrite import json_graph

# === CONFIG ===
REINDEX = True  # Set to False to skip LLM re-indexing if data exists
GRAPH_PATH = "historical_graph.json"
ENTITIES_PATH = "known_entities.json"

# === Setup containers ===
G = nx.MultiDiGraph()
all_triples = []
entity_usage = defaultdict(int)
known_entities = set()
MAX_ENTITIES_FOR_PROMPT = 40

if not REINDEX and os.path.exists(GRAPH_PATH) and os.path.exists(ENTITIES_PATH):
    print("🔁 Loading graph and entities from disk...")
    with open(GRAPH_PATH, "r", encoding="utf-8") as f:
        data = json.load(f)
        G = json_graph.node_link_graph(data, directed=True, multigraph=True)
    with open(ENTITIES_PATH, "r", encoding="utf-8") as f:
        known_entities = set(json.load(f))
    print(f"✅ Loaded graph with {len(G.nodes)} nodes and {len(G.edges)} edges.")
else:
    print("⚙️ Rebuilding graph from documents using LLM...")
    for i, doc in enumerate(documents):
        start = time.time()

        sorted_entities = sorted(known_entities, key=lambda e: -entity_usage[e])
        limited_entities = sorted_entities[:MAX_ENTITIES_FOR_PROMPT]
        known_str = ", ".join(limited_entities) if limited_entities else "(none)"

        output = dynamic_entity_chain.invoke({"input": doc.page_content, "known": known_str})
        print(f"\n--- LLM Output for doc {i} ---\n{output.strip()}")

        sections = output.strip().split("New Entities:")
        triples_block = sections[0].replace("Triples:", "").strip()
        new_entities_block = sections[1].strip() if len(sections) > 1 else ""

        triples = []
        for line in triples_block.split("\n"):
            if line.strip():
                parts = [p.strip(" ()\n") for p in line.split(",")]
                if len(parts) == 3:
                    triples.append(tuple(parts))
                else:
                    print(f"⚠️ Skipping malformed triple line: {line}")

        # new_entities = [e.strip(" \"'") for e in new_entities_block.split("\n") if e.strip()]
        # Normalize new entities by stripping leading/trailing spaces and quotes
        new_entities = [e.lstrip("- ").strip(" \"'\n") for e in new_entities_block.split("\n") if e.strip()]


        all_triples.extend(triples)
        known_entities.update(new_entities)

        for s, r, o in triples:
            # Normalize subject and object by stripping leading/trailing spaces and hyphens
            s = s.lstrip("- ").strip()
            r = r.strip(" .-").lower()
            o = o.lstrip("- ").strip()
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, relation=r)
            entity_usage[s] += 1
            entity_usage[o] += 1

        print(f"⏱️ Processed doc {i} in {round(time.time() - start, 2)}s")

    # === Save graph and entity index ===
    print("💾 Saving graph and entities to disk...")
    with open(GRAPH_PATH, "w", encoding="utf-8") as f:
        json.dump(json_graph.node_link_data(G), f, ensure_ascii=False, indent=2)
    with open(ENTITIES_PATH, "w", encoding="utf-8") as f:
        json.dump(sorted(list(known_entities)), f, ensure_ascii=False, indent=2)

    print(f"✅ Graph built with {len(G.nodes)} nodes, {len(G.edges)} edges, and {len(known_entities)} tracked entities.")


In [None]:
def search_graph(anchor_entity: str, depth: int = 2) -> nx.MultiDiGraph:
    if anchor_entity not in G:
        raise ValueError(f"Entity '{anchor_entity}' not found in the graph.")

    # Use BFS to find all nodes within the specified depth
    bfs_edges = nx.bfs_edges(G, source=anchor_entity, depth_limit=depth)
    nodes_in_scope = {anchor_entity}
    edges_in_scope = []

    for u, v in bfs_edges:
        nodes_in_scope.update([u, v])
        edges_in_scope.append((u, v))

    # Create a new subgraph from the collected nodes and edges
    subgraph = nx.MultiDiGraph()
    for u, v in edges_in_scope:
        for key in G[u][v]:
            relation = G[u][v][key].get("relation", "")
            subgraph.add_edge(u, v, relation=relation)
            subgraph.add_node(u)
            subgraph.add_node(v)

    return subgraph


In [None]:
def search_graph(anchor: str, depth: int = 2) -> nx.Graph:
    if anchor not in G:
        raise ValueError(f"Entity '{anchor}' not found in the graph.")

    undirected_G = G.to_undirected()
    visited_nodes = set([anchor])
    current_level = set([anchor])

    for _ in range(depth):
        next_level = set()
        for node in current_level:
            neighbors = undirected_G.neighbors(node)
            for neighbor in neighbors:
                if neighbor not in visited_nodes:
                    next_level.add(neighbor)
        visited_nodes.update(next_level)
        current_level = next_level

    subgraph = G.subgraph(visited_nodes).copy()
    return subgraph


## 📌 Anchor Detection from Question

In [None]:
def detect_anchor_entity(question: str) -> str:
    MAX_ENTITIES_FOR_ANCHOR_PROMPT = 40

    # Use global variables for known_entities and entity_usage
    sorted_entities = sorted(known_entities, key=lambda e: -entity_usage.get(e, 0))
    limited_entities = sorted_entities[:MAX_ENTITIES_FOR_ANCHOR_PROMPT]
    known_entities_list = "\n".join(f"- {e}" for e in limited_entities) if limited_entities else "(none)"

    detect_prompt = PromptTemplate.from_template(f"""
You are a semantic matcher. Your task is to identify the main entity (person, event, or discovery) from the list below that the question is primarily about.

Only return the exact entity name from the list. Do not explain.

Known Entities:
{known_entities_list}

Question: {{question}}
""")

    detect_chain = detect_prompt | llm | StrOutputParser()
    return detect_chain.invoke({"question": question}).strip()


In [None]:
def visualize_subgraph(query: str, depth: int = 2):
    anchor_entity = detect_anchor_entity(query)
    print(f"🔍 Detected anchor entity: {anchor_entity}")

    if anchor_entity not in G:
        print(f"⚠️ Entity '{anchor_entity}' not found in the graph.")
        return

    # Extract the subgraph based on depth
    visited = set()
    queue = [(anchor_entity, 0)]
    sub_nodes = set()

    while queue:
        current_node, current_depth = queue.pop(0)
        if current_depth > depth or current_node in visited:
            continue
        visited.add(current_node)
        sub_nodes.add(current_node)
        neighbors = list(G.successors(current_node)) + list(G.predecessors(current_node))
        for neighbor in neighbors:
            queue.append((neighbor, current_depth + 1))
            sub_nodes.add(neighbor)

    SG = G.subgraph(sub_nodes)

    # Visualize
    plt.figure(figsize=(10, 7))
    pos = nx.spring_layout(SG, seed=42)
    nx.draw_networkx_nodes(SG, pos, node_size=600, node_color="#FFEEEE")
    nx.draw_networkx_edges(SG, pos, arrows=True, arrowstyle='-|>', edge_color="#666")
    nx.draw_networkx_labels(SG, pos, font_size=10)
    edge_labels = {(u, v): data["relation"] for u, v, data in SG.edges(data=True)}
    nx.draw_networkx_edge_labels(SG, pos, edge_labels=edge_labels, font_size=9)
    plt.title(f"Subgraph for: '{anchor_entity}'", fontsize=14)
    plt.axis("off")
    plt.show()




## 🤔 Ask a Question Using the Graph

In [None]:
def answer_question_with_graph(query: str, depth: int = 2):
    anchor_entity = detect_anchor_entity(query)
    print(f"🔍 Detected anchor entity: {anchor_entity}")

    subgraph = search_graph(anchor_entity, depth=depth)
    context = [f"{u} {data['relation']} {v}." for u, v, data in subgraph.edges(data=True)]
    context_text = "\n".join(context)

    response = llm.invoke(f"Answer the question based on these facts:\n\n{context_text}\n\nQuestion: {query}")

    return context, response


## 🚀 Try It Out

In [None]:
question = "How did Albert Einstein contribute to the development of the atomic bomb?"
# Set the depth for the subgraph search
depth= 2
retrieved_chunks, response = answer_question_with_graph(question, depth)

print("\nRetrieved Chunks:")
for line in retrieved_chunks:
    print("-", line)

print("\nAnswer:")
print(response)

visualize_subgraph(question, depth)


In [None]:
print(list(G.nodes))


In [None]:
for u, v, data in G.edges(data=True):
    if "relativity" in u.lower() or "relativity" in v.lower():
        print(f"{u} --{data['relation']}--> {v}")
