In [None]:
import os
from src.vectorstore import VectorstoreHandler
from src.models import init_emb, init_llm
from src.rag_tools import build_rag_chain
import textwrap

# TODO: The general config is quite shit. Make it so that it automatically identifies subfodlders to sources_dir, and there should be one config file in each. if is none, then create a standard config file. 
# TODO: Fix how to deal with system prompts for each rag.
# TODO: Wrap aggregate model to better deal with different providers.
# TODO: Consider options to plaintext. Maybe enforcing latex is better, to prepare for mathematical notation and code? Research options. 
# TODO: Modularize

SOURCES_DIR = "sources"

# Config JSON for RAG specifications
RAG_CONFIG = {
    "book": {
        "subfolder": "book",
        "llm_name": "Llama3.2-3b",
        "emb_name": "hf-minilm-l6-v2",
        "k": 10,
    },
    "lectures": {
        "subfolder": "lectures",
        "llm_name": "Llama3.2-3b",
        "emb_name": "hf-minilm-l6-v2",
        "k": 10,
    },
}

# Functions 
def create_rag(config, handler):
    """
    Create a RAG chain based on the provided configuration.

    Args:
        config (dict): Configuration for the RAG (subfolder, LLM, embedding, k, etc.).
        handler (VectorstoreHandler): Handler for managing vectorstores.

    Returns:
        tuple: A tuple containing the RAG chain and its retriever.
    """
    subfolder = config["subfolder"]
    llm_name = config["llm_name"]
    emb_name = config["emb_name"]
    k = config["k"]

    # Paths and model initialization
    dir_path = os.path.join(SOURCES_DIR, subfolder)
    emb = init_emb(emb_name)
    llm = init_llm(llm_name)

    # Build vectorstore and retriever
    vs = handler.build_vectorstore(dir_path, emb, emb_name)
    retriever = handler._init_retriever(vs, dir_path, k)

    # Build and return the RAG chain
    chain = build_rag_chain(retriever, llm)
    return chain, retriever

    import textwrap

def generate_answers(prompt, rags):

    rag_responses = {}
    print(f"\nQuerying all RAGs for prompt: '{prompt}'\n{'=' * 80}")
    for rag_name, (chain, retriever) in rags.items():
        print(f"--- Querying RAG: {rag_name} ---")
        output = chain.invoke(prompt)
        # Store the raw response and associated documents
        rag_responses[rag_name] = {
            "answer": output["answer"],
            "docs": output["docs"],
        }
    return rag_responses

def generate_and_display_aggregate_response(aggregation_llm, agent_answers, prompt):
    """
    Generates an aggregated response from the answers of individual agents and prints it.
    
    Args:
        aggregation_llm: The LLM used for aggregating responses.
        agent_answers (dict): A dictionary with agent names as keys and their answers as values.
        prompt (str): The prompt provided to all agents.

    Returns:
        None
    """
    # Step 1: Construct the aggregation prompt
    aggregation_prompt = f"Here are the answers from different sources for the prompt: '{prompt}':\n"
    for agent_name, answer in agent_answers.items():
        aggregation_prompt += f"\n[{agent_name}] {answer}\n"
    aggregation_prompt += (
        "\nPlease provide a concise, technical, and structured summary of the above, "
        "formatted for terminal display with a wrapping width of 80 characters."
        "If the answer contains mathematical notation, default to using LaTeX-syntax."
    )

    # Step 2: Generate the aggregated response
    aggregated_output = aggregation_llm.invoke(aggregation_prompt)
    if isinstance(aggregated_output, str):
        aggregated_answer = aggregated_output
    else:
        aggregated_answer = aggregated_output.content['answer']

    # Step 3: Print the aggregated answer first
    print("\nAggregated Answer:\n" + "=" * 80)
    wrapped_aggregated_answer = textwrap.fill(aggregated_answer, width=80)
    print(wrapped_aggregated_answer + "\n")

    # Step 4: Print detailed answers from each agent
    print("\nDetailed Responses from Agents:\n" + "=" * 80)
    for agent_name, answer in agent_answers.items():
        if isinstance(answer, str):
            answer = answer
        else:
            answer = answer['answer']

        wrapped_answer = textwrap.fill(f"[{agent_name}] {answer}", width=80)
        print(wrapped_answer + "\n")

In [None]:
# Initialize the vectorstore handler and construct the rags
handler = VectorstoreHandler(SOURCES_DIR, force_rebuild=False)

rags = {}
for rag_name, rag_config in RAG_CONFIG.items():
    print(f"Setting up RAG: {rag_name}")
    rags[rag_name] = create_rag(rag_config, handler)

In [23]:
AGGREGATION_LLM_NAME = "Llama3.2-3b"  # Change to your preferred LLM
aggregation_llm = init_llm(AGGREGATION_LLM_NAME)

# Example usage: querying all RAGs
test_prompt = "What are the assumptions used for krieging?"

rag_responses = generate_answers(test_prompt,rags)
generate_and_display_aggregate_response(aggregation_llm,rag_responses, test_prompt)


Querying all RAGs for prompt: 'What are the assumptions used for krieging?'
--- Querying RAG: book ---
--- Querying RAG: lectures ---

Aggregated Answer:
**Kriging**  Kriging is a method for estimating the value of a random field at
unobserved locations, given observed data.  **Key Concepts**  * **Joint Gaussian
distribution**: The underlying assumption that the observed and unobserved
values are jointly Gaussian. * **Conditional mean**: The expected value of the
unobserved values given the observed values, denoted as E[X |X ]. * **Kriging
predictor**: The conditional mean, which is a weighted sum of the observed
values.  **Mathematical Notation**  Let X and X be jointly Gaussian random
fields with means (µ, µ) and covariances (Σ, Σ). Then, the kriging predictor at
unobserved locations s is given by:  E[X |X ] = µ + Σ Σ−1(X −µ )  **Properties**
* **Point prediction**: The kriging predictor provides a point estimate for the
value of the random field at the unobserved location. * **Cond