In [13]:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import create_react_agent
from intent_entity.intent_entity import IntentEntity
from template_formatter.template_formatter_controller import TemplateSelector


intent_entity = IntentEntity()
template_selector = TemplateSelector()


In [None]:
username = 'root'
password = 'mindstix'
host = 'localhost'  # or your database host
port = '3306'       # default MySQL port
database = 'edb3'
table_name = 'amazon_sale_report'
sql_string = f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}'

db = SQLDatabase.from_uri(sql_string, include_tables=['amazon_sale_report'])
# llm = ChatOpenAI(model="gpt-4o-mini")
llm = ChatOpenAI(model="gpt-4o")


toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

In [None]:
print(db.get_table_info())

In [16]:
with open('primary_kpis_system_message.txt', 'r') as file:
    primary_kpis_prompt = file.read()


primary_kpis_system_message = SystemMessage(content=primary_kpis_prompt)


In [17]:
_mysql_prompt = """
Analyze the following Secondary KPIs for any significant changes or new trends based on Category, Size, State, and Channel: Fill Rate, Average Order Value, Return Rate, Damage Rate, and Shipment Error Rate.

Below is the name, description and schema of each table tables that you can use:

1.  Table : 'amazon_sale_report' 
    Description : This dataset provides detailed insights into Amazon sales data, including SKU Code, Design Number, Stock, Category, Size and Color, to help optimize product profitability.
    Schema : 
CREATE TABLE amazon_sale_report (
	`index` INTEGER, 
	order_id TEXT, 
	date DATE, 
	status TEXT, 
	fulfilment TEXT, 
	sales_channel TEXT, 
	`ship-service-level` TEXT, 
	style TEXT, 
	`SKU` TEXT, 
	category TEXT, 
	size TEXT, 
	`ASIN` TEXT, 
	courier_status TEXT, 
	qty INTEGER, 
	currency TEXT, 
	amount DOUBLE, 
	`ship-city` TEXT, 
	`ship-state` TEXT, 
	`ship-postal-code` INTEGER, 
	`ship-country` TEXT, 
	`promotion-ids` TEXT, 
	`B2B` TEXT, 
	`fulfilled-by` TEXT, 
	unnamed TEXT
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci


Identify significant changes:

Track the changes and percentage changes for each KPI across all Category, Size, State, and Channel dimensions.
Report any significant changes (increase or decrease) in these KPIs for any Category, Size, State, or Channel that contributes the most to sales or exhibits a healthy KPI profile.
Detect new trends:

Evaluate the data for new trends in Category, Size, State, or Channel based on changes in sales, Fill Rate, Average Order Value, Return Rate, Damage Rate, and Shipment Error Rate.
Highlight any new emerging patterns or trends across these dimensions.
Below are the KPI definitions:

Total Orders: The number of distinct order IDs.
Average Order Value (AOV): Total sales amount divided by total orders.
Fulfillment Rate: Percentage of orders with shipped status over total orders.
Return Rate: Percentage of returned orders.
Shipment Error Rate: Percentage of orders marked as Shipped - Damaged or Shipped - Lost in Transit.
Damage Rate: Percentage of orders marked as Shipped - Damaged in Transit.
Use the MySQL database to fetch relevant data and generate SQL queries as needed. Present the output in a detailed report summarizing the findings for both changes and trends in the KPIs.

"""

secondary_kpis_system_message = SystemMessage(content=_mysql_prompt)

In [18]:
with open('mysql_prompt.txt', 'r') as file:
    mysql_prompt = file.read()

system_message = SystemMessage(content=mysql_prompt)


In [19]:
from token_module import CallbackHandler, llm_trace_to_jaeger, tracer
handler = CallbackHandler()
# user_message, config={"callbacks":[callback]}

In [None]:
agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)
primary_kpis_agent_executor = create_react_agent(llm, tools, messages_modifier=primary_kpis_system_message)
secondary_kpis_agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)



