In [None]:
# Environment variables
from dotenv import load_dotenv
import os

load_dotenv(".env")

api_key = os.getenv("DEEPSEEK_API_KEY")

In [None]:
# Setup
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import TokenTextSplitter
from langchain_openai import ChatOpenAI

from collections import defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import regex as re
import json
import json_repair
import pickle
import uuid

from prompt import *
import utils

llm = ChatOpenAI(
    model="deepseek-chat",
    api_key=api_key,
    base_url="https://api.deepseek.com/v1"
)

In [None]:
# Load and chunk text
filename = "symposium_short.txt"
loader = TextLoader(filename)
docs = loader.load()

text_splitter = TokenTextSplitter(chunk_size=1200)
chunks = text_splitter.split_documents(docs)

# Embedding function initialization
embedding_function = OllamaEmbeddings(
    model="bge-m3:567m",
)

In [None]:
# Extraction
tuple_delimiter = PROMPTS["DEFAULT_TUPLE_DELIMITER"] 
record_delimiter = PROMPTS["DEFAULT_RECORD_DELIMITER"]  
completion_delimiter = PROMPTS["DEFAULT_COMPLETION_DELIMITER"] 

prompts = []
for chunk in chunks:
    prompt_template = PROMPTS["entity_extraction"]
    examples_template = PROMPTS["entity_extraction_examples"]
    examples = "\n".join(examples_template).format(
        tuple_delimiter=tuple_delimiter,
        record_delimiter=record_delimiter,
        completion_delimiter=completion_delimiter,
    )
    prompt = prompt_template.format(
        language="English",
        entity_types=PROMPTS["DEFAULT_ENTITY_TYPES"],
        tuple_delimiter=tuple_delimiter,
        record_delimiter=record_delimiter,
        completion_delimiter=completion_delimiter,
        examples=examples,
        input_text=chunk.page_content
    )
    prompts.append(prompt)

cache_file = f"{filename} responses_cache.pkl"

if os.path.exists(cache_file):
    print("Loading responses from cache...")
    with open(cache_file, 'rb') as f:
        responses = pickle.load(f)
else:
    print("Fetching responses from LLM and caching...")
    responses = llm.batch(prompts)
    with open(cache_file, 'wb') as f:
        pickle.dump(responses, f)
print("Responses loaded successfully.")

In [None]:
# Parsing extraction output
graph_nodes = defaultdict(lambda: {"descriptions": [], "type": "", "chunk_ids": []})
graph_edges = defaultdict(lambda: {"descriptions": [], "keywords": [], "weight": 1.0, "chunk_ids": []})

for chunk_id, response in enumerate(responses):
    response = response.content
    
    records = utils.split_string_by_multi_markers(response, [record_delimiter, completion_delimiter])

    for record in records:
        record = re.search(r"\((.*)\)", record)
        if record:
            continue
        record = record.group(1)
        record_attributes = utils.split_string_by_multi_markers(record, tuple_delimiter)

        # Parse entity
        entity_data = utils._handle_single_entity_extraction(record_attributes, llm)
        if entity_data:
            name = entity_data["entity_name"]
            type = entity_data["entity_type"]
            description = entity_data["description"]

            # Deduplication by name
            graph_nodes[name]["descriptions"].append(description)
            graph_nodes[name]["chunk_ids"].append(chunk_id)

        # Parse relation    
        relation_data = utils._handle_single_relation_extraction(record_attributes, llm)
        if relation_data:
            source = relation_data["src_id"]
            target = relation_data["tgt_id"]
            weight = relation_data["weight"]
            description = relation_data["description"]
            keywords = relation_data["keywords"]

            key = (source, target)
            graph_edges[key]["keywords"].append(keywords)
            graph_edges[key]["weight"] = weight

            # Deduplication by key
            graph_edges[key]["descriptions"].append(description)
            graph_edges[key]["chunk_ids"].append(chunk_id)

In [None]:
# Deduplication
processed_graph_cache_file = f"{filename} processed_graph.pkl"

temp_entities_vectorstore = Chroma(
        collection_name="temp_entities",
        embedding_function=embedding_function,
        persist_directory="./graph",
        collection_metadata={"hnsw:space": "cosine"}
    )

if os.path.exists(processed_graph_cache_file):
    print("Loading processed graph nodes, edges, and summaries from cache...")
    with open(processed_graph_cache_file, 'rb') as f:
        cached_data = pickle.load(f)
        graph_nodes = cached_data['graph_nodes']
        graph_edges = cached_data['graph_edges']
        summaries = cached_data['summaries']
    print("Success.")
