In [None]:
# !pip install mysql-connector
# !pip install pydantic==1.10.12
# !pip install langchain
# !pip install langchain-openai

In [None]:
#mysql.server start

## Libraries

In [1]:
import csv
import os
import mysql.connector
import subprocess

from langchain import LLMChain
from langchain.chains import LLMChain, LLMMathChain, SequentialChain, TransformChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel, Field, validator
from langchain.tools import Tool
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_tools_agent, create_openai_functions_agent
from langchain_community.agent_toolkits.sql.prompt import SQL_FUNCTIONS_SUFFIX
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)

from langchain import hub
from langchain_experimental.tools import PythonREPLTool
from langchain import PromptTemplate



## Setup SQL Connection

In [3]:
mydb = mysql.connector.connect(
    host="localhost",
    user="root"
)
mycursor = mydb.cursor()

# Run only once
mycursor.execute("CREATE DATABASE IF NOT EXISTS ecommerce")

mydb = mysql.connector.connect(
    host="localhost",
    user="root",
    database="ecommerce"
)
mycursor = mydb.cursor()

## Create all tables

In [None]:
mycursor.execute("CREATE TABLE distribution_centers(id INT, name VARCHAR(255), latitude FLOAT, longitude FLOAT)")
mycursor.execute("CREATE TABLE events(id INT, user_id INT, sequence_number INT, session_id VARCHAR(255), created_at TIMESTAMP, ip_address VARCHAR(255), city VARCHAR(255), state VARCHAR(255), postal_code VARCHAR(255), browser VARCHAR(255), traffic_source VARCHAR(255), uri VARCHAR(255), event_type VARCHAR(255))")
mycursor.execute("CREATE TABLE inventory_items(id INT, product_id INT, created_at TIMESTAMP, sold_at TIMESTAMP, cost FLOAT, product_category VARCHAR(255), product_name VARCHAR(255), product_brand VARCHAR(255), product_retail_price FLOAT, product_department VARCHAR(255), product_sku VARCHAR(255), product_distribution_center_id INT)")
mycursor.execute("CREATE TABLE order_items(id INT, order_id INT, user_id INT, product_id INT, inventory_item_id INT, status VARCHAR(255), created_at TIMESTAMP, shipped_at TIMESTAMP, delivered_at TIMESTAMP, returned_at TIMESTAMP, sale_price FLOAT)")
mycursor.execute("CREATE TABLE orders(order_id INT, user_id INT, status VARCHAR(255), gender VARCHAR(255), created_at TIMESTAMP, returned_at TIMESTAMP, shipped_at TIMESTAMP, delivered_at TIMESTAMP, num_of_item INT)")
mycursor.execute("CREATE TABLE products(id INT, cost FLOAT, category VARCHAR(255), name VARCHAR(255), brand VARCHAR(255), retail_price FLOAT, department VARCHAR(255), sku VARCHAR(255), distribution_center_id INT)")
mycursor.execute("CREATE TABLE users(id INT, first_name VARCHAR(255), last_name VARCHAR(255), email VARCHAR(255), age INT, gender VARCHAR(255), state VARCHAR(255), street_address VARCHAR(255), postal_code VARCHAR(255), city VARCHAR(255), country VARCHAR(255), latitude FLOAT, longitude FLOAT, traffic_source VARCHAR(255), created_at TIMESTAMP)")

## Data export to MySQL tables

In [None]:
# Read data from CSV files and insert that data into corresponding tables in a database.

table_names = ["distribution_centers", "events", "inventory_items", "order_items", "orders", "products", "users"]
for table_name in table_names:
    csv_data = csv.reader(open("data/%s.csv" % table_name))
    next(csv_data) #to ignore headers
    counter = 0
    print("Currently inserting data into table %s" % (table_name))
    for row in csv_data:
        if counter % 10000 == 0:
            print("Progress is", counter)
        row = [None if cell == '' else cell.replace(" UTC", "") for cell in row]
        #print(row)
        postfix = ','.join(["%s"] * len(row))
        #print(f"INSERT INTO %s VALUES(%s)" % (table_name, postfix))
        mycursor.execute(f"INSERT INTO %s VALUES(%s)" % (table_name, postfix), row)
        counter += 1
    mydb.commit()

## SQL table connection

In [6]:
sql_db = SQLDatabase.from_uri("mysql://localhost:3306/ecommerce?user=root")

In [7]:
os.environ['OPENAI_API_KEY'] = "YOUR_API_KEY"
model_name = "gpt-4-0125-preview"

