## import

In [1]:
from pyspark.sql import SparkSession

## Read data into DataFrame

In [2]:
import os
spark = SparkSession.Builder().appName("Spark Dataframe LLM NLQ Demo").enableHiveSupport().master("local[*]").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/17 08:31:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Read kaggle credit_card_complaints.csv 

In [3]:
credit_card_complaints_file = '/Users/sharattadimalla/github/nlq-app/public_data/credit_card_complaints.csv'
complaints_df = spark.read.format("csv").option("delimiter", ",").option("header", True).load(credit_card_complaints_file)

schema = "langchain_example"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}")
spark.sql(f"USE {schema}")
table = "credit_card_complaints"
spark.sql(f"DROP TABLE IF EXISTS {table}")
spark.read.csv(credit_card_complaints_file, header=True, inferSchema=True).write.saveAsTable(table)
spark.table(table).select("date_received", "complaint_id","issue","company").show(5, truncate=False)

                                                                                

+-------------+------------+-------------------------------------+---------------------+
|date_received|complaint_id|issue                                |company              |
+-------------+------------+-------------------------------------+---------------------+
|2016-08-02   |2042660     |Other                                |Citibank             |
|2016-08-18   |2070223     |Billing disputes                     |Citibank             |
|2016-08-11   |2058230     |Identity theft / Fraud / Embezzlement|Wells Fargo & Company|
|2016-08-18   |2070292     |Billing disputes                     |Citibank             |
|2016-08-15   |2063158     |Billing disputes                     |Synchrony Financial  |
+-------------+------------+-------------------------------------+---------------------+
only showing top 5 rows



## Natural Language Query on Spark SQL

In [4]:
from langchain.agents import create_spark_sql_agent
from langchain.agents.agent_toolkits import SparkSQLToolkit
from langchain.chat_models import ChatOpenAI
from langchain.utilities.spark_sql import SparkSQL
from langchain.callbacks import get_openai_callback
import dotenv

In [5]:
dotenv.load_dotenv()
spark_sql = SparkSQL(schema=schema)
llm = ChatOpenAI(temperature=0)
toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)
agent_executor = create_spark_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

## Natural language Queries