else:
    print("Deduplicating and summarizing...")

    # Clear the temporary collection
    temp_entities_vectorstore.delete_collection()
    
    temp_entities_vectorstore = Chroma(
        collection_name="temp_entities",
        embedding_function=embedding_function,
        persist_directory="./graph",
        collection_metadata={"hnsw:space": "cosine"}
    )
        
    temp_entities_vectorstore.add_texts([str(key) for key in graph_nodes.keys()])

    similarity_threshold = 0.7 # tunable

    # Cluster entities by names
    entity_clusters = {}
    processed_names = set()
    for name in graph_nodes.keys():
        if name in processed_names:
            continue

        search_results = temp_entities_vectorstore.similarity_search_with_relevance_scores(name, k=5)
        
        cluster = [name]
        processed_names.add(name)
        for document, score in search_results:
            variant = document.page_content
            if score > similarity_threshold and name != variant:
                cluster.append(variant)
                processed_names.add(variant)
        entity_clusters[name] = sorted(list(set(cluster)))
    
    name_to_canonical = {}
    for name, variants in entity_clusters.items():
        canonical = name
        for variant in variants:
            name_to_canonical[variant] = canonical
    

    merged_graph_nodes = defaultdict(lambda: {"canonical": "", "type": "", "descriptions": [], "chunk_ids": []})
    merged_graph_edges = defaultdict(lambda: {"descriptions": [], "keywords": [], "weight": 1.0, "count": 0, "chunk_ids": []})

    # Merge entities
    for name, variants in entity_clusters.items():
        descriptions = []
        for variant in variants:
            if variant in graph_nodes:
                descriptions.extend(graph_nodes[variant]["descriptions"])
                if name != variant:
                    print(f"Merged nodes: {name} <- {variant}")

        new_key = tuple(variants)
        merged_graph_nodes[new_key] = {
            "canonical": name,
            "type": graph_nodes[name]["type"],
            "descriptions": descriptions,
            "chunk_ids": graph_nodes[name]["chunk_ids"]
        }
    
    # Reconnect relations
    for key, data in graph_edges.items():
        source, target = key
        canonical_source = name_to_canonical.get(source)
        canonical_target = name_to_canonical.get(target)

        if canonical_source and canonical_target:
            new_key = (canonical_source, canonical_target)
            merged_graph_edges[new_key]["descriptions"].extend(data["descriptions"]) 
            merged_graph_edges[new_key]["keywords"].extend(data["keywords"])
            merged_graph_edges[new_key]["weight"] += data["weight"] 
            merged_graph_edges[new_key]["count"] += 1 
            merged_graph_edges[new_key]["chunk_ids"].extend(data["chunk_ids"])
    
    for key in merged_graph_edges:
        if merged_graph_edges[new_key]["count"] > 1: 
            merged_graph_edges[new_key]["weight"] /= merged_graph_edges[new_key]["count"] 

    graph_nodes = merged_graph_nodes
    graph_edges = merged_graph_edges
    
    print(f"Number of nodes: {len(graph_nodes)}")
    print(f"Number of edges: {len(graph_nodes)}")

    # Summarization
    entity_or_relation_names = []
    entity_or_relation_descriptions = []

    for name, data in graph_nodes.items():
        entity_or_relation_names.append(", ".join(name))
        entity_or_relation_descriptions.append(" ".join(data["descriptions"]))
    for key, data in graph_edges.items():
        entity_or_relation_names.append(" -> ".join(key))
        entity_or_relation_descriptions.append(" ".join(data["descriptions"]))

    summaries = utils._handle_entity_relation_summary(entity_or_relation_names, entity_or_relation_descriptions, filename, llm)

    with open(processed_graph_cache_file, 'wb') as f:
        pickle.dump({
            'graph_nodes': dict(graph_nodes),
            'graph_edges': dict(graph_edges),
            'summaries': summaries
        }, f)
    print("Processed graph data cached successfully.")

In [None]:
# Storage: Networkx and FAISS
graph_storage_file = f"{filename} graph_storage.pkl"
if os.path.exists(graph_storage_file):
    print("Loading graph from cache...")
    with open(graph_storage_file, 'rb') as f:
        G = pickle.load(f)
    print("Success.")