In [8]:
llm = OpenAI(temperature=0.0)

  warn_deprecated(


## SQL Query Engine

- Creates SQL query from the user input

In [9]:
class SQLQueryEngine:
    """
    A class representing an SQL query engine.

    Attributes:
        llm (ChatOpenAI): An instance of ChatOpenAI used for natural language processing.
        toolkit (SQLDatabaseToolkit): An SQL database toolkit instance.
        context (dict): Contextual information obtained from the SQL database toolkit.
        tools (list): List of tools available for SQL query execution.
        prompt (ChatPromptTemplate): The prompt used for interactions with the SQL query engine.
        agent_executor (AgentExecutor): An executor for the SQL query engine's agent.
    """
    def __init__(self, model_name, db):
        self.llm = ChatOpenAI(model=model_name, temperature=0)
        self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
        self.context = self.toolkit.get_context()
        self.tools = self.toolkit.get_tools()
        self.prompt = None
        self.agent_executor = None
        
    def set_prompt(self):
        messages = [
            HumanMessagePromptTemplate.from_template("{input}"),
            AIMessage(content=SQL_FUNCTIONS_SUFFIX),
            MessagesPlaceholder(variable_name="agent_scratchpad")
            ]
        self.prompt = ChatPromptTemplate.from_messages(messages)
        self.prompt = self.prompt.partial(**self.context)
        
    def initialize_agent(self):
        agent = create_openai_tools_agent(self.llm, self.tools, self.prompt)
        self.agent_executor = AgentExecutor(
            agent=agent,
            tools=self.toolkit.get_tools(),
            verbose=True,
        )
        
    def get_query_data(self, query):
        if 'return' in query:
            query = query + "\n" + "return percentage is defined as total number of returns divided by total number of orders. You can join orders table with users table to know more about each user"
        return self.agent_executor.invoke({"input": query})['output']

## Python Dashboard Engine

- Creates Streamlit dashboard from data output by SQLQueryEngine

In [10]:
# REPL -> Read Evaluate Print Loop
class PythonDashboardEngine:
    """
    A class representing a Python dashboard engine.

    Attributes:
        tools (list): A list of tools available for the dashboard engine.
        instructions (str): Instructions guiding the behavior of the dashboard engine.
        prompt (str): The prompt used for interactions with the dashboard engine.
        agent_executor (AgentExecutor): An executor for the dashboard engine's agent.
    """
    def __init__(self):
        self.tools = [PythonREPLTool()]
        self.instructions = """You are an agent designed to write a python code to answer questions.
        You have access to a python REPL, which you can use to execute python code.
        If you get an error, debug your code and try again.
        You might know the answer without running any code, but you should still run the code to get the answer.
        If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
        Always output the python code only.
        """
        base_prompt = hub.pull("langchain-ai/openai-functions-template")
        self.prompt = base_prompt.partial(instructions=self.instructions)
        self.agent_executor = None
        
    def initialize(self):
        agent = create_openai_functions_agent(ChatOpenAI(model=model_name, temperature=0), self.tools, self.prompt)
        self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True)
        
    def get_output(self, query):
        output = self.agent_executor.invoke({"input": "Write a code in python to plot the following data\n\n" + query})
        return output['output']
    
    def parse_output(self, inp):
        inp = inp.split('```')[1].replace("```", "").replace("python", "").replace("plt.show()", "")
        outp = "import streamlit as st\nst.set_option('deprecation.showPyplotGlobalUse', False)\nst.title('E-commerce Company[insights]')\nst.write('Here is our LLM generated dashboard')" \
                + inp + "st.pyplot()\n"
        return outp
    
    def export_to_streamlit(self, data):
        with open("app.py", "w") as text_file:
            text_file.write(self.parse_output(data))

        command = "streamlit run app.py"
        proc = subprocess.Popen([command], shell=True, stdin=None, stdout=None, stderr=None, close_fds=True)

In [11]:
global sql_query_engine, dashboard_engine

def init_engines():
    sql_query_engine = SQLQueryEngine(model_name, sql_db)
    sql_query_engine.set_prompt()
    sql_query_engine.initialize_agent()

    dashboard_engine = PythonDashboardEngine()
    dashboard_engine.initialize()
    return sql_query_engine, dashboard_engine

### Query 1

- Number of users with their gender

In [12]:
sql_query_engine, dashboard_engine = init_engines()
query = "Number of users with their gender"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'users'}`


[0m[33;1m[1;3m
CREATE TABLE users (
	id INTEGER, 
	first_name VARCHAR(255), 
	last_name VARCHAR(255), 
	email VARCHAR(255), 
	age INTEGER, 
	gender VARCHAR(255), 
	state VARCHAR(255), 
	street_address VARCHAR(255), 
	postal_code VARCHAR(255), 
	city VARCHAR(255), 
	country VARCHAR(255), 
	latitude FLOAT, 
	longitude FLOAT, 
	traffic_source VARCHAR(255), 
	created_at TIMESTAMP NULL
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from users table:
id	first_name	last_name	email	age	gender	state	street_address	postal_code	city	country	latitude	longitude	traffic_source	created_at
9766	Brandon	Phillips	brandonphillips@example.com	44	M	Acre	717 Martinez Street Apt.

### Query 2

- number of users in each country who came via facebook

In [13]:
#sql_query_engine, dashboard_engine = init_engines()
query = "number of users in each country who came via facebook"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'users'}`