In [6]:
with get_openai_callback() as cb:
    question1 = "How many customer complaints are there?"
    agent_executor.run(question1)
    print(f"Total Tokens: {cb.total_tokens}")
    print(f"Prompt Tokens: {cb.prompt_tokens}")
    print(f"Completion Tokens: {cb.completion_tokens}")
    print(f"Total Cost (USD): ${cb.total_cost}")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mcomplaints, credit_card_complaints[0m
Thought:[32;1m[1;3mI can query the "complaints" and "credit_card_complaints" tables to find the number of customer complaints.
Action: query_sql_db
Action Input: "SELECT COUNT(*) FROM complaints"[0m
Observation: [36;1m[1;3m[('115195',)][0m
Thought:[32;1m[1;3mThere are 115,195 customer complaints.
Final Answer: 115,195[0m

[1m> Finished chain.[0m
Total Tokens: 1785
Prompt Tokens: 1716
Completion Tokens: 69
Total Cost (USD): $0.002712


In [7]:
question2 = "What are the top 5 issues?"
agent_executor.run(question2)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mcomplaints, credit_card_complaints[0m
Thought:[32;1m[1;3mI can query the "complaints" and "credit_card_complaints" tables.
Action: schema_sql_db
Action Input: "complaints, credit_card_complaints"[0m
Observation: [33;1m[1;3mCREATE TABLE spark_catalog.langchain_example.complaints (
  date_received STRING,
  product STRING,
  sub_product STRING,
  issue STRING,
  sub_issue STRING,
  consumer_complaint_narrative STRING,
  company_public_response STRING,
  company STRING,
  state STRING,
  zip_code STRING,
  tags STRING,
  consumer_consent_provided STRING,
  submitted_via STRING,
  date_sent_to_company STRING,
  company_response_to_consumer STRING,
  timely_response STRING,
  consumer_disputed STRING,
  complaint_id STRING)
;

/*
3 rows from complaints table:
date_received	product	sub_product	issue	sub_issue	consumer_complaint_narrative	company_pub

                                                                                


Observation: [36;1m[1;3m[('None', '20117'), ('Billing disputes', '14827'), ('Other', '9144'), ('Identity theft / Fraud / Embezzlement', '8330'), ('Closing/Cancelling account', '6291')][0m
Thought:[32;1m[1;3mThe top 5 issues are: 'None', 'Billing disputes', 'Other', 'Identity theft / Fraud / Embezzlement', and 'Closing/Cancelling account'.
Final Answer: The top 5 issues are 'None', 'Billing disputes', 'Other', 'Identity theft / Fraud / Embezzlement', and 'Closing/Cancelling account'.[0m

[1m> Finished chain.[0m


"The top 5 issues are 'None', 'Billing disputes', 'Other', 'Identity theft / Fraud / Embezzlement', and 'Closing/Cancelling account'."

In [8]:
question3 = "which companies had the top 5 complaints?"
agent_executor.run(question3)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mcomplaints, credit_card_complaints[0m
Thought:[32;1m[1;3mI can query the "complaints" table to find the companies with the most complaints.
Action: schema_sql_db
Action Input: "complaints"[0m
Observation: [33;1m[1;3mCREATE TABLE spark_catalog.langchain_example.complaints (
  date_received STRING,
  product STRING,
  sub_product STRING,
  issue STRING,
  sub_issue STRING,
  consumer_complaint_narrative STRING,
  company_public_response STRING,
  company STRING,
  state STRING,
  zip_code STRING,
  tags STRING,
  consumer_consent_provided STRING,
  submitted_via STRING,
  date_sent_to_company STRING,
  company_response_to_consumer STRING,
  timely_response STRING,
  consumer_disputed STRING,
  complaint_id STRING)
;

/*
3 rows from complaints table:
date_received	product	sub_product	issue	sub_issue	consumer_complaint_narrative	company_public_res

'None, Citibank, Capital One, JPMorgan Chase & Co., Bank of America'

In [9]:
question4 = "How many issues were resolved in timely fashion?"
agent_executor.run(question4)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mcomplaints, credit_card_complaints[0m
Thought:[32;1m[1;3mI can query the "complaints" and "credit_card_complaints" tables to find the number of resolved issues.
Action: schema_sql_db
Action Input: "complaints, credit_card_complaints"[0m
Observation: [33;1m[1;3mCREATE TABLE spark_catalog.langchain_example.complaints (
  date_received STRING,
  product STRING,
  sub_product STRING,
  issue STRING,
  sub_issue STRING,
  consumer_complaint_narrative STRING,
  company_public_response STRING,
  company STRING,
  state STRING,
  zip_code STRING,
  tags STRING,
  consumer_consent_provided STRING,
  submitted_via STRING,
  date_sent_to_company STRING,
  company_response_to_consumer STRING,
  timely_response STRING,
  consumer_disputed STRING,
  complaint_id STRING)
;

/*
3 rows from complaints table:
date_received	product	sub_product	issue	sub_issue	co

'74,407'

In [10]:
question5 = "what are the unique channels by which complaints are submmited?"
agent_executor.run(question5)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mcomplaints, credit_card_complaints[0m
Thought:[32;1m[1;3mI can query the "complaints" table to find the unique channels by which complaints are submitted.
Action: schema_sql_db
Action Input: "complaints"[0m
Observation: [33;1m[1;3mCREATE TABLE spark_catalog.langchain_example.complaints (
  date_received STRING,
  product STRING,
  sub_product STRING,
  issue STRING,
  sub_issue STRING,
  consumer_complaint_narrative STRING,
  company_public_response STRING,
  company STRING,
  state STRING,
  zip_code STRING,
  tags STRING,
  consumer_consent_provided STRING,
  submitted_via STRING,
  date_sent_to_company STRING,
  company_response_to_consumer STRING,
  timely_response STRING,
  consumer_disputed STRING,
  complaint_id STRING)
;

/*
3 rows from complaints table:
date_received	product	sub_product	issue	sub_issue	consumer_complaint_narrative	com

InvalidRequestError: This model's maximum context length is 4097 tokens. However, your messages resulted in 33751 tokens. Please reduce the length of the messages.