# Experiment SparkSQLAgent with Complex Cases

This experiment is based on https://github.com/hwchase17/langchain/pull/4381

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
schema = "langchain_example"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}")
spark.sql(f"USE {schema}")
csv_file_path = "../titanic.csv"
table = "titanic"
spark.read.csv(csv_file_path, header=True, inferSchema=True).write.mode('overwrite').saveAsTable(table)
spark.table(table).show()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/10 21:15:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|      

In [2]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.chat_models import ChatOpenAI
from langchain.spark_sql import SparkSQL

db = SparkSQL(schema=schema)
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True
)

In [3]:
agent_executor.run("which cabin is the most expensive?")



[1m> Entering new AgentExecutor chain...[0m

MK,  Action: list_tables_sql_db
Action Input: "" 

[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:
MK,  I should query the titanic table to find the most expensive cabin.
Action: query_checker_sql_db
Action Input: "SELECT cabin, fare FROM titanic ORDER BY fare DESC LIMIT 1" 

[32;1m[1;3mI should query the titanic table to find the most expensive cabin.
Action: query_checker_sql_db
Action Input: "SELECT cabin, fare FROM titanic ORDER BY fare DESC LIMIT 1"[0m
Observation: [31;1m[1;3mThe original query is correct and does not contain any of the mentioned mistakes.[0m
Thought:
MK,  I can now execute the query to get the answer.
Action: query_sql_db
Action Input: "SELECT cabin, fare FROM titanic ORDER BY fare DESC LIMIT 1" 

[32;1m[1;3mI can now execute the query to get the answer.
Action: query_sql_db
Action Input: "SELECT cabin, fare FROM titanic ORDER BY fare DESC LIMI

"'None'"

In [5]:
agent_executor.run("how many people bought the most expensive cabin?")



[1m> Entering new AgentExecutor chain...[0m

MK,  Action: list_tables_sql_db, "" 

[32;1m[1;3mAction: list_tables_sql_db, ""[0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:
MK,  I should query the titanic table to find the number of people who bought the most expensive cabin.
Action: query_checker_sql_db, "SELECT COUNT(*) FROM titanic WHERE Fare = (SELECT MAX(Fare) FROM titanic WHERE Pclass = 1)" 

[32;1m[1;3mI should query the titanic table to find the number of people who bought the most expensive cabin.
Action: query_checker_sql_db, "SELECT COUNT(*) FROM titanic WHERE Fare = (SELECT MAX(Fare) FROM titanic WHERE Pclass = 1)"[0m
Observation: [31;1m[1;3mThe original query is correct and does not contain any of the common mistakes listed. Therefore, the query is:

SELECT COUNT(*) FROM titanic WHERE Fare = (SELECT MAX(Fare) FROM titanic WHERE Pclass = 1)[0m
Thought:
MK,  I should execute the query to get the final answer.
Action: query_sql_db, "SELECT COUNT(*) FROM tit

'3 people bought the most expensive cabin.'

In [6]:
agent_executor.run("what's the percentage of female in the group who bought the most expensive cabin?")



[1m> Entering new AgentExecutor chain...[0m

MK,  Action: list_tables_sql_db, "" 

[32;1m[1;3mAction: list_tables_sql_db, ""[0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:
MK,  I should query the titanic table to get the relevant information.
Action: schema_sql_db, "titanic" 

[32;1m[1;3mI should query the titanic table to get the relevant information.
Action: schema_sql_db, "titanic"[0mCREATE TABLE langchain_example.titanic (
  PassengerId INT,
  Survived INT,
  Pclass INT,
  Name STRING,
  Sex STRING,
  Age DOUBLE,
  SibSp INT,
  Parch INT,
  Ticket STRING,
  Fare DOUBLE,
  Cabin STRING,
  Embarked STRING)
USING parquet


Observation: [33;1m[1;3mCREATE TABLE langchain_example.titanic (
  PassengerId INT,
  Survived INT,
  Pclass INT,
  Name STRING,
  Sex STRING,
  Age DOUBLE,
  SibSp INT,
  Parch INT,
  Ticket STRING,
  Fare DOUBLE,
  Cabin STRING,
  Embarked STRING)
;

/*
3 rows from titanic table:
PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	C

AnalysisException: [MISSING_GROUP_BY] The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses.; line 1 pos 0;
Aggregate [Sex#69, ((cast(count(1) as decimal(20,0)) * 100.0) / cast(scalar-subquery#248 [] as decimal(20,0))) AS percentage#249]
:  +- Aggregate [count(1) AS count(1)#281L]
:     +- Filter (cast(Cabin#292 as double) = scalar-subquery#247 [])
:        :  +- Aggregate [max(Fare#278) AS max(Fare)#268]
:        :     +- SubqueryAlias spark_catalog.langchain_example.titanic
:        :        +- Relation spark_catalog.langchain_example.titanic[PassengerId#269,Survived#270,Pclass#271,Name#272,Sex#273,Age#274,SibSp#275,Parch#276,Ticket#277,Fare#278,Cabin#279,Embarked#280] parquet
:        +- SubqueryAlias t2
:           +- SubqueryAlias spark_catalog.langchain_example.titanic
:              +- Relation spark_catalog.langchain_example.titanic[PassengerId#282,Survived#283,Pclass#284,Name#285,Sex#286,Age#287,SibSp#288,Parch#289,Ticket#290,Fare#291,Cabin#292,Embarked#293] parquet
+- Filter ((cast(Cabin#75 as double) = scalar-subquery#250 []) AND (Sex#69 = female))
   :  +- Aggregate [max(Fare#263) AS max(Fare)#253]
   :     +- SubqueryAlias spark_catalog.langchain_example.titanic
   :        +- Relation spark_catalog.langchain_example.titanic[PassengerId#254,Survived#255,Pclass#256,Name#257,Sex#258,Age#259,SibSp#260,Parch#261,Ticket#262,Fare#263,Cabin#264,Embarked#265] parquet
   +- SubqueryAlias t1
      +- SubqueryAlias spark_catalog.langchain_example.titanic
         +- Relation spark_catalog.langchain_example.titanic[PassengerId#65,Survived#66,Pclass#67,Name#68,Sex#69,Age#70,SibSp#71,Parch#72,Ticket#73,Fare#74,Cabin#75,Embarked#76] parquet