else:
    print("Generating graph...")
    G = nx.DiGraph()
    nodes = []
    edges = []
    summary_index = 0

    for name_tuple, data in graph_nodes.items():
        summary = summaries[summary_index].content
        canonical = data["canonical"]

        node_id = str(uuid.uuid4()) # TODO

        # Add node to graph
        G.add_node(
                canonical, # Name as node for now
                graph_id=canonical, 
                type=data['type'], 
                description=summary,
                chunk_ids=data["chunk_ids"]
            )
        
        print(f"Adding node: {canonical} with summary: {summary}...")
        
        # Prepare node document
        nodes.append(
            Document(
                page_content=summary, 
                metadata={
                    "graph_id": canonical, # Name as id for now
                    "type": data["type"],
                    "value": summary,
                    "chunk_ids": " ".join([str(id) for id in data["chunk_ids"]])
                }
            )
        )
    
        summary_index += 1

    for key, data in graph_edges.items():
        summary = summaries[summary_index].content
        source, target = key

        edge_id = str(uuid.uuid4())

        # Add edge to graph
        G.add_edge(
                source, 
                target, 
                graph_id=edge_id, 
                description=summary, 
                keywords=data['keywords'], 
                weight=float(data['weight']),
                chunk_ids=data["chunk_ids"]
                )
        
        print(f"Adding edge: ({source}, {target}) with summary: {summary}...")
        
        # Prepare edge document
        edges.append(
            Document(
                page_content=summary,
                metadata={
                    "graph_id": edge_id, 
                    "value": summary,
                    "source": source, 
                    "target": target,
                    "weight": data["weight"],
                    "chunk_ids": " ".join([str(id) for id in data["chunk_ids"]])
                }
            )
        )
        summary_index += 1

    print("Adding nodes and edges to vectorstores...")

    # Populate vectorstore
    entities_vectorstore = FAISS.from_documents(documents=nodes, embedding=embedding_function)
    relations_vectorstore = FAISS.from_documents(documents=edges, embedding=embedding_function)

    print("Caching graph...")
    with open(f"{filename} graph.pkl", "wb") as f:
        pickle.dump(G, f)
    print("Success.")

In [None]:
# Graph visualization
plt.figure(figsize=(50, 50))

pos = nx.spring_layout(G, k=0.01, iterations=1000, seed=1172)
nx.draw(G, with_labels=True, font_size=12)

plt.show()

In [None]:
# Query keywords extraction

# TODO: query from input
query = "Compare and contrast the different theories about love mentioned in the text, illustrating how each of them helps conveying the main philosophical thought."

history = "" # TODO

prompt_template = PROMPTS["keywords_extraction"]
examples = PROMPTS["keywords_extraction_examples"]
prompt = prompt_template.format(
    examples=examples,
    history=history,
    query=query
)

response = llm.invoke(prompt)

try:
    keywords_data = json_repair.loads(response.content)
    if not keywords_data:
        print("No JSON-like structure found in the LLM response.")
except json.JSONDecodeError as e:
    print(f"JSON parsing error: {e}")
    print(f"LLM response: {keywords_data}")

hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])

In [None]:
# Keyword matching
# TODO: mode from input
query_mode = ""

local_entities = []
local_relations = []
global_entities = []
global_relations = []

if query_mode == "local":
    for ll_keyword in ll_keywords:
        local_entities.extend(entities_vectorstore.similarity_search(ll_keyword, k=3))
        graph_ids = [entity.metadata.get("graph_id") for entity in local_entities if entity.metadata.get("graph_id") is not None]
        
        for graph_id in graph_ids:
            for neighbor in G.neighbors(graph_id):
                edge_data = G.get_edge_data(graph_id, neighbor)
                if edge_data:
                    local_relations.append(edge_data)elif query_mode == "global":
    
    for hl_keyword in hl_keywords:
        global_relations.extend(relations_vectorstore.similarity_search(hl_keyword, k=3))
        graph_ids = [relation.metadata.get("graph_id") for relation in global_relations]

        for relation in global_relations:
            source = relation.metadata.get("source")
            target = relation.metadata.get("target")
            if source and source in G.nodes:
                global_entities.append(G.nodes[source])
            if target and target in G.nodes:
                global_entities.append(G.nodes[target])

else:
    for ll_keyword in ll_keywords:
        local_entities.extend(entities_vectorstore.similarity_search(ll_keyword, k=3))
        graph_ids = [entity.metadata.get("graph_id") for entity in local_entities if entity.metadata.get("graph_id") is not None]
        
        for graph_id in graph_ids:
            for neighbor in G.neighbors(graph_id):
                edge_data = G.get_edge_data(graph_id, neighbor)
                if edge_data:
                    local_relations.append(edge_data)
    
    for hl_keyword in hl_keywords:
        global_relations.extend(relations_vectorstore.similarity_search(hl_keyword, k=3))
        graph_ids = [relation.metadata.get("graph_id") for relation in global_relations]

        for relation in global_relations:
            source = relation.metadata.get("source")
            target = relation.metadata.get("target")
            if source and source in G.nodes:
                global_entities.append(G.nodes[source])
            if target and target in G.nodes:
                global_entities.append(G.nodes[target])

