In [1]:
# ! pip install langchain streamlit==1.24.0 snowflake-connector-python ai21
# ! pip install -U pyarrow

In [1]:
# Enter your AI21 API Key
from getpass import getpass

AI21_API_KEY = getpass()

········


In [12]:
# Define Snowflake Credentials, table and column details here
sf_dict = {'user': 'username', 'password': 'password', 'account': 'account',
           'role': 'ACCOUNTADMIN', 'warehouse': 'COMPUTE_WH', 'database': 'SNOWFLAKE_SAMPLE_DATA', 'schema': 'TPCDS_SF10TCL'}
table = '"SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF10TCL"."CATALOG_SALES"'
columns = ["CS_SOLD_DATE_SK", "CS_SOLD_TIME_SK", "CS_SHIP_DATE_SK", "CS_BILL_CUSTOMER_SK", "CS_BILL_CDEMO_SK", 
           "CS_BILL_HDEMO_SK", "CS_BILL_ADDR_SK", "CS_SHIP_CUSTOMER_SK", "CS_SHIP_CDEMO_SK", "CS_SHIP_HDEMO_SK", 
           "CS_SHIP_ADDR_SK", "CS_CALL_CENTER_SK", "CS_CATALOG_PAGE_SK", "CS_SHIP_MODE_SK", "CS_WAREHOUSE_SK", "CS_ITEM_SK", 
           "CS_PROMO_SK", "CS_ORDER_NUMBER", "CS_QUANTITY", "CS_WHOLESALE_COST", "CS_LIST_PRICE", "CS_SALES_PRICE", 
           "CS_EXT_DISCOUNT_AMT", "CS_EXT_SALES_PRICE", "CS_EXT_WHOLESALE_COST", "CS_EXT_LIST_PRICE", "CS_EXT_TAX", 
           "CS_COUPON_AMT", "CS_EXT_SHIP_COST", "CS_NET_PAID", "CS_NET_PAID_INC_TAX", "CS_NET_PAID_INC_SHIP", 
           "CS_NET_PAID_INC_SHIP_TAX", "CS_NET_PROFIT"]

In [2]:
from langchain.llms import AI21
import streamlit as st
from langchain import PromptTemplate, LLMChain
from snowflake.connector import connect
import re
import json
import urllib
import pandas as pd
import base64

Failed to import ArrowResult. No Apache Arrow result set format can be used. ImportError: No module named 'snowflake.connector.arrow_iterator'


In [3]:
llm = AI21(ai21_api_key=AI21_API_KEY)

In [4]:
# Execute Snowflake Query
def execute_snowflake_query(query, sf_dict):
    # Connect to Snowflake
    snowflake_conn = connect(
        user=sf_dict['user'],
        password=sf_dict['password'],
        account=sf_dict['account'],
        warehouse=sf_dict['warehouse'],
        database=sf_dict['database'],
        schema=sf_dict['schema']
    )

    # Execute the Snowflake query
    cursor = snowflake_conn.cursor()
    cursor.execute(query)
    result = cursor.fetchall()
    df = pd.DataFrame(result, columns=[desc[0] for desc in cursor.description])

    # Disconnect from Snowflake
    snowflake_conn.close()
    return df

In [5]:
def format_snowflake_query(query, column_names):
    # Remove new line characters & leading/trailing whitespace
    
    for chars in ["`", "sql"]:
        query = query.replace("`", "")
        
    formatted_query = query.replace("\n", " ").strip()

    # Add a semicolon at the end of the query if missing
    if not formatted_query.endswith(';'):
        formatted_query += ';'
    try:
        # Execute Query
        query_result = execute_snowflake_query(formatted_query, sf_dict)
    except:
        # Enquote column names and table name
        formatted_query = enquote_identifiers(formatted_query, column_names)
    return formatted_query


def find_matching_columns(query, column_names):
    # Construct a regular expression pattern to match the column names
    pattern = r'\b(?:' + '|'.join(column_names) + r')\b'

    # Find all matches of the pattern in the query
    matches = re.findall(pattern, query, re.IGNORECASE)
    return list(set(matches))


def replace_word(query, word_to_replace, replacement):
    pattern = r"(?<![A-Za-z])" + re.escape(word_to_replace) + r"(?![A-Za-z])"
    replaced_query = re.sub(pattern, replacement, query)
    return replaced_query


def remove_extra_spaces(query):
    cleaned_query = re.sub(r'\s+', ' ', query)
    return cleaned_query.strip()


def enquote_identifiers(query, column_names):
    def enquote(column_name):
        if column_name[0] == ('\"') and column_name[-1] == ('\"'):
            return column_name
        else:
            return f'"{column_name}"'

    matched_columns = find_matching_columns(query, column_names)
    # Enquote column names
    for column_name in matched_columns:
        formatted_column_name = enquote(column_name)
        query = replace_word(query, column_name, formatted_column_name)

    query = remove_extra_spaces(query)
    return query