In [None]:
question = 'What is the latest data on sales?'

for s in agent_executor.stream(
    {"messages": [HumanMessage(content=question)]}
):
    print(s)
    print("----")

In [None]:


question = 'What is the latest data on sales?'

intent_name = intent_entity.find_intent_entity(question=question, intents = 'general_intents')
print("Intent name:",intent_name)

agent_response = None

if intent_name == "greeting":
    result = template_selector.select_template()

elif intent_name == 'simple_sales_analysis':
    agent_response = agent_executor.invoke(input={"messages": [HumanMessage(content=question)]})

elif intent_name == 'complex_sales_analysis':
    kpi_intent = intent_entity.find_intent_entity(question=question, intents = 'kpis_classifier')

    if kpi_intent == 'primary_kpis':
        print("kpi_intent:",kpi_intent)
        agent_response = primary_kpis_agent_executor.invoke(input={"messages": [HumanMessage(content=question)]})

    else:
        print("kpi_intent:",'secondary_kpis')
        agent_response = secondary_kpis_agent_executor.invoke(input={"messages": [HumanMessage(content=question)]})

if agent_response:
    tracing_list = handler.infi()

    with tracer.start_span("chain_traces") as chain_traces:
        llm_trace_to_jaeger(tracing_list, chain_traces.span_id, chain_traces.trace_id)

    try:
        queries = [call['args']['query'] for call in agent_response['messages'][0]['additional_kwargs']['tool_calls']]
        print(queries)
        for query in queries:
            print(query)
    except Exception as e:
        print('-------------------------')
    result = agent_response['messages'][-1].content

print(result)


In [None]:
tracing_list

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from langchain_core.messages import HumanMessage

app = Flask(__name__)
CORS(app)


@app.route("/qna", methods=["POST"])
def get_response():
    try:
        if "question" in request.form:
            question = request.form["question"].strip()

        agent_response = agent_executor.invoke(input={"messages": [HumanMessage(content=question)]})
        result = agent_response['messages'][-1].content

        try:
            queries = [call['args']['query'] for call in agent_response['messages'][0]['additional_kwargs']['tool_calls']]

            for query in queries:
                print(query)
        except Exception as e:
            pass

        response = {'answer': result, 'agent_response': str(agent_response)}
        return jsonify(response), 200

    except Exception as e:

        error_response = {"error": "An error occurred, please try again.", "details": str(e)}
        return jsonify(error_response), 500  # 500 for internal server error

if __name__ == "__main__":
    app.run(debug=False)


In [16]:
import matplotlib.pyplot as plt
# import seaborn as sns
import pandas as pd
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDataBaseTool,
)
class SQLVisualizationTool:
    def __init__(self, db, llm):
        self.db = db
        self.llm = llm
    
    def generate_sql_query(self, question):
        # Use the LLM to generate the SQL query based on the question
        query_checker_tool = QuerySQLCheckerTool(db=self.db, llm=self.llm)
        query = query_checker_tool.run(question)
        return query

    def execute_sql_query(self, query):
        # Execute the SQL query
        result = self.db.run(query)
        return pd.DataFrame(result)

    def create_visualization(self, data, chart_type="bar"):
        # Generate a chart based on the data
        if chart_type == "bar":
            data.plot(kind="bar")
        elif chart_type == "pie":
            data.plot(kind="pie", autopct='%1.1f%%')
        plt.show()

    def run(self, question, chart_type="bar"):
        query = self.generate_sql_query(question)
        data = self.execute_sql_query(query)
        self.create_visualization(data, chart_type)

# Example usage
# visualization_tool = SQLVisualizationTool(db, llm)
# visualization_tool.run("Show me the sales by product category", chart_type="bar")


In [None]:
# Add the visualization tool to the toolkit
visualization_tool = SQLVisualizationTool(db=db, llm=llm)
tools.append(visualization_tool)

# Run the agent with the new tool
agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)
