# Querying a Database with Natural Language
## Experimentation with Weights & Biases 

In this notebook, we will use LLMs to generate SQL with natural language. The development process of LLM apps requires experimentation, for example with chain architecture and prompt engineering. We will use W&B Prompts Tracer to log our experiments and debug errors. After running the code, you should see a screen like this one in your W&B dashboard: 

![prompts.jpg](prompts.jpg)

### Data
- [TPCH_SF1](https://docs.snowflake.com/en/user-guide/sample-data-tpch) - Contains data related to **orders, customers, suppliers, and inventory** in a manufacturing and distribution business environment.
- See [TPC Benchmark H](https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf) for details
- Available in Snowflake or via [SQLite download from here](https://github.com/lovasoa/TPCH-sqlite/releases/tag/v1.0)

## Setup

In [None]:
!pip install wandb
!pip install openai 
!pip install langchain==v0.0.147

In [None]:
import os
from getpass import getpass
from types import SimpleNamespace

from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import TransformChain, LLMChain, SequentialChain

from utils import SQLConnector

Create a basic config

In [None]:
config = SimpleNamespace(
    model_name="text-davinci-003",
    WANDB_PROJECT="mt-pocono",
    WANDB_ENTITY=None, # Your W&B Team if you have one, e.g. "prompt-eng",
    WANDB_JOB_TYPE="production",
    SNOWFLAKE_WAREHOUSE='COMPUTE_WH',
    SNOWFLAKE_DATABASE='SNOWFLAKE_SAMPLE_DATA',
    SNOWFLAKE_DATABASE_PREFIX='TPCH_SF1',
    SNOWFLAKE_SCHEMA='INFORMATION_SCHEMA',
    SQLITE_DB_PATH='data/TPC-H-small.db'  # Downloaded from https://github.com/lovasoa/TPCH-sqlite
)

Configure OpenAI api key

In [None]:
if os.getenv("OPENAI_API_KEY") is None:
  if any(['VSCODE' in x for x in os.environ.keys()]):
    print('Please enter password in the VS Code prompt at the top of your VS Code window!')
  os.environ["OPENAI_API_KEY"] = getpass("Paste your OpenAI key from: https://platform.openai.com/account/api-keys\n")

assert os.getenv("OPENAI_API_KEY", "").startswith("sk-"), "This doesn't look like a valid OpenAI API key"
print("OpenAI API key configured")

Configure Database
- You can select either 'sqlite' or 'snowflake'

In [None]:
# Set whether you're using Snowflake or SQLite database file in ./data
config.DB_TYPE = 'sqlite'  # 'sqlite' or 'snowflake'

# If using Snowflake, set your Snowflake credentials here
config.SNOWFLAKE_PASSWORD = os.environ.get('SNOWFLAKE_PASSWORD')
config.SNOWFLAKE_ACCOUNT = os.environ.get('SNOWFLAKE_ACCOUNT')  # ORG-ACCOUNT
config.SNOWFLAKE_USER = os.environ.get('SNOWFLAKE_USER')


## Start W&B Monitoring

In [None]:
from wandb.integration.langchain import WandbTracer

WandbTracer.init({"project": config.WANDB_PROJECT, "entity": config.WANDB_ENTITY})

### Set Up Database
- Set either `DB_TYPE = 'sqlite'` or `DB_TYPE = 'snowflake'` 
- Connect to our SQL database
- Pull the database schema for the relevant Tables, this will be used as context

In [None]:
sql_conn = SQLConnector(config, db_type=config.DB_TYPE)  # db_type can be 'sqlite' or 'snowflake'
# sql_conn(f"select * from {config.SNOWFLAKE_DATABASE_PREFIX}.ORDERS limit 1")  # Test the SQL connection

# Get the schema for every Table in the SQLite database
if config.DB_TYPE == 'sqlite':
    schema_str = sql_conn.get_schema(database_name="TPC-H-small", verbose=False)
elif config.DB_TYPE == 'snowflake':
    schema_str = sql_conn.get_schema(config.SNOWFLAKE_DATABASE, config.SNOWFLAKE_DATABASE_PREFIX, verbose=False)

# schema_str

## 1. Question -> SQL generation
- Add basic schema info about a limited set of Tables to a simple prompt
- Call SQL on Snowflake and log the success/fail result  

In [None]:
# Chain 1: Generate SQL query from a user question
llm = OpenAI(openai_api_key=os.environ.get('OPENAI_API_KEY'), 
             model_name = "text-davinci-003",
             temperature=0, 
             verbose=True)

template = f"Here is a {config.DB_TYPE} database schema: {{schema_str}}.{{question}}"

generate_sql_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["schema_str", "question"], template=template), 
    output_key="sql",
    verbose=True)

# Chain 2: Run the SQL query
def run_sql(inputs: dict) -> dict:
    return {"sql_result": sql_conn(inputs["sql"])}

run_sql_chain = TransformChain(
    input_variables=["sql"], 
    output_variables=["sql_result"], 
    transform=run_sql, 
    verbose=True)

# Wrap the two chains into a SequentialChain
sql_chain = SequentialChain(
    chains=[generate_sql_chain, run_sql_chain], 
    input_variables=["schema_str", "question"], 
    output_variables=["sql_result", "sql"], 
    verbose=True)

### Run the Chain

In [None]:
# question = f"Write a {config.DB_TYPE} sql query to find the most recent ship date for every customer"
question = f"Write a {config.DB_TYPE} sql query to find the id of the first product we sent every customer, only return the first 10 rows"

- Lets try run the chain with our basic user question.
- If there is an error we can inspect what happened in Weights & Biases

