In [1]:
from decimal import Decimal
import math
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
from core.constants import ChatMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
import uuid

load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

_printed = set()
thread_id = str(uuid.uuid4())

def _print_event(event: dict, _printed: set, max_length=1500):
    current_state = event.get("dialog_state")
    if current_state:
        print("Currently in: ", current_state[-1])
    message = event.get("messages")
    if message:
        if isinstance(message, list):
            message = message[-1]
        if message.id not in _printed:
            msg_repr = message.pretty_repr(html=True)
            if len(msg_repr) > max_length:
                msg_repr = msg_repr[:max_length] + " ... (truncated)"
            print(msg_repr)
            _printed.add(message.id)
            


template = """You are an AI trained in Retail Store Data. You job is to first generate sql query and get the desired answer by running the generated query based on tools provided.
Format any currency values into Indian currency format and then provide a concise answer basis the following guidelines:

Given the following user question, SQL query, and SQL results:
1. Keep the tone conversational and friendly.
2. Analyse the result data and try to provide a trend or analysis if possible.
3. Provide suggestions on similar questions related to the question asked.
4. If Sql Result is None or empty, provide a response that "{SQL_RESULT_ERROR}"

"""

prompt = ChatPromptTemplate.from_messages(
  [
      ("system", template),
      ("placeholder", "{messages}")
  ]
).partial(SQL_RESULT_ERROR=ChatMessage.SQL_RESULT_ERROR)

@tool
def format_in_indian_currency(amount):
    """Converts a numeric amount into Indian rupee format with Lakhs/Crores."""
    amount = math.ceil(amount)
    if amount >= 10000000:  # 1e7 = 1,00,00,000
        return f"{amount / 10000000:.2f} Crores"
    elif amount >= 100000:  # 1e5 = 1,00,000
        return f"{amount / 100000:.2f} Lakhs"
    else:
        return f"₹{amount:,}"
    
@tool
def get_sql_query(question):
    """Returns the SQL query for the user question."""
    return "SELECT SUM(`Total Order Value`) AS `Total Sales` FROM  `rpt_order_details_report` WHERE  `Transaction Date` BETWEEN '2023-04-01' AND '2024-03-31';"

@tool
def run_query(query):
    """Runs the sql query"""
    return "[(Decimal('2991221641.22'),)]"

tools = [format_in_indian_currency, get_sql_query, run_query]

llm = ChatOpenAI(model="gpt-4o")
chain_runnable = prompt | llm.bind_tools(tools)
tools = [format_in_indian_currency, run_query, get_sql_query]

In [10]:
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.constants import START
from langgraph.graph import add_messages, StateGraph, MessagesState
from langchain_core.messages import AnyMessage
from typing import TypedDict, Annotated
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import tools_condition, ToolNode

class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    
# class Assistant:
#     def __init__(self, runnable: Runnable):
#         self.runnable = runnable
# 
#     def __call__(self, state: State, config: RunnableConfig):
#         while True:
#             configuration = config.get("configurable", {})
#             # question = configuration.get("question", None)
#             # result = configuration.get("result", None)
#             # query = configuration.get("query", None)
#             # state = {**state, "question": question, "result": result,"query": query}
#             state = {**state}
#             result = self.runnable.invoke(state)
#             # If the LLM happens to return an empty response, we will re-prompt it
#             # for an actual response.
#             if not result.tool_calls and (
#                 not result.content
#                 or isinstance(result.content, list)
#                 and not result.content[0].get("text")
#             ):
#                 messages = state["messages"] + [("user", "Respond with a real output.")]
#                 state = {**state, "messages": messages}
#             else:
#                 break
#         return {"messages": result}
    
    
def call_model(state: State):
    messages = state["messages"]
    print(messages)
    response = chain_runnable.invoke({"messages": messages})
    return {"messages": [response]}


builder = StateGraph(MessagesState)
builder.add_node("agent", call_model)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "agent")
builder.add_conditional_edges("agent", tools_condition)
builder.add_edge("tools", "agent")

# memory = MemorySaver()
graph = builder.compile()


In [12]:
from langchain_core.tracers.context import tracing_v2_enabled


config = {
    "configurable": {
        # "question": "sales for financial year 2023",
        # "result": "[(Decimal('2991221641.22'),)]",
        # "query": "SELECT SUM(`Total Order Value`) AS `Total Sales` FROM  `rpt_order_details_report` WHERE  `Transaction Date` BETWEEN '2023-04-01' AND '2024-03-31';",
        "thread_id": thread_id
    }
}
with tracing_v2_enabled():
    # final_state = graph.invoke({"messages": ("user", "sales for financial year 2023?")}, config)
    # print(final_state)
    events = graph.stream({"messages": ("user", "sales for financial year 2023?")}, config, stream_mode="values")
    # print(state)
    for event in events:
            _print_event(event, _printed)


[HumanMessage(content='sales for financial year 2023?', additional_kwargs={}, response_metadata={}, id='12e94a64-5fb8-4b50-bef9-111398e80b75')]
[HumanMessage(content='sales for financial year 2023?', additional_kwargs={}, response_metadata={}, id='12e94a64-5fb8-4b50-bef9-111398e80b75'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_zUfyNDdHFOoV2wAefF3qZIzY', 'function': {'arguments': '{"question":"sales for financial year 2023?"}', 'name': 'get_sql_query'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 23, 'prompt_tokens': 239, 'total_tokens': 262, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_4691090a87', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-5b59cc45-dd13-442c-b506-96ebb651da