In [None]:
import json
import os
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from groq import Groq 
# Configuration & Setup
EMBEDDED_CHUNKS_FILE = Path("data/embedded_chunks.json")

# API KEY FOR GROQ

client = Groq(api_key="replace your key here") # Make sure your actual key is here

# Load Data
def load_embedded_chunks():
    """Loads embedded chunks from the JSON file."""
    if not EMBEDDED_CHUNKS_FILE.exists():
        print(f"Error: Embedded chunks file not found at {EMBEDDED_CHUNKS_FILE}")
        return None
    with open(EMBEDDED_CHUNKS_FILE, "r", encoding="utf-8") as f:
        return json.load(f)

# Semantic Search
try:
    from sentence_transformers import SentenceTransformer
    embed_model = SentenceTransformer("all-MiniLM-L6-v2")
except ImportError:
    print("SentenceTransformer not found. Please ensure it's installed (pip install sentence-transformers).")
    print("Assuming `embed_model` is provided by the execution environment or `embedder.ipynb` has been run.")
    embed_model = None

def semantic_search(query_embedding, embedded_data, top_k=15):
    """Performs semantic search on embedded chunks using a query embedding."""
    if embedded_data is None:
        return []

    if embed_model is None:
        print(" Error: Embedding model not initialized. Cannot perform semantic search.")
        return []

    similarities = []
    for chunk in embedded_data:
        if "embedding" in chunk and chunk["embedding"] is not None:
            chunk_embedding = np.array(chunk["embedding"])

            if query_embedding.ndim == 1:
                q_emb = query_embedding.reshape(1, -1)
            else:
                q_emb = query_embedding

            if chunk_embedding.ndim == 1:
                c_emb = chunk_embedding.reshape(1, -1)
            else:
                c_emb = chunk_embedding

            score = cosine_similarity(q_emb, c_emb)[0][0]
            similarities.append((chunk, score))

    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:top_k]

# Query Decomposition Function
def decompose_query(user_query, llm_client):
    """Uses LLM to decompose a complex query into simpler sub-queries."""
    prompt_decompose = f"""
    You are a financial query analyzer. Your task is to analyze a user's question about financial filings
    and determine if it requires decomposition into multiple sub-questions.

    If the question is simple and can be answered with a single direct search (e.g., "What was Microsoft's revenue in 2023?"),
    return only the original question in a JSON object with the key "sub_queries".

    If the question is complex, comparative, or requires information from multiple companies/years
    (e.g., "How did NVIDIA's data center revenue grow from 2022 to 2023?", "Which company had the highest operating margin in 2023?"),
    break it down into specific, atomic sub-questions. Each sub-question should be answerable by a single retrieval.

    Consider company names (Microsoft, Google, NVIDIA) and fiscal years (e.g., 2022, 2023, 2024) for decomposition.

    Format your output strictly as a JSON object with a single key "sub_queries" whose value is a JSON list of strings.
    Do NOT include any conversational text, preamble, or any other keys. Only the JSON object.

    Example 1 (Simple):
    Question: What was Google's total revenue in 2023?
    Output: {{"sub_queries": ["What was Google's total revenue in 2023?"]}}

    Example 2 (Comparative):
    Question: How did NVIDIA's data center revenue grow from 2022 to 2023?
    Output: {{"sub_queries": ["NVIDIA data center revenue 2022", "NVIDIA data center revenue 2023"]}}

    Example 3 (Cross-Company):
    Question: Which company had the highest operating margin in 2023?
    Output: {{"sub_queries": ["Microsoft operating margin 2023", "Google operating margin 2023", "NVIDIA operating margin 2023"]}}

    Example 4 (Complex Multi-aspect):
    Question: Compare cloud revenue growth rates across all three companies from 2022 to 2023
    Output: {{
        "sub_queries": [
            "Microsoft cloud revenue 2022",
            "Microsoft cloud revenue 2023",
            "Google cloud revenue 2022",
            "Google cloud revenue 2023",
            "NVIDIA data center revenue 2022", # NVIDIA often reports data center revenue instead of "cloud"
            "NVIDIA data center revenue 2023"
        ]
    }}

    Question: {user_query}
    Output:
    """
    try:
        response = llm_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt_decompose}],
            model="llama3-8b-8192", # Using a smaller model for decomposition might be faster
            response_format={"type": "json_object"} # Re-added this
        )
        sub_queries_raw_str = response.choices[0].message.content
        parsed_output = json.loads(sub_queries_raw_str)

        # Now we strictly expect a dictionary with a "sub_queries" key
        if isinstance(parsed_output, dict) and "sub_queries" in parsed_output and isinstance(parsed_output["sub_queries"], list):
            sub_queries = parsed_output["sub_queries"]
        else:
            raise ValueError(f"LLM did not return a JSON object with 'sub_queries' key as expected. Got: {parsed_output}")

        if not sub_queries: # If list is empty after parsing
            raise ValueError("No sub-queries extracted from LLM response.")

        return sub_queries
    except json.JSONDecodeError as e:
        print(f"JSON decoding error, LLM output might not be valid JSON: {sub_queries_raw_str}. Error: {e}")
        return [user_query] # Fallback
    except Exception as e:
        print(f"Error decomposing query, defaulting to original query: {e}")
        return [user_query] # Fallback