In [None]:
try:
    sql_chain({"question": question, "schema_str": schema_str})
except Exception as e:
    print(f'\nError running the chain:\n{e}')

#### Lets run a few more queries

In [None]:
questions = [
        "Find the top 10 customers who have spent the most money",
        "Get my last 10 orders",
        "What is my best performing region?",
        "Find the top 1 customer who has spent the most money"
    ]

outputs = []
for q in questions:
    try:
        outputs.append(sql_chain({"question": q, "schema_str": schema_str}))
    except Exception as e:
        print(f'\nError running the chain:\n{e}')

# outputs

## LLM Self Correction
- The output SQL is malformed, can we get the LLM to correct itself?
- We can see that in some cases, the LLM tries to complete the question before generating the SQL
- Lets explore whether we get get valid SQL from either:
  - (A) using the LLM to simply clean up the output text, fixing the SQL and removing extraneous characters
  - (B) using the LLM to just clarify the users input to generate the correct SQL or
- Lets chain these calls together using LangChain

In [None]:
# Chain Step 2: Cleanup and Format SQL: {raw_sql} -> {clean_sql}
clean_sql_template = f"""Please correct any syntax errors in the following SQL and format it nicely: 

{{raw_sql}}

Correct SQL:"""

generate_sql_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["schema_str", "question"], template=template), 
    output_key="raw_sql",
    verbose=True)

clean_sql_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["raw_sql"], template=clean_sql_template), 
    output_key="sql",
    verbose=True)

run_sql_chain = TransformChain(
    input_variables=["sql"], 
    output_variables=["sql_result"], 
    transform=run_sql, 
    verbose=True)

# Wrap the two chains into a SequentialChain
sql_chain = SequentialChain(
    chains=[generate_sql_chain, clean_sql_chain, run_sql_chain], 
    input_variables=["schema_str", "question"], 
    output_variables=["sql_result", "sql"], 
    verbose=True)

Run the chain

In [None]:
try:
    clean_sql_output = sql_chain({"question": question, "schema_str": schema_str})
except Exception as e:
    print(f'\nError running the chain:\n{e}')
    clean_sql_output = None

# clean_sql_output

## LLM Question Clarification

Lets try another approch; we'll ask the LLM to clarify the users' question, and see if the resulting SQL code can be run without any cleanup

In [None]:
# Chain Step 1.b: Clarify the users question: {user_input} -> {clarified_user_input}
clarify_template = f"""Please re-write this user request, if needed, to better clarify the SQL question they are asking.\
    Please also make sure to include the word "sql" in the question. Add any additional context you think is necessary.\
    Add any punctuation you think is necessary.:

{{raw_question}}

A better request would be:"""


clarify_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["raw_question"], template=clarify_template), 
    output_key="question",
    verbose=True)

generate_sql_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["schema_str", "question"], template=template), 
    output_key="sql",
    verbose=True)

run_sql_chain = TransformChain(
    input_variables=["sql"], 
    output_variables=["sql_result"], 
    transform=run_sql, 
    verbose=True)

# Wrap the two chains into a SequentialChain
sql_chain = SequentialChain(
    chains=[clarify_chain, generate_sql_chain, run_sql_chain], 
    input_variables=["schema_str", "raw_question"], 
    output_variables=["sql_result", "sql"], 
    verbose=True)

Run the chain

In [None]:
try:
    clarify_output = sql_chain({"raw_question": question, "schema_str": schema_str})
except Exception as e:
    print(f'\nError running the chain:\n{e}')
    clarify_output = None

# clarify_output

## 3. Iterate on the Prompt Template

Taking our learnings back to improve the original prompt template and reduce the number of calls to the LLM service

In [None]:
# For reference, this was our old template
# template = f"Here is a snowflake database schema: {{schema_str}}.{{question}}"

In [None]:
new_template = f"""You are a data analyst working on a {config.DB_TYPE} database with the following schema: {{schema_str}}

Please produce a sql query to answer the following question from a colleague in the business: {{question}}

Please ensure to use only correct SQL syntax without errors or strange punctuation. Only return SQL and please format it nicely: 

Correct SQL:
"""

Re-instantiate our SQL chain with the new prompt template

In [None]:
generate_sql_chain = LLMChain(
    llm=llm, 
    prompt=PromptTemplate(input_variables=["schema_str", "question"], template=new_template), 
    output_key="sql",
    verbose=True)

run_sql_chain = TransformChain(
    input_variables=["sql"], 
    output_variables=["sql_result"], 
    transform=run_sql, 
    verbose=True)

# Wrap the two chains into a SequentialChain
sql_chain = SequentialChain(
    chains=[generate_sql_chain, run_sql_chain], 
    input_variables=["schema_str", "question"], 
    output_variables=["sql_result", "sql"], 
    verbose=True)

In [None]:
# Using the same question as before
question = f"Write a {config.DB_TYPE} sql query to find the id of the first product we sent every customer, only return the first 10 rows"

Run the chain

In [None]:
better_prompt_output = sql_chain({"question": question, "schema_str": schema_str})
# better_prompt_output

### Testing on More User Queries

In [None]:
questions = [
        "Find the top 10 customers who have spent the most money",
        "Get my last 10 orders",
        "What is my best performing region?",
        "Find the top 1 customer who has spent the most money"
    ]

outputs = []
for q in questions:
    outputs.append(sql_chain({"question": q, "schema_str": schema_str}))
# # outputs

#### Finally, once you're finished, it is best practice to call `WandbTracer.stop_watch` to close the wandb process

In [None]:
WandbTracer.finish()