In [9]:
import os

# Manually set the __file__ variable to the notebook's directory
__file__ = os.path.abspath("notebook_name.ipynb")


In [10]:
"""
CFOLytics_reportgenerator_no_manual_fields.py

LangGraph workflow that:
1) Takes an initial user prompt in the conversation (state["messages"]).
2) Checks clarity using LLM instructions from 'verify_instructions.xml'.
3) If unclear, LLM asks clarifying question -> user responds in the conversation -> we parse user’s new message -> re-check clarity.
4) If clear, LLM generates a layout (render_layout.xml).
5) LLM then asks user “Is this layout okay?” -> user answers in conversation -> we parse yes/no from the conversation -> if no, regenerate, if yes, proceed.
6) Generate components, unify lists, finalize JSON.

No need for manually adding `clarification_answer` or `layout_confirm` to the state. The conversation itself is the source of truth.

Requires:
- langgraph
- langchain-core
- langchain-community
- langchain-openai
- Your custom ChatGroq, or whichever LLM wrapper you use
"""

import os
import json
import re
from typing import List
from typing_extensions import TypedDict

# LangChain / LangGraph
from langchain_core.messages import (
    AIMessage, 
    HumanMessage,
    SystemMessage,
    BaseMessage
)
from langchain_openai import ChatOpenAI
from langgraph.constants import START, END, Send
from langgraph.graph import StateGraph, MessagesState


# Example: your custom ChatGroq usage
from langchain_groq import ChatGroq
llm = ChatGroq(
    temperature=0,
    model_name="llama-3.3-70b-versatile",
    api_key="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"
)
# llm = ChatOpenAI(model="gpt-4o",
#         temperature=0)

def load_xml_instructions(filename: str) -> str:
    """
    Load system instructions from 'XML_instructions/filename' if you keep them externally.
    Otherwise, just inline your prompts as strings.
    """
    current_dir = os.path.dirname(os.path.abspath(__file__))
    file_path = os.path.join(current_dir, "XML_instructions", filename)
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()

class FinalReportState(TypedDict):
    """
    Our final JSON structures and clarity status.
    """
    instructions_clear: bool
    layout_json: dict
    final_json: dict

class ReportGraphState(MessagesState, FinalReportState):
    """
    Merges the base conversation messages plus our custom fields.
    'messages' is a list of SystemMessage, HumanMessage, or AIMessage.
    """
    pass

# -----------------------------------------------------------------------------
# 1) verify_instructions
# -----------------------------------------------------------------------------

def verify_instructions(state: ReportGraphState):
    """
    Node checks if the conversation so far implies the instructions are clear or not.
    We load instructions from 'verify_instructions.xml'.
    - LLM appends a final line "clear" or "not clear" or "unclear" which we parse.
    - We store the LLM output in the conversation.
    """
    system_instructions = load_xml_instructions("verify_instructions.xml")
    system_msg = SystemMessage(content=system_instructions)

    # We pass the entire conversation plus the system instructions.
    conversation = [system_msg] + state["messages"]
    result = llm.invoke(conversation)

    # Store the LLM's analysis as an AIMessage
    state["messages"].append(AIMessage(content=result.content, name="clarity-check"))
    
    text_lower = result.content.lower()
    if "not clear" in text_lower or "unclear" in text_lower:
        return {"instructions_clear": False}
    return {"instructions_clear": True}


# -----------------------------------------------------------------------------
# 2) ask_clarification
# -----------------------------------------------------------------------------
def ask_clarification(state: ReportGraphState):
    """
    LLM asks the user clarifying questions. We store that question in the conversation.
    """
    system_instructions = load_xml_instructions("clarification_prompt.xml")
    system_msg = SystemMessage(content=system_instructions)

    conversation = [system_msg] + state["messages"]
    llmvers = ChatGroq(temperature=0, model_name="llama-3.3-70b-versatile", api_key="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")
    result = llmvers.invoke(conversation)

    # Append the AI’s clarifying question
    question_msg = AIMessage(content=result.content, name="clarification_question")
    state["messages"].append(question_msg)

    return {}  # No direct state changes, just updated conversation


# -----------------------------------------------------------------------------
# 3) get_user_clarification
# -----------------------------------------------------------------------------
def get_user_clarification(state: ReportGraphState):
    idx_question = None
    for i, msg in reversed(list(enumerate(state["messages"]))):
        if isinstance(msg, AIMessage) and msg.name == "clarification_question":
            idx_question = i
            break
    if idx_question is None:
        return {}

    for j in range(idx_question+1, len(state["messages"])):
        msg = state["messages"][j]
        if isinstance(msg, HumanMessage):
            # User responded, proceed to next node
            return {}

    # Remain in this node until the user responds
    return None  # Signal to remain in the current node


