In [1]:
import os
import json
import logging
import sys
import warnings
from typing import Dict, Any

# --- 1. GLOBAL SILENCING CONFIGURATION ---
os.environ["TQDM_DISABLE"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
warnings.filterwarnings("ignore")

# --- 2. LOGGING SETUP ---
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

noisy_loggers = ["sentence_transformers", "transformers", "urllib3", "requests", "huggingface_hub", "filelock", "tqdm"]
for logger_name in noisy_loggers:
    logging.getLogger(logger_name).setLevel(logging.ERROR)

# --- 3. IMPORTS ---
try:
    from causal_graph.builder import CausalGraphBuilder
    from causal_graph.retriever import CausalPathRetriever
    from causal_graph.explainer import CausalGraphExplainer
except ImportError:
    from builder import CausalGraphBuilder
    from retriever import CausalPathRetriever
    from explainer import CausalGraphExplainer

In [5]:
class CausalRAGChain:
    def __init__(self, graph_file: str, embedding_model: str):
        """
        Initialize the RAG chain by loading a pre-built graph.
        """
        self.history: List[Dict[str, str]] = []
        
        # 1. Initialize Builder with the superior embedding model
        # This model will be used to generate embeddings for the nodes 
        # loaded from your JSON file.
        print(f"Initializing with model: {embedding_model}...")
        self.builder = CausalGraphBuilder(
            model_name=embedding_model,
            normalize_nodes=True
        )
        
        # 2. Load the Pre-built Graph
        # This populates nodes, edges, and regenerates embeddings 
        # for semantic search.
        print(f"Loading pre-built causal graph from: {graph_file}...")
        success = self.builder.load(graph_file)
        
        if success:
            print(f"Graph loaded successfully: {self.builder.get_graph().number_of_nodes()} nodes, {self.builder.get_graph().number_of_edges()} edges.")
        else:
            raise RuntimeError(f"Failed to load graph from {graph_file}")
        
        # 3. Initialize Retriever and Explainer
        self.retriever = CausalPathRetriever(self.builder)
        self.explainer = CausalGraphExplainer(
            self.builder.get_graph(), 
            self.builder.node_text
        )

    def query(self, user_question: str) -> Dict[str, Any]:
        """
        1. Embeds question.
        2. Retrieves relevant causal paths from the loaded graph.
        3. Generates an explanation.
        4. Logs the interaction.
        """
        print(f"\nProcessing Query: {user_question}")
        
        # Retrieve relevant paths (e.g., "smoking" -> "lung cancer")
        relevant_paths = self.retriever.retrieve_paths(user_question, max_paths=5)
        
        # Generate response
        if relevant_paths:
            answer_text = self.retriever.get_causal_explanation(user_question)
        else:
            # Fallback to node relevance if no connected paths found
            relevant_nodes = self.retriever.retrieve_nodes(user_question, top_k=5)
            if relevant_nodes:
                answer_text = "No complete causal chains found, but these concepts are relevant:\n"
                for node_id, score in relevant_nodes:
                    node_text = self.builder.node_text.get(node_id, node_id)
                    answer_text += f"- {node_text} (Relevance: {score:.2f})\n"
            else:
                answer_text = "No relevant causal information found in the graph."

        # Store interaction history
        interaction = {
            "question": user_question,
            "output": answer_text,
            "raw_paths": relevant_paths
        }
        self.history.append(interaction)
        
        return interaction



Initializing with model: sentence-transformers/all-mpnet-base-v2...
Loading pre-built causal graph from: causal_math_graph_llm.json...
Graph loaded successfully: 219 nodes, 151 edges.

Processing Query: What happens when the circumcenter is on the side of the triangle?
--------------------------------------------------
Q: What happens when the circumcenter is on the side of the triangle?
A: Causal relationships relevant to 'What happens when the circumcenter is on the side of the triangle?':

1. bent or extended or broken sides → a triangle changes shape
2. broken joints → a triangle changes shape
3. one pair of corresponding sides of two triangles are in the same proportion as another pair of corresponding sides, and their included angles have the same measure → triangles are similar


Processing Query: how many side of squares
--------------------------------------------------
Q: how many side of squares
A: Causal relationships relevant to 'how many side of squares':

1. formula for 

In [6]:
# ==========================================
# Execution Code
# ==========================================

if __name__ == "__main__":
    # Configuration
    GRAPH_FILE = "causal_math_graph_llm.json"
    MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
    
    try:
        # Instantiate the RAG Chain with the pre-built graph
        rag_chain = CausalRAGChain(GRAPH_FILE, MODEL_NAME)
        
        # Example queries based on the nodes in your JSON
        test_questions = [
            'What happens when the circumcenter is on the side of the triangle?',
            "how many side of squares",
            "what is the circumference of circle",
            "How is the area of a square related to its side length?",
            "What conditions make two triangles similar?",
            "What characterizes a tangential quadrilateral?",
            "What happens to a graph when a cut-set is removed?",
            "How is Brownian motion related to the diffusion constant?",
            "Why do soap films form minimal area surfaces?",
            "What is the relationship between snarks and the four-color theorem?"
        ]
        
        # Run queries
        for q in test_questions:
            result = rag_chain.query(q)
            print("-" * 50)
            print(f"Q: {result['question']}")
            print(f"A: {result['output']}")
            
        # Verify History Storage
        print("\n" + "="*20 + " Interaction History " + "="*20)
        print(f"Total interactions logged: {len(rag_chain.history)}")
        # Print the last interaction structure
        print(json.dumps(rag_chain.history[-1], indent=2))
        
    except Exception as e:
        print(f"An error occurred: {e}")

Initializing with model: sentence-transformers/all-mpnet-base-v2...
Loading pre-built causal graph from: causal_math_graph_llm.json...
Graph loaded successfully: 219 nodes, 151 edges.

Processing Query: What happens when the circumcenter is on the side of the triangle?
--------------------------------------------------
Q: What happens when the circumcenter is on the side of the triangle?
A: Causal relationships relevant to 'What happens when the circumcenter is on the side of the triangle?':

1. bent or extended or broken sides → a triangle changes shape
2. broken joints → a triangle changes shape
3. one pair of corresponding sides of two triangles are in the same proportion as another pair of corresponding sides, and their included angles have the same measure → triangles are similar


Processing Query: how many side of squares
--------------------------------------------------
Q: how many side of squares
A: Causal relationships relevant to 'how many side of squares':

1. formula for 