In [3]:
!pip install openai awswrangler --quiet

import pandas as pd
import numpy as np
import datetime

import os
import openai

import awswrangler as wr

import time
import boto3

import requests
import json

In [4]:
# Setup of GPT-4 configurations for quizzing
key_file = open('/home/ec2-user/SageMaker/open_key.txt')
open_key = key_file.read()

os.environ["OPENAI_API_KEY"] = open_key
openai.api_key = os.environ['OPENAI_API_KEY']

In [5]:
from openai import OpenAI
client = OpenAI()

athena = boto3.client("athena")
runtime = boto3.client("runtime.sagemaker")

In [6]:
# Setup of sqlcoder-7b configuration for quizzing
sql_endpoint = "https://yhkqe6os38yw017b.us-east-1.aws.endpoints.huggingface.cloud"

headers = {
	"Authorization": "Bearer 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",
	"Content-Type": "application/json"
}

In [28]:
# Setup of llama2-13b-chat endpoint for quizzing
llama_endpoint = "interpreter-endpoint"

In [15]:
# Helper functions

def get_ddl(table_name):
    """ Get the DDL (table definition) of the Athena table to feed into the
        Text-to-SQL prompt
    """
    queryStart = athena.start_query_execution(
        QueryString = "SHOW CREATE TABLE {}".format(table_name),
        QueryExecutionContext = {
            'Database': 'capstone_v3'
        },
        ResultConfiguration = { 'OutputLocation': 's3://aws-athena-query-results-820103179345-us-east-1/text2sql/'}
    )
    # Give enough time for the SQL query to execute
    time.sleep(5)
    queryExecution = athena.get_query_execution(QueryExecutionId=queryStart['QueryExecutionId'])
    results = athena.get_query_results(QueryExecutionId=queryStart['QueryExecutionId'])
        
    ddl = ""
    for i in range(len(results['ResultSet']['Rows'])):
        ddl += results['ResultSet']['Rows'][i]['Data'][0]['VarCharValue']
    ddl = ddl.split(')')[0]
    
    return ddl


def ask_gpt(question):
    """ Ask GPT-4 a question
    """
    start = time.time()
    response = client.chat.completions.create(
        model='gpt-4',
        messages=[
            {
                "role": "system",
                "content": "Given the following SQL tables, your job is to write queries given a user's request. Only return SQL queries- no extra text \n {} {};".format(get_ddl('predictions'), get_ddl('telemetry_extended_v3'))
            },
            {
                "role": "user",
                "content": question,
            }
        ],
        temperature=0.1,
        max_tokens=1000,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
        )

    output_query = response.choices[0].message.content
    gpt_text2sql_time = time.time() - start
    print("gpt time: ", str(gpt_text2sql_time))
    
    print(output_query)
    
    context = "Here is the context: We are looking at industrial factory machine data. Speed differences of 232 and over are worse as you get higher. You will be provided a string representation of a dataframe with SQL query results. Your job is to summarize thes results and give recommendations for how to address these issues. Here is the corresponding SQL query: " + output_query

    athena_result = ""
    
    start = time.time()
    try:
        df = wr.athena.read_sql_query(
            sql=output_query,
            database='capstone_v3',
            ctas_approach=True)
        athena_result = str(df.head(15))
    except:
        athena_result = "The query outputted by sqlcoder could not be processed by Athena!"
    
    athena_time = time.time() - start
    print("athena_time: ", str(athena_time))
    
    start = time.time()
    response = client.chat.completions.create(
        model='gpt-4',
        messages=[
            {
                "role": "system",
                "content": context
            },
            {
                "role": "user",
                "content": "Evaluate the following output: {}. If the query was successful, interpret the dataframe. If the query failed, communicate that to the user:".format(athena_result)
            }
        ],
        temperature=0.1,
        max_tokens=1000,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0)
    interpretation_time = time.time() - start
    print("interpretation_time: ", str(interpretation_time))

    return response.choices[0].message.content

def query(payload):
	response = requests.post(sql_endpoint, headers=headers, json=payload)
	return response.json()