# -----------------------------------------------------------------------------
# 4) generate_layout_json
# -----------------------------------------------------------------------------
def generate_layout_json(state: ReportGraphState):
    """
    Generate a JSON layout using structured output (json_mode).
    """
    # Load system instructions for generating the layout
    system_instructions = load_xml_instructions("render_layout.xml")
    system_msg = SystemMessage(content=system_instructions)

    # Combine system message with conversation history
    conversation = [system_msg] + state["messages"]

    from langchain_core.pydantic_v1 import BaseModel, Field

    # Define the schema for the LLM's output
    class LayoutConfig(BaseModel):
        gridColumns: dict
        rows: list

    class ReportConfig(BaseModel):
        reportTitle: str = Field(alias="reportTitle")
        layout: LayoutConfig
        numberFormat: dict = Field(alias="numberFormat")

        class Config:
            allow_population_by_field_name = True

    # Initialize the structured LLM for parsing
    structured_llm = llm.with_structured_output(
        ReportConfig,
        method="json_mode",
        include_raw=True
    )

    # Invoke the LLM and capture the structured output
    output = structured_llm.invoke(conversation)

    # Extract the parsed output
    parsed_output = output.get("parsed", None)

    if parsed_output:
        # Store the parsed layout in the state
        state["layout_json"] = parsed_output.dict(by_alias=True)
    else:
        # Handle parsing failure
        state["layout_json"] = {
            "error": "Failed to parse layout",
            "raw_output": output.raw if "raw" in output else None
        }

    # Return the updated state
    return {"layout_json": state["layout_json"]}



In [11]:
# -----------------------------------------------------------------------------
# 8) identify_and_unify_lists
# -----------------------------------------------------------------------------
def identify_and_unify_lists(state: ReportGraphState):
    """
    Placeholder for list unification. We'll just pass layout through.
    """
    layout_json = state.get("layout_json", {})
    return {"layout_json": layout_json}


# -----------------------------------------------------------------------------
# 9) create_lists_contents
# -----------------------------------------------------------------------------
def create_lists_contents(state: ReportGraphState):
    """
    Finds 'lists' keys and populates them with dimension members. Hard-coded example.
    """
    layout_json = state.get("layout_json", {})
    if not layout_json:
        return {}

    lists_found = []
    def walk(obj):
        if isinstance(obj, dict):
            if "lists" in obj and isinstance(obj["lists"], list):
                for l_ in obj["lists"]:
                    lists_found.append(l_)
            for v in obj.values():
                walk(v)
        elif isinstance(obj, list):
            for v in obj:
                walk(v)
    walk(layout_json)

    for item in lists_found:
        item["list"] = ["Jan", "Feb", "Mar"]
        if "AI Generation Description" not in item:
            item["AI Generation Description"] = "Populated with months for demonstration."

    return {"layout_json": layout_json}


# -----------------------------------------------------------------------------
# 10) finalize_report_json
# -----------------------------------------------------------------------------
def finalize_report_json(state: ReportGraphState):
    """
    Copy layout_json => final_json
    """
    layout_json = state.get("layout_json", {})
    return {"final_json": layout_json}


In [12]:
def process_component(state: dict):
    """
    Process a single component using the LLM with structured output and explicitly update the state.
    """
    from langchain_core.pydantic_v1 import BaseModel, Field
    from typing import Optional

    # Define a flexible schema for the component configuration
    class ComponentConfig(BaseModel):
        config: Optional[dict] = None

    # Load system instructions
    comp_instructions = load_xml_instructions("component_content_gen.xml")
    system_msg = SystemMessage(content=comp_instructions)

    # Extract component description
    desc = state.get("AI Generation Description", "No AI Generation Description was provided.")

    # Prepare the conversation
    conversation = [system_msg] + [HumanMessage(content=desc, name="component-desc")]

    # Use the LLM with structured output
    structured_llm = llm.with_structured_output(
        ComponentConfig,
        method="json_mode",
        include_raw=True
    )
    output = structured_llm.invoke(conversation)

    # Extract parsed output or handle errors
    parsed_output = output.get("parsed", None)
    if parsed_output:
        config = parsed_output.dict(by_alias=True)
    else:
        config = {
            "error": "Failed to parse component configuration",
            "raw_output": output.raw if "raw" in output else None
        }

    # Return the result as an explicit state update
    return {"layout_json": config}


In [13]:
def generate_components_config(state: ReportGraphState):
    """
    1) Gathers all components from layout_json.
    2) Returns a list of Send(...) tasks for each component.
       The main graph engine will handle invoking "process_component"
       in parallel for each item and pass the combined results
       in state["tasks"] to the next node.
    """
    layout_json = state.get("layout_json", {})
    if not layout_json or "error" in layout_json:
        return {"layout_json": layout_json}  # no-op

    # Collect all "components" from layout_json
    components = []
    def walk(obj):
        if isinstance(obj, dict):
            if "components" in obj and isinstance(obj["components"], list):
                components.extend(obj["components"])
            for v in obj.values():
                walk(v)
        elif isinstance(obj, list):
            for v in obj:
                walk(v)

    walk(layout_json)

    # Return an array of parallel tasks:
    # => The engine will run "process_component" once per component
    return [Send("process_component", component) for component in components]