# Make all four lists type <class 'dict'>
local_entities = [relation.metadata for relation in local_entities]
global_relations = [entity.metadata for entity in global_relations]

In [None]:
# Round-robin merge 
# borrowed from https://github.com/HKUDS/LightRAG/blob/main/lightrag/operate.py
chunk_ids = set()

# Round-robin merge entities
final_entities = []
seen_ids = set()

max_len = max(len(local_entities), len(global_entities))
for i in range(max_len):
    # First from local
    if i < len(local_entities):
        entity = local_entities[i]
        entity_id = entity.get("graph_id")
        if entity_id and entity_id not in seen_ids:
            final_entities.append(entity)
            seen_ids.add(entity_id)

            chunk_ids.update(entity["chunk_ids"])
    # Then from global
    if i < len(global_entities):
        entity = global_entities[i]
        entity_id = entity.get("graph_id")
        if entity_id and entity_id not in seen_ids:
            final_entities.append(entity)
            seen_ids.add(entity_id)

            chunk_ids.update(entity["chunk_ids"])

# Round-robin merge relations
final_relations = []
seen_ids = set()

max_len = max(len(local_relations), len(global_relations))
for i in range(max_len):
    # First from local
    if i < len(local_relations):
        relation = local_relations[i]
        # Build relation unique identifier
        relation_id = relation.get("graph_id")
        if relation_id not in seen_ids:
            final_relations.append(relation)
            seen_ids.add(relation_id)

            chunk_ids.update(relation["chunk_ids"])

    # Then from global
    if i < len(global_relations):
        relation = global_relations[i]
        # Build relation unique identifier
        relation_id = relation.get("graph_id")
        if relation_id not in seen_ids:
            final_relations.append(relation)
            seen_ids.add(relation_id)

            chunk_ids.update(relation["chunk_ids"])

# Generate entities context
entities_context = []
for i, n in enumerate(final_entities):

    # Get file path from node data
    #file_path = n.get("file_path", "unknown_source")

    entities_context.append(
        {
            "id": i + 1,
            "entity": n["graph_id"],
            "type": n.get("type", "UNKNOWN"),
            "description": n.get("value", "UNKNOWN"),
            #"file_path": file_path,
        }
    )

# Generate relations context
relations_context = []
for i, e in enumerate(final_relations):

    # Get file path from edge data
    #file_path = e.get("file_path", "unknown_source")

    relations_context.append(
        {
            "id": i + 1,
            "entity1": e.get("source"),
            "entity2": e.get("target"),
            "description": e.get("value", "UNKNOWN"),
            #"file_path": file_path,
        }
    )

# Generate chunks context
merged_chunks = []
chunk_ids = [int(id) for id in chunk_ids] #fix original logic later
for chunk_id in chunk_ids:
    merged_chunks.append(
                        {
                            "content": chunks[chunk_id],
                            #"file_path": chunk.get("file_path", "unknown_source"),
                        }
                    )

context_data = {
    "entities": entities_context,
    "relations": relations_context,
    "chunks": merged_chunks
}

In [None]:
# Build context from context data
# borrowed from https://github.com/HKUDS/LightRAG/blob/main/lightrag/operate.py
if "entities" in context_data and context_data["entities"]:
        entities = context_data["entities"]
        entities_str = f"--- Entities ---\n{json.dumps(entities, indent=2, ensure_ascii=False)}\n"

if "relations" in context_data and context_data["relations"]:
    relations = context_data["relations"]
    relations_str = f"--- Relations ---\n{json.dumps(relations, indent=2, ensure_ascii=False)}\n"

if "chunks" in context_data and context_data["chunks"]:
    chunks = context_data["chunks"]
    chunks_str_list = []
    for chunk in chunks:
        chunks_str_list.append(f"Text: {chunk}")
    chunks_str = f"--- Text Chunks ---\n" + "\n\n".join(chunks_str_list)

context_str = f"{entities_str}{relations_str}{chunks_str}".strip()

In [None]:
# Final query
prompt_template = PROMPTS["rag_response"]
final_query = prompt_template.format(
    history="",
    context_data=context_str,
    response_type="",
    user_prompt=""
)

In [None]:
# Final generation
answer = llm.invoke(final_query)
print(answer)