# Agent Query Engine
def agent_query(user_query, top_k_per_subquery=3):
    """
    Main function for the RAG agent to answer a query.
    1. Decomposes complex queries into sub-queries.
    2. Performs multi-step semantic search for each sub-query.
    3. Constructs a prompt with the user query and combined retrieved context.
    4. Uses a generative model (Groq) to synthesize an answer.
    """

    embedded_chunks = load_embedded_chunks()
    if embedded_chunks is None:
        return {"query": user_query, "answer": "Could not load embedded chunks.", "reasoning": "", "sub_queries": [], "sources": []}

    all_retrieved_chunks_with_scores = []
    reasoning_steps = []
    final_sub_queries = [] # To store the actual sub-queries used

    # 1. Query Decomposition
    sub_queries = decompose_query(user_query, client)
    final_sub_queries.extend(sub_queries)

    print(f"\n--- Decomposed Sub-queries for '{user_query}': ---")
    for sq_idx, sq_val in enumerate(sub_queries):
        print(f"  {sq_idx + 1}. {sq_val}")
    print("--------------------------------------------------")

    if len(sub_queries) > 1:
        reasoning_steps.append(f"Original query decomposed into {len(sub_queries)} sub-queries.")
    else:
        reasoning_steps.append("Original query processed as a single search.")

    # 2. Multi-step Retrieval
    for sq in sub_queries:
        reasoning_steps.append(f"Retrieving information for sub-query: '{sq}'")
        if embed_model is None:
            reasoning_steps.append("Embedding model not available, skipping semantic search for sub-query.")
            continue
        sq_embedding = embed_model.encode(sq)
        sq_top_results = semantic_search(sq_embedding, embedded_chunks, top_k=top_k_per_subquery)
        all_retrieved_chunks_with_scores.extend(sq_top_results)

    print("\n--- All Retrieved Chunks (before deduplication): ---")
    if all_retrieved_chunks_with_scores:
        for chunk_score_pair in all_retrieved_chunks_with_scores:
            chunk = chunk_score_pair[0]
            score = chunk_score_pair[1]
            print(f"  Company: {chunk.get('company', 'N/A')}, Year: {chunk.get('year', 'N/A')}, Score: {score:.4f}, Excerpt: {chunk.get('text', '')[:100]}...")
    else:
        print("  No chunks retrieved for any sub-query.")
    print("----------------------------------------------------")


    if not all_retrieved_chunks_with_scores:
        return {"query": user_query, "answer": "No relevant chunks found across all sub-queries.", "reasoning": "Semantic search returned no results.", "sub_queries": final_sub_queries, "sources": []}

    # Deduplicate and sort collected chunks by score (highest first)
    # Use a set to track unique chunk texts to avoid redundancy
    unique_chunks_map = {}
    for chunk, score in all_retrieved_chunks_with_scores:
        # Use a unique identifier for the chunk, e.g., combination of company, year, and text hash
        chunk_id = (chunk.get("company"), chunk.get("year"), hash(chunk.get("text", "")))
        if chunk_id not in unique_chunks_map or score > unique_chunks_map[chunk_id][1]:
            unique_chunks_map[chunk_id] = (chunk, score)

    deduplicated_top_results = sorted(unique_chunks_map.values(), key=lambda x: x[1], reverse=True)

    # Prepare context for the LLM from deduplicated results
    top_chunks_text = [chunk.get("text", "") for chunk, _ in deduplicated_top_results]
    context = "\n\n---\n\n".join(top_chunks_text)

    # Prepare sources for attribution (using .get() for safety)
    sources = [{
        "company": chunk.get("company", "N/A"),
        "year": chunk.get("year", "N/A"),
        "fiscal_year": chunk.get("fiscal_year", "N/A"), # Include fiscal_year if available
        "excerpt": chunk.get("text", "")[:300] + ("..." if len(chunk.get("text", "")) > 300 else ""),
        "score": float(score)
    } for chunk, score in deduplicated_top_results]

    system_prompt = """
    You are an AI financial analyst. Answer the user's question precisely based on the provided context from 10-K filings.
    If the question involves comparison or calculation, perform it accurately and show your work if appropriate.
    State any numerical values clearly, including units (e.g., billions, millions) and percentages.
    If information for a specific company or year is not available in the context, state that clearly for that specific entity.
    Do not make up information. If you cannot answer the question from the provided context, state that you do not have enough information.
    """

    user_prompt = f"""
    Question: {user_query}

    Context from 10-K filings:
    {context}

    Synthesize a comprehensive answer based ONLY on the provided context.
    If the question involved comparing multiple entities or periods, provide a clear and concise comparison, including specific figures and growth rates if found.
    """

    try:
        chat_completion = client.chat.completions.create(
            messages=[{"role": "system", "content": system_prompt},
                      {"role": "user", "content": user_prompt}],
            model="llama3-8b-8192",
            temperature=0.0,
            max_tokens=1000, # Increase max_tokens for comprehensive answers
        )
        answer = chat_completion.choices[0].message.content.strip()
        reasoning_steps.append("Used Groq model to synthesize answer from combined retrieved context.")
    except Exception as e:
        answer = "Could not generate answer due to LLM error: " + str(e) + ". Please check your Groq API key and model access."
        reasoning_steps.append(f"LLM synthesis failed: {e}")

    return {
        "query": user_query,
        "answer": answer,
        "reasoning": "\n".join(reasoning_steps),
        "sub_queries": final_sub_queries, # Include the generated sub-queries
        "sources": sources
    }