In [11]:
def results_from_snowflake(question, sf_dict, table, columns):
    '''
    This function will take the user question and return the result from Snowflake table.
    '''
    llm_code = llm
    llm_text = llm

    # SQL Prompt
    query_template = """
    Generate a Snowflake query for the below table, schema and question:\n
    - Table name {Table}\n
    - List of columns: {Columns}\n
    - Question: {question}\n

    Query should adhere to the following rules:\n
    - Column names should be case sensitive.\n
    - Don't Order the results if not necessary.\n
    - Ensure proper filtering conditions are applied to retrieve the desired subset of data.\n
    - Don't use groupby and where if not needed.\n
    - Table names to be wrapped in double quotes.\n
    - Enquote database and schema also in double quotes.\n
    - Don't add any additional keyword like copy and all if not added.\n
    - Don't add any braces unnecessarily.\n
    - Ensure groupby or where if used are applied properly.\n
    - Don't add any numbers unnecessarily.\n
    Please provide the query to retrieve the requested information:

    """

    sql_prompt = PromptTemplate(template=query_template, input_variables=["Table", "question", "Columns"])

    # Generate SQL Query
    def get_sql_query(table, question, columns):
        llm_chain = LLMChain(prompt=sql_prompt, llm=llm_code)
        response = llm_chain.run({"Table": table, "question": question, "Columns": columns})
        return response

    query = get_sql_query(table, question, columns)
    # Format Query
    formatted_query = format_snowflake_query(query, columns)

    # Execute Query
    try:
        query_result = execute_snowflake_query(formatted_query, sf_dict)
    except Exception as e:
        query_result = f"Error: {e}"

    # Generate answer from questions and query result
    def convert_sql_response_to_english(question, query_result):

        # Answer Prompts
        answer_template1 = """
        Given a question and the result of an SQL query, your task is to convert them into plain English format. You are to provide a human-readable explanation that accurately conveys the meaning of the question and the information contained in the query result. \n
        
        Example 1:
        Question: "What is the average total delivery cases?"
        Query Result: [("QuarterTotalDeliveryCases": 256)]

        Plain English Representation: "The average of total delivery cases is 256."

        Example 2:
        Question: "What are the total number of row counts?"
        Query Result: [("Count": "660")]

        Plain English Representation: "The total number of row counts is 660."

        Example 3:
        Question: "What is the Count of unique 'TrademarkID'?"
        Query Result: [("TrademarkID": 50)]

        Plain English Representation: "The count of unique 'TrademarkID' is 50." \n

        Question: {question}\n
        Query result: {query_result}\n\n
        
        Plain English Explanation:
        """

        answer_template2 = """
            Translate given question into a statement.\n\n
            Example 1-\n
            Question: What is the sales for brands last year?\n
            Answer: Sales for brands in last year is:\n\n
            Example 2-\n
            Question: What is the weekly avergae sale for this year?\n
            Answer: Weekly average sale for this year is:\n\n
            Example 3-\n
            Question: Show all unique trademark ids starting with `A` along with it's description.\n
            Answer: Below is a list of all unique trademark ids starting with `A` along with it's description:\n\n

            Question: {question}\n\n\n
            Answer:
            """
        if len(query_result) == 1:
            answer_prompt = PromptTemplate(template=answer_template1, input_variables=["query_result", "question"])
            llm_chain = LLMChain(prompt=answer_prompt, llm=llm_text)
            response = llm_chain.run({"query_result": query_result, "question": question})
        else:
            answer_prompt = PromptTemplate(template=answer_template2, input_variables=["question"])
            llm_chain = LLMChain(prompt=answer_prompt, llm=llm_text)
            response = llm_chain.run({"query_result": query_result, "question": question})
        return response

    response = convert_sql_response_to_english(question, query_result)
    return formatted_query, response, query_result


def return_schema_df(db_details_query, sf_dict):
    df = execute_snowflake_query(db_details_query, sf_dict)
    return df

In [13]:
question1 = "What is the average of sales price?"
question2 = "What are the sold date wise averages of quantity?"
question3 = "How many unique bill addresses are there?"
question4 = "What is max discount amount?"
question5 = "What are 5 lowest wholesale cost?"

In [14]:
formatted_query1, response1, query_result1 = results_from_snowflake(question1, sf_dict, table, columns)
print("Question:", question1, "\n","Snowflake Query:", formatted_query1, "\n", "Query Result:", query_result1, "\n","Response:", response1)

Question: What is the average of sales price? 
 Snowflake Query: (SELECT AVG(CS_SALES_PRICE) FROM "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF10TCL"."CATALOG_SALES"); 
 Query Result:   AVG(CS_SALES_PRICE)
0         50.49307022 
 Response: 
        The average sales price is 50.49307022


In [16]:
formatted_query2, response2, query_result2 = results_from_snowflake(question2, sf_dict, table, columns)
print("Question:", question2, "\n","Snowflake Query:", formatted_query2, "\n", "Query Result:", query_result2, "\n","Response:", response2)

Question: What are the sold date wise averages of quantity? 
 Snowflake Query: SEL AVG(Quantity) , Sold Date FROM (SELECT "CS_SOLD_DATE_SK" , "CS_SOLD_TIME_SK" , "CS_QUANTITY" FROM table_name) GROUP BY "CS_SOLD_DATE_SK"; 
 Query Result: Error: 001003 (42000): SQL compilation error:
syntax error line 1 at position 0 unexpected 'SEL'. 
 Response:  The sold date wise averages of quantity are:


In [18]:
formatted_query3, response3, query_result3 = results_from_snowflake(question4, sf_dict, table, columns)
print("Question:", question4, "\n","Snowflake Query:", formatted_query3, "\n", "Query Result:", query_result3, "\n","Response:", response3)

Question: What is max discount amount? 
 Snowflake Query: (SELECT MAX("CS_EXT_DISCOUNT_AMT") FROM "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF10TCL"."CATALOG_SALES"); 
 Query Result:   MAX("CS_EXT_DISCOUNT_AMT")
0                   29982.00 
 Response: 
        The maximum discount amount is $29,982.