def ask_open_source(question):
    
    prompt = """### Task
    Generate a SQL query to answer the following question:
    `{question}`

    ### Database Schema
    This query will run on a database whose schema is represented in this string:
    `{pred_ddl}`;

    `{telemetry_ddl}`;

    -- the prediction table's columns each represent the predictions of a vehicle's risk scores
    -- the telemetry table contains a timeseries of machines along with metrics collected for each record
    ;

    ### SQL
    Given the database schema, here is the SQL query that answers `{question}`:
    ```sql
    """.format(question=question, pred_ddl=get_ddl('predictions'), telemetry_ddl=get_ddl('telemetry_extended_v3'))
    
    output = query({
        "inputs": prompt,
        "parameters": {'max_new_tokens': 500, "top_p": 0.1, "temperature": 0.2}
    })
    
    output_query = output[0]['generated_text']
    output_query = output_query.replace('\n','')
    output_query = output_query.replace('```','')
    
    print(output_query)
    
    try:
        df = wr.athena.read_sql_query(
            sql=question1_open,
            database='capstone_v3',
            ctas_approach=True)
        athena_result = str(df)
    except:
        athena_result = "The query outputted by sqlcoder could not be processed by Athena!"
    
    input_prompt = f"""[INST] <<SYS>>
    Your job is to interpret the results from the following query results: {athena_result}. If the query failed, just say that and nothing else. If the query didn't fail, provide recommendations on how to fix these issues.
    <</SYS>>{question} [/INST]"""
    
    payload = {
        "inputs": input_prompt,
        "parameters": {"max_new_tokens": 500, "top_p": 0.1, "temperature": 0.5}
    }
    
    response = runtime.invoke_endpoint(
        EndpointName=llama_endpoint,
        Body=json.dumps(payload),
        ContentType='application/json')
    
    result = json.loads(response['Body'].read().decode())
    
    return result[0]['generated_text'].split('[/INST]')[1]

### Question 1: Give me the machines that had the highest pressure

In [16]:
question1 = "Give me the machines that had the highest pressure"

In [17]:
question1_gpt = ask_gpt(question1)
print(question1_gpt)

gpt time:  11.452497243881226
SELECT machineid, MAX(pressure) as max_pressure
FROM telemetry_extended_v3
GROUP BY machineid
ORDER BY max_pressure DESC;
athena_time:  4.024524688720703
interpretation_time:  7.920696258544922
The SQL query was successful and it returned a dataframe with two columns: 'machineid' and 'max_pressure'. The dataframe lists the maximum pressure recorded by each machine in descending order. 

The machine with the highest recorded pressure is M_0017 with a pressure of 2694.20. The machine with the lowest maximum pressure in this list is M_0008 with a pressure of 2591.06. 

If high pressure is a concern for the operation of these machines, it would be recommended to investigate the machines with the highest pressures, starting with M_0017. It would be beneficial to understand why these machines are experiencing such high pressures and if there are any common factors among them. 

If the speed differences of 232 and over are worse as you get higher, it would be imp

In [78]:
question1_open = ask_open_source(question1)
print(question1_open)

 SELECT machineid, MAX(pressure) AS max_pressure FROM telemetry_extended_v3 WHERE timestamp BETWEEN '2020-01-01' AND '2020-12-31' GROUP BY machineid ORDER BY max_pressure DESC NULLS LAST;    
  The query outputted by sqlcoder could not be processed by Athena!

The error message from Athena was:

"Invalid SQL: Unknown column 'pressure' in 'where clause'"

This error message indicates that the column 'pressure' does not exist in the table or view being queried.

To fix this issue, you can try the following:

1. Check the table or view definition to ensure that the column 'pressure' exists and is of the correct data type.
2. Verify that the column name is spelled correctly and is not case-sensitive.
3. If the column does not exist, you may need to add it to the table or view definition.

Once you have resolved the issue, you can try running the query again to retrieve the machines with the highest pressure.


### Question 2: Which machines had the highest temperature and when?

In [79]:
question2 = "Which machines had the highest temperature?"

In [80]:
print(ask_gpt(question2))

SELECT machineid, MAX(temperature) as max_temperature
FROM telemetry_extended_v3
GROUP BY machineid
ORDER BY max_temperature DESC;
The query was successful. The dataframe shows the maximum temperature recorded for each machine in the factory. The machines are ordered by their maximum temperature in descending order. 

The machine with the highest recorded temperature is M_0002 and M_0005, both reaching a maximum temperature of 215.00. The machine with the lowest maximum temperature is M_0018 with a temperature of 214.91.