# Example Usage
if __name__ == '__main__':

    import pprint # Import pprint here if not at the top

    print("\n--- Testing with a simple query ---")
    query_simple = "What was Microsoft's total revenue in 2023?"
    try:
        response_simple = agent_query(query_simple)
        pprint.pprint(response_simple)
    except Exception as e:
        print(f"Error during simple query execution: {e}")

    print("\n--- Testing with a comparative query (YoY growth) ---")
    query_yoy = "How did NVIDIA's data center revenue grow from 2022 to 2023?"
    try:
        response_yoy = agent_query(query_yoy)
        pprint.pprint(response_yoy)
    except Exception as e:
        print(f"Error during comparative query execution: {e}")

    print("\n--- Testing with a complex cross-company/segment query ---")
    query_complex = "Compare cloud revenue growth rates across all three companies (Google, Microsoft, NVIDIA) from 2022 to 2023."
    try:
        response_complex = agent_query(query_complex)
        pprint.pprint(response_complex)
    except Exception as e:
        print(f"Error during complex query execution: {e}")

    print("\n--- Testing with a segment analysis query ---")
    query_segment_analysis = "What percentage of Google's revenue came from cloud in 2023?"
    try:
        response_segment_analysis = agent_query(query_segment_analysis)
        pprint.pprint(response_segment_analysis)
    except Exception as e:
        print(f"Error during segment analysis query execution: {e}")

    print("\n--- Testing with an AI Strategy query ---")
    query_ai_strategy = "Compare AI investments mentioned by all three companies (Google, Microsoft, NVIDIA) in their 2024 10-Ks."
    try:
        response_ai_strategy = agent_query(query_ai_strategy)
        pprint.pprint(response_ai_strategy)
    except Exception as e:
        print(f"Error during AI strategy query execution: {e}")
