In [1]:
import os
import json
import logging
import sys
import warnings
from typing import Dict, Any

# --- 1. GLOBAL SILENCING CONFIGURATION ---
os.environ["TQDM_DISABLE"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
warnings.filterwarnings("ignore")

# --- 2. LOGGING SETUP ---
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

noisy_loggers = ["sentence_transformers", "transformers", "urllib3", "requests", "huggingface_hub", "filelock", "tqdm"]
for logger_name in noisy_loggers:
    logging.getLogger(logger_name).setLevel(logging.ERROR)

# --- 3. IMPORTS ---
try:
    from causal_graph.builder import CausalGraphBuilder
    from causal_graph.retriever import CausalPathRetriever
    from causal_graph.explainer import CausalGraphExplainer
except ImportError:
    from builder import CausalGraphBuilder
    from retriever import CausalPathRetriever
    from explainer import CausalGraphExplainer

In [2]:
import re

class CausalRAGChain:
    def __init__(self, model_name: str = "all-mpnet-base-v2"):
        """
        Args:
            model_name: The SentenceTransformer model. 
        """
        self.logger = logging.getLogger(__name__)
        
        self.logger.info(f"Initializing Causal Graph Builder with model: {model_name}...")
        self.builder = CausalGraphBuilder(
            model_name=model_name, 
            normalize_nodes=True
        )
        self.retriever = None
        self.documents = []  # NEW: Store original documents for context lookup

    def load_graph_state(self, filepath: str):
        """Loads an existing graph state (nodes/edges) from JSON."""
        self.logger.info(f"Loading graph state from {filepath}...")
        if os.path.exists(filepath):
            success = self.builder.load(filepath)
            if success:
                self.logger.info(f"Graph loaded successfully: {self.builder.get_graph().number_of_nodes()} nodes.")
            else:
                self.logger.error("Failed to parse graph file. Starting with empty graph.")
        else:
            self.logger.warning(f"Graph file not found: {filepath}. Starting with empty graph.")
        
        self.retriever = CausalPathRetriever(self.builder)

    def save_graph_state(self, filepath: str):
        """Saves the current graph state to JSON."""
        self.builder.save(filepath)

    def ingest_wiki_knowledge(self, json_path: str, limit: int = None, auto_save_path: str = None):
        """
        Loads wiki json, stores raw text for retrieval, and builds the graph.
        """
        self.logger.info(f"Ingesting knowledge from {json_path}...")
        
        if not os.path.exists(json_path):
            self.logger.error(f"Knowledge base file not found: {json_path}")
            return

        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # Extract text
            self.documents = [] # Reset documents
            if isinstance(data, list):
                for item in data:
                    if 'raw_text' in item:
                        self.documents.append(item['raw_text'])
            
            if not self.documents:
                self.logger.warning("No 'raw_text' fields found in JSON.")
                return

            # Apply limit if specified
            if limit:
                self.documents = self.documents[:limit]
                self.logger.info(f"Limiting ingestion to first {limit} documents.")

            self.logger.info(f"Indexing {len(self.documents)} documents into the graph...")
            
            # Index documents into graph
            self.builder.index_documents(self.documents, show_progress=False)
            
            self.logger.info(f"Ingestion complete. Graph size: {self.builder.get_graph().number_of_nodes()} nodes.")
            self.retriever = CausalPathRetriever(self.builder)
            
            if auto_save_path:
                self.save_graph_state(auto_save_path)
            
        except Exception as e:
            self.logger.error(f"Error during ingestion: {e}")

    def _get_context_for_path(self, path: list[str], window_size: int = 300) -> str:
        """
        NEW: Finds the snippet in the source documents that contains the nodes in the path.
        This provides the 'narrative' context surrounding the causal arrow.
        """
        best_snippet = ""
        max_matches = 0
        
        # Convert path nodes to a set of keywords (lowercase for matching)
        path_keywords = [node.lower() for node in path]
        
        # Heuristic: Search documents for sentences containing the Cause and Effect
        for doc in self.documents:
            doc_lower = doc.lower()
            
            # Count how many path nodes appear in this document
            matches = sum(1 for keyword in path_keywords if keyword in doc_lower)
            
            if matches >= 2 and matches > max_matches:
                # If we find a document containing multiple nodes from the chain, extract context
                max_matches = matches
                
                # Find the position of the first keyword occurrence
                first_pos = doc_lower.find(path_keywords[0])
                if first_pos != -1:
                    start = max(0, first_pos - window_size)
                    end = min(len(doc), first_pos + window_size * 2)
                    best_snippet = f"...{doc[start:end]}..."
        
        return best_snippet if best_snippet else "Context not found in source text."

    def run(self, query: str):
        """Runs the retrieval chain with Context Enrichment."""
        if not self.retriever:
            self.retriever = CausalPathRetriever(self.builder)
            
        print(f"\nProcessing query: {query}")
        
        # 1. Retrieve Causal Paths (The "Skeleton" of the answer)
        paths = self.retriever.retrieve_paths(
            query, 
            max_paths=5, 
            min_path_length=2, 
            max_path_length=4
        )
        
        # 2. Retrieve Source Context (The "Flesh" of the answer)
        # We look up the original text for each path found
        context_blocks = []
        for i, path in enumerate(paths):
            arrow_chain = " -> ".join(path)
            source_snippet = self._get_context_for_path(path)
            
            block = (
                f"PATH {i+1}: {arrow_chain}\n"
                f"SOURCE CONTEXT: {source_snippet}\n"
            )
            context_blocks.append(block)
        
        paths_context_text = "\n".join(context_blocks)
        
        if not paths_context_text:
            paths_context_text = "No direct causal paths found in the knowledge graph."
            
        # 3. Enhanced Prompt
        prompt = f"""You are a Causal AI Expert. 
Using the provided Causal Paths and their Source Context, write a coherent, detailed answer.
Do not just list the paths; weave them into a narrative explanation.

USER QUERY: {query}

=== RETRIEVED CAUSAL EVIDENCE ===
{paths_context_text}
=================================

ANSWER:"""

        return {
            "query": query,
            "paths": paths, 
            "context_text": paths_context_text, # Return context for debugging
            "final_prompt": prompt
        }

In [3]:
if __name__ == "__main__":
    import os
    
    # --- Configuration ---
    GRAPH_STATE_FILE = "causal_math_graph_state_llm.json"
    WIKI_KB_FILE = "wiki_math_knowledge_base_api.json"
    OUTPUT_FILE = "rag_output_with_context.txt"
    
    # 1. Initialize Chain
    # We use the same model as before
    chain = CausalRAGChain(model_name="all-mpnet-base-v2")
    
    # 2. Load Existing Graph State
    # This loads the nodes and edges you've already built
    chain.load_graph_state(GRAPH_STATE_FILE)
    
    # 3. Ingest Data (CRITICAL STEP)
    # Even if the graph is loaded, we MUST run this to populate 'self.documents'
    # so the chain can look up the original text context.
    # We use limit=20 to match your previous test; remove 'limit' for full run.
    chain.ingest_wiki_knowledge(WIKI_KB_FILE, limit=20, auto_save_path=GRAPH_STATE_FILE)
    
    # 4. Define Queries
    queries = [
        'What happens when the circumcenter is on the side of the triangle?',
        "What influences the velocity of a Brownian particle?",
        "Tell me about surface tension and minimal surfaces."
    ]
    
    print(f"\nProcessing {len(queries)} queries... (Saving results to {OUTPUT_FILE})")
    
    # 5. Run and Save
    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        f.write("=== CAUSAL RAG RESULTS WITH SOURCE CONTEXT ===\n\n")
        
        for i, q in enumerate(queries, 1):
            # The run() method now returns 'context_text' containing the source snippets
            result = chain.run(q)
            
            output_block = []
            output_block.append(f"QUERY {i}: {result['query']}")
            output_block.append("-" * 40)
            
            # Display the Retrieved Evidence (Paths + Source Text)
            if result.get('context_text'):
                output_block.append("RETRIEVED EVIDENCE & CONTEXT:")
                output_block.append(result['context_text'])
            else:
                output_block.append("  [INFO]: No evidence found.")
            
            output_block.append("-" * 40)
            
            # Display the Final Prompt (What you would send to an LLM)
            output_block.append("FINAL GENERATED PROMPT:")
            output_block.append(result['final_prompt'])
            
            output_block.append("=" * 60 + "\n")
            
            # Write to file
            full_text = "\n".join(output_block)
            f.write(full_text)
            f.flush()
            
            print(f"Finished Query {i}")

    print(f"\nDone! Check '{OUTPUT_FILE}' to see the paths linked with their original text.")

2026-01-28 13:46:54,767 - INFO - Initializing Causal Graph Builder with model: all-mpnet-base-v2...
2026-01-28 13:46:55,214 - INFO - HTTP Request: HEAD https://huggingface.co/sentence-transformers/all-mpnet-base-v2/resolve/main/modules.json "HTTP/1.1 307 Temporary Redirect"
2026-01-28 13:46:55,224 - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/sentence-transformers/all-mpnet-base-v2/e8c3b32edf5434bc2275fc9bab85f82640a19130/modules.json "HTTP/1.1 200 OK"
2026-01-28 13:46:55,330 - INFO - HTTP Request: HEAD https://huggingface.co/sentence-transformers/all-mpnet-base-v2/resolve/main/config_sentence_transformers.json "HTTP/1.1 307 Temporary Redirect"
2026-01-28 13:46:55,342 - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/sentence-transformers/all-mpnet-base-v2/e8c3b32edf5434bc2275fc9bab85f82640a19130/config_sentence_transformers.json "HTTP/1.1 200 OK"
2026-01-28 13:46:55,459 - INFO - HTTP Request: HEAD https://huggingface.co/sentence-

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

2026-01-28 13:46:56,479 - INFO - HTTP Request: HEAD https://huggingface.co/sentence-transformers/all-mpnet-base-v2/resolve/main/config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-28 13:46:56,488 - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/sentence-transformers/all-mpnet-base-v2/e8c3b32edf5434bc2275fc9bab85f82640a19130/config.json "HTTP/1.1 200 OK"
2026-01-28 13:46:56,593 - INFO - HTTP Request: HEAD https://huggingface.co/sentence-transformers/all-mpnet-base-v2/resolve/main/tokenizer_config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-28 13:46:56,603 - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/sentence-transformers/all-mpnet-base-v2/e8c3b32edf5434bc2275fc9bab85f82640a19130/tokenizer_config.json "HTTP/1.1 200 OK"
2026-01-28 13:46:56,718 - INFO - HTTP Request: GET https://huggingface.co/api/models/sentence-transformers/all-mpnet-base-v2/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Fo