Since the speed differences of 232 and over are worse as you get higher, none of the machines have reached this critical level yet. However, it's important to monitor the machines, especially M_0002 and M_0005, as they have the highest recorded temperatures. 

Recommendations:
1. Regularly monitor the temperature of all machines, especially those with higher maximum temperatures.
3. Regular maintenance and checks should be performed to ensure machines are working effi

In [81]:
print(ask_open_source(question2))

 SELECT machineid, MAX(temperature) AS max_temperature FROM telemetry_extended_v3 WHERE timestamp >= CURRENT_DATE - interval '7 days' GROUP BY machineid ORDER BY max_temperature DESC NULLS LAST;    
  The query outputted by sqlcoder could not be processed by Athena!

The error message suggests that there was a problem with the query itself, rather than an issue with the data or the machine. Without more information, it's difficult to provide specific recommendations on how to fix the issue. However, here are a few general tips that may help:

1. Check the syntax of the query: Make sure that the query is written correctly and that all of the necessary syntax is included.
2. Check the data types of the columns: Ensure that the data types of the columns in the query match the data types of the columns in the table.
3. Check for missing or invalid data: Make sure that all of the required data is present and that it is in the correct format.
4. Check for duplicate or redundant data: Remove 

In [51]:
def ask_7b(llama_endpoint, athena_client, sagemaker_client, question):
    
    prompt = """### Task
    Generate a SQL query to answer the following question:
    `{question}`

    ### Database Schema
    This query will run on a database whose schema is represented in this string:
    `{pred_ddl}`;

    `{telemetry_ddl}`;

    -- the prediction table's columns each represent the predictions of a vehicle's risk scores
    -- the telemetry table contains a timeseries of machines along with metrics collected for each record
    ;

    ### SQL
    Given the database schema, here is the SQL query that answers `{question}`:
    ```sql
    """.format(question=question, pred_ddl=get_ddl('predictions'), telemetry_ddl=get_ddl('telemetry_extended_v3'))
    
    start = time.time()
    output = query({
        "inputs": prompt,
        "parameters": {'max_new_tokens': 500, "top_p": 0.1, "temperature": 0.2}
    })
    
    output_query = output[0]['generated_text']
    output_query = output_query.replace('\n','')
    output_query = output_query.replace('```','')
    output_query = output_query.replace("' '","")
    
    text2sql_time = time.time() - start
    
    start = time.time()
    
    query_success = 0
    
    try:
        df = wr.athena.read_sql_query(
            sql=output_query,
            database='capstone_v3',
            ctas_approach=True)
        athena_result = str(df.head(15))
        query_success = 1
    except:
        athena_result = "The query could not be processed by Athena!"
        
        
        
    query_time = time.time() - start
    

    
    input_prompt = f"""[INST] <<SYS>>
    Your job is to interpret the results from the following query results: {athena_result}. If the query failed, just say that and nothing else. If the query didn't fail, provide recommendations on how to fix these issues.
    <</SYS>>{question} [/INST]"""
    
    start = time.time()
    payload = {
        "inputs": input_prompt,
        "parameters": {"max_new_tokens": 500, "top_p": 0.1, "temperature": 0.5}
    }
    inter_time = time.time() - start
    
    response = sagemaker_client.invoke_endpoint(
        EndpointName=llama_endpoint,
        Body=json.dumps(payload),
        ContentType='application/json')
    
    result = json.loads(response['Body'].read().decode())
    result = result[0]['generated_text'].split('[/INST]')[1]
    
    return {
        "LLM": "sqlcoder-7b",
        "question": question,
        "query": output_query,
        "interpretation": result,
        "text2sql_time": np.round(text2sql_time,2),
        "query_time": np.round(query_time,2),
        "query_success": query_success,
        "interpretation_time": np.round(inter_time,2)
    }

In [52]:
question1 = "what devices have the highest temperatures?"

In [53]:
print(ask_7b(llama_endpoint, athena, runtime, question1))

{'LLM': 'sqlcoder-7b', 'question': 'what devices have the highest temperatures?', 'query': ' SELECT machineid, MAX(temperature) AS max_temperature FROM telemetry_extended_v3 GROUP BY machineid ORDER BY max_temperature DESC NULLS LAST;    ', 'interpretation': "  Based on the query results you provided, the devices with the highest temperatures are:\n\n1. M_0005 with a maximum temperature of 215.00\n2. M_0024 with a maximum temperature of 214.99\n3. M_0012 with a maximum temperature of 214.99\n4. M_0015 with a maximum temperature of 214.99\n5. M_0000 with a maximum temperature of 214.99\n\nAll of these devices have a maximum temperature of 214.99 or higher, indicating that they are operating at high temperatures.\n\nTo fix these issues, you may want to consider the following recommendations:\n\n1. Check the cooling system of these devices to ensure that it is functioning properly and that there are no blockages or malfunctions that could be causing the high temperatures.\n2. Consider upg

In [75]:
def ask_34b(llama_endpoint, athena_client, sagemaker_client, question):
    
    prompt = """### Task
    Generate a SQL query to answer the following question:
    `{question}`

    ### Database Schema
    This query will run on a database whose schema is represented in this string:
    `{pred_ddl}`;

    `{telemetry_ddl}`;

    -- the prediction table's columns each represent the predictions of a vehicle's risk scores
    -- the telemetry table contains a timeseries of machines along with metrics collected for each record
    ;

    ### SQL
    Given the database schema, here is the SQL query that answers `{question}`:
    ```sql
    """.format(question=question, pred_ddl=get_ddl('predictions'), telemetry_ddl=get_ddl('telemetry_extended_v3'))
    
    start = time.time()
    payload = {
        "inputs": prompt,
        "parameters": {"max_new_tokens": 500, "top_p": 0.1, "temperature": 0.5}
    }
    
    response = sagemaker_client.invoke_endpoint(
        EndpointName="text2sql-34b",
        Body=json.dumps(payload),
        ContentType='application/json')
    
    result = json.loads(response['Body'].read().decode())
    output_query = result[0]['generated_text'].split('```sql')[1].split('```')[0]
    
    text2sql_time = time.time() - start
    
    start = time.time()
    
    query_success = 0
    
    try:
        df = wr.athena.read_sql_query(
            sql=output_query,
            database='capstone_v3',
            ctas_approach=True)
        athena_result = str(df.head(15))
        query_success = 1
    except:
        athena_result = "The query could not be processed by Athena!"
        
        
        
    query_time = time.time() - start
    

    
    input_prompt = f"""[INST] <<SYS>>
    Your job is to interpret the results from the following query results: {athena_result}. If the query failed, just say that and nothing else. If the query didn't fail, provide recommendations on how to fix these issues.
    <</SYS>>{question} [/INST]"""
    

    payload = {
        "inputs": input_prompt,
        "parameters": {"max_new_tokens": 500, "top_p": 0.1, "temperature": 0.5}
    }
    
    start = time.time()
    response = sagemaker_client.invoke_endpoint(
        EndpointName=llama_endpoint,
        Body=json.dumps(payload),
        ContentType='application/json')
    inter_time = time.time() - start
    
    result = json.loads(response['Body'].read().decode())
    result = result[0]['generated_text'].split('[/INST]')[1]
    
    return {
        "LLM": "sqlcoder-34b",
        "question": question,
        "query": output_query,
        "interpretation": result,
        "text2sql_time": np.round(text2sql_time,2),
        "query_time": np.round(query_time,2),
        "query_success": query_success,
        "interpretation_time": np.round(inter_time,2)
    }

In [76]:
question1 = "what devices have the highest temperatures?"
print(ask_34b(llama_endpoint, athena, runtime, question1))

{'LLM': 'sqlcoder-34b', 'question': 'what devices have the highest temperatures?', 'query': '\n     SELECT machineid, MAX(temperature) AS max_temperature FROM telemetry_extended_v3 GROUP BY machineid ORDER BY max_temperature DESC;\n    ', 'interpretation': "  Based on the query results you provided, the devices with the highest temperatures are:\n\n1. M_0005 with a maximum temperature of 215.00\n2. M_0002 with a maximum temperature of 215.00\n3. M_0024 with a maximum temperature of 214.99\n\nAll of these devices have a maximum temperature of 215.00, which is the highest temperature recorded in the query results.\n\nThere are no obvious issues with the query results, but it's worth noting that the temperatures are all very close to each other, which may indicate that the temperature readings are not very precise. Additionally, there are several devices with maximum temperatures of 214.99, which is just below the highest temperature recorded. This may indicate that there is some variabil