def merge_components_config(state: ReportGraphState):
    """
    1) The LangGraph engine collects the results of all parallel
       calls to 'process_component' in state["tasks"].
    2) We zip them with the original 'components' in layout_json
       and inject the returned .config back into layout_json.
    """
    layout_json = state.get("layout_json", {})
    tasks_results = state.get("tasks", [])

    if not tasks_results:
        # means no parallel tasks or no components
        return {"layout_json": layout_json}

    # We must gather the same set of 'components' again to merge them properly
    components = []
    def walk(obj):
        if isinstance(obj, dict):
            if "components" in obj and isinstance(obj["components"], list):
                components.extend(obj["components"])
            for v in obj.values():
                walk(v)
        elif isinstance(obj, list):
            for v in obj:
                walk(v)
    walk(layout_json)

    # Merge each result into the correct component
    for component, result in zip(components, tasks_results):
        new_config = result.get("config", {})
        walk_and_update(layout_json, component["id"], new_config)

    state["layout_json"] = layout_json
    return {"layout_json": layout_json}


def walk_and_update(obj, comp_id, new_config):
    """
    Helper that finds the dict whose "id" == comp_id and updates 'config'.
    """
    if isinstance(obj, dict):
        if obj.get("id") == comp_id:
            obj["config"] = new_config
        for v in obj.values():
            walk_and_update(v, comp_id, new_config)
    elif isinstance(obj, list):
        for v in obj:
            walk_and_update(v, comp_id, new_config)


In [20]:
# -----------------------------------------------------------------------------
# Build the Graph
# -----------------------------------------------------------------------------
builder = StateGraph(ReportGraphState)

# 1. START => verify_instructions
builder.add_node("verify_instructions", verify_instructions)
builder.add_edge(START, "verify_instructions")

# 2. Decide: If instructions_clear => generate_layout_json, else => ask_clarification
def instructions_decider(state: ReportGraphState):
    return "generate_layout_json" if state["instructions_clear"] else "ask_clarification"

builder.add_conditional_edges(
    "verify_instructions",
    instructions_decider,
    ["generate_layout_json", "ask_clarification"]
)

builder.add_node("ask_clarification", ask_clarification)
builder.add_node("get_user_clarification", get_user_clarification)
builder.add_node("generate_layout_json", generate_layout_json)
builder.add_node("merge_components_config", merge_components_config)


# Add process_component subgraph
component_builder = StateGraph(dict)
component_builder.add_node("process_component", process_component)
component_builder.add_edge(START, "process_component")
component_builder.add_edge("process_component", END)

process_component_subgraph = component_builder.compile()

# Add nodes for generate_components_config and other steps
builder.add_node("generate_components_config", generate_components_config)

builder.add_node("identify_and_unify_lists", identify_and_unify_lists)
builder.add_node("create_lists_contents", create_lists_contents)
builder.add_node("finalize_report_json", finalize_report_json)

# Add edges
builder.add_edge("ask_clarification", "get_user_clarification")
builder.add_edge("get_user_clarification", "verify_instructions")
builder.add_edge("generate_layout_json", "generate_components_config")
builder.add_edge("generate_components_config", "merge_components_config")
builder.add_edge("merge_components_config", "identify_and_unify_lists")

# Continue with the rest of the flow
builder.add_edge("identify_and_unify_lists", "create_lists_contents")
builder.add_edge("create_lists_contents", "finalize_report_json")
builder.add_edge("finalize_report_json", END)

# Finally, compile without using interrupt_before, because we rely on conversation-based logic
graph = builder.compile(interrupt_before=["get_user_clarification"])


In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
state = ReportGraphState(
    messages=[
        HumanMessage(content="Create a report showing the profit and loss in a table comparing actuals to budget. Next to the table I want to see a chart with 12 periods comparing Actuals to Budget for the current selected row in the table. Below the chart I want to see a small table breaking down the current selected line in to the product dimension.", name="user")
    ],
    instructions_clear=False,
    layout_json={},
    final_json={}
)

# Simulate clarification input by overriding `await_clarification_answer` in graph
result_state = graph.invoke(state)



In [None]:
from IPython.display import Markdown
if "layout_json" in result_state:
    display(Markdown(f"**Layout JSON**:\n```json\n{json.dumps(result_state['layout_json'], indent=2)}\n```"))
else:
    display(Markdown("**No Layout JSON Generated**"))