[0m[33;1m[1;3m
CREATE TABLE users (
	id INTEGER, 
	first_name VARCHAR(255), 
	last_name VARCHAR(255), 
	email VARCHAR(255), 
	age INTEGER, 
	gender VARCHAR(255), 
	state VARCHAR(255), 
	street_address VARCHAR(255), 
	postal_code VARCHAR(255), 
	city VARCHAR(255), 
	country VARCHAR(255), 
	latitude FLOAT, 
	longitude FLOAT, 
	traffic_source VARCHAR(255), 
	created_at TIMESTAMP NULL
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from users table:
id	first_name	last_name	email	age	gender	state	street_address	postal_code	city	country	latitude	longitude	traffic_source	created_at
9766	Brandon	Phillips	brandonphillips@example.com	44	M	Acre	717 Martinez Street Apt.

### Query 3

- number of orders per month since January 2020

In [19]:
#sql_query_engine, dashboard_engine = init_engines()
query = "number of orders per month since January 2020"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'orders'}`


[0m[33;1m[1;3m
CREATE TABLE orders (
	order_id INTEGER, 
	user_id INTEGER, 
	`status` VARCHAR(255), 
	gender VARCHAR(255), 
	created_at TIMESTAMP NULL, 
	returned_at TIMESTAMP NULL, 
	shipped_at TIMESTAMP NULL, 
	delivered_at TIMESTAMP NULL, 
	num_of_item INTEGER
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from orders table:
order_id	user_id	status	gender	created_at	returned_at	shipped_at	delivered_at	num_of_item
23	15	Cancelled	F	2023-09-23 08:12:00	None	None	None	1
45	30	Cancelled	F	2024-02-24 10:54:17	None	None	None	1
60	40	Cancelled	F	2023-06-04 14:56:00	None	None	None	4
*/[0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': "SELECT DA

### Query 4

- top 3 product categories with highest number of returns by count

In [20]:
#sql_query_engine, dashboard_engine = init_engines()
query = "top 3 product categories with highest number of returns by count"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'orders, order_items, products'}`


[0m[33;1m[1;3m
CREATE TABLE order_items (
	id INTEGER, 
	order_id INTEGER, 
	user_id INTEGER, 
	product_id INTEGER, 
	inventory_item_id INTEGER, 
	`status` VARCHAR(255), 
	created_at TIMESTAMP NULL, 
	shipped_at TIMESTAMP NULL, 
	delivered_at TIMESTAMP NULL, 
	returned_at TIMESTAMP NULL, 
	sale_price FLOAT
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from order_items table:
id	order_id	user_id	product_id	inventory_item_id	status	created_at	shipped_at	delivered_at	returned_at	sale_price
162569	112164	89224	14235	438986	Cancelled	2023-10-25 04:27:30	None	None	None	0.02
25143	17365	13804	14235	67816	Complete	2021-02-16 04:09:02	2021-0

### Query 5

- return percentage country wise

In [21]:
#sql_query_engine, dashboard_engine = init_engines()
query = "return percentage country wise"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'orders'}`


[0m[33;1m[1;3m
CREATE TABLE orders (
	order_id INTEGER, 
	user_id INTEGER, 
	`status` VARCHAR(255), 
	gender VARCHAR(255), 
	created_at TIMESTAMP NULL, 
	returned_at TIMESTAMP NULL, 
	shipped_at TIMESTAMP NULL, 
	delivered_at TIMESTAMP NULL, 
	num_of_item INTEGER
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from orders table:
order_id	user_id	status	gender	created_at	returned_at	shipped_at	delivered_at	num_of_item
23	15	Cancelled	F	2023-09-23 08:12:00	None	None	None	1
45	30	Cancelled	F	2024-02-24 10:54:17	None	None	None	1
60	40	Cancelled	F	2023-06-04 14:56:00	None	None	None	4
*/[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'users'}`



### Query 6

- which are my top 5 geographies with highest business

In [41]:
#sql_query_engine, dashboard_engine = init_engines()
query = "which are my top 5 geographies with highest business"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)

### Query 7

- which product categories have the highest margins in 2024

In [39]:
#sql_query_engine, dashboard_engine = init_engines()
query = "which product categories have the highest margins in 2024"
sql_query_engine_output = sql_query_engine.get_query_data(query)
#print(sql_query_engine_output)

dashboard_engine_output = dashboard_engine.get_output(sql_query_engine_output)
dashboard_engine.export_to_streamlit(dashboard_engine_output)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mdistribution_centers, events, inventory_items, message_store, order_items, orders, products, users[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'products, order_items'}`


[0m[33;1m[1;3m
CREATE TABLE order_items (
	id INTEGER, 
	order_id INTEGER, 
	user_id INTEGER, 
	product_id INTEGER, 
	inventory_item_id INTEGER, 
	`status` VARCHAR(255), 
	created_at TIMESTAMP NULL, 
	shipped_at TIMESTAMP NULL, 
	delivered_at TIMESTAMP NULL, 
	returned_at TIMESTAMP NULL, 
	sale_price FLOAT
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from order_items table:
id	order_id	user_id	product_id	inventory_item_id	status	created_at	shipped_at	delivered_at	returned_at	sale_price
162569	112164	89224	14235	438986	Cancelled	2023-10-25 04:27:30	None	None	None	0.02
25143	17365	13804	14235	67816	Complete	2021-02-16 04:09:02	2021-02-17 03: