In [68]:
import requests
import boto3
import time
import awswrangler as wr

client = boto3.client("athena")

In [28]:
API_URL = "https://yhkqe6os38yw017b.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
	"Authorization": "Bearer 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",
	"Content-Type": "application/json"
}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()
	
output = query({
	"inputs": "Can you please let us know more details about your ",
})

In [29]:
output

[{'generated_text': '1970 Chevro Chevelle SS 454.\n\n## 1'}]

Before we can start experimenting with text-to-SQL, I'm going to implement a way to get the Athena table's DDL so that I don't have to manually enter it in the prompt.

In [57]:
def get_ddl(table_name):
    queryStart = client.start_query_execution(
        QueryString = "SHOW CREATE TABLE {}".format(table_name),
        QueryExecutionContext = {
            'Database': 'capstone_v3'
        },
        ResultConfiguration = { 'OutputLocation': 's3://aws-athena-query-results-820103179345-us-east-1/text2sql/'}
    )
    time.sleep(5)
    queryExecution = client.get_query_execution(QueryExecutionId=queryStart['QueryExecutionId'])
    results = client.get_query_results(QueryExecutionId=queryStart['QueryExecutionId'])
        
    ddl = ""
    for i in range(len(results['ResultSet']['Rows'])):
        ddl += results['ResultSet']['Rows'][i]['Data'][0]['VarCharValue']
    ddl = ddl.split(')')[0]
    
    return ddl

In [58]:
pred_ddl = get_ddl('predictions')

In [60]:
telemetry_ddl = get_ddl('telemetry_extended_v3')

In [61]:
question = "Which machines had the highest pressure and when?"

In [62]:
prompt = """### Task
Generate a SQL query to answer the following question:
`{question}`

### Database Schema
This query will run on a database whose schema is represented in this string:
`{pred_ddl}`;

`{telemetry_ddl}`;

-- the prediction table's columns each represent the predictions of a vehicle's risk scores
-- the telemetry table contains a timeseries of machines along with metrics collected for each record
;

### SQL
Given the database schema, here is the SQL query that answers `{question}`:
```sql
""".format(question=question, pred_ddl=pred_ddl, telemetry_ddl=telemetry_ddl)

In [63]:
print(prompt)

### Task
Generate a SQL query to answer the following question:
`Which machines had the highest pressure and when?`

### Database Schema
This query will run on a database whose schema is represented in this string:
`CREATE EXTERNAL TABLE `predictions`(  `m_0000` double,   `m_0001` double,   `m_0002` double,   `m_0003` double,   `m_0004` double,   `m_0005` double,   `m_0006` double,   `m_0007` double,   `m_0008` double,   `m_0009` double,   `m_0010` double,   `m_0011` double,   `m_0012` double,   `m_0013` double,   `m_0014` double,   `m_0015` double,   `m_0016` double,   `m_0017` double,   `m_0018` double,   `m_0019` double,   `m_0020` double,   `m_0021` double,   `m_0022` double,   `m_0023` double,   `m_0024` double`;

`CREATE EXTERNAL TABLE `telemetry_extended_v3`(  `timestamp` timestamp,   `speed` double,   `temperature` double,   `pressure` double,   `machineid` string,   `speed_difference` double`;

-- the prediction table's columns each represent the predictions of a vehicle's ris

In [64]:
output = query({
	"inputs": prompt,
    "parameters": {'max_new_tokens': 500, "top_p": 0.1, "temperature": 0.1}
})

In [65]:
output_query = output[0]['generated_text']

In [76]:
output_query = output_query.replace('\n','')
output_query = output_query.replace('```','')
output_query

'SELECT machineid, MAX(pressure) AS max_pressure, MAX(timestamp) AS max_timestamp FROM telemetry_extended_v3 GROUP BY machineid ORDER BY max_pressure DESC NULLS LAST;'

In [77]:
queryStart = client.start_query_execution(
    QueryString = output_query,
    QueryExecutionContext = {
        'Database': 'capstone_v3'
    },
    ResultConfiguration = { 'OutputLocation': 's3://aws-athena-query-results-820103179345-us-east-1/text2sql/'}
)
time.sleep(5)
queryExecution = client.get_query_execution(QueryExecutionId=queryStart['QueryExecutionId'])
results = client.get_query_results(QueryExecutionId=queryStart['QueryExecutionId'])

In [79]:
df = wr.athena.read_sql_query(sql=output_query, database="capstone_v3")

In [82]:
df = df.head(15)

In [86]:
question = "Which machines had the highest pressure and when?"

As we can see, the interpretation capabilities of the SQL fine-tuned LLM are pretty poor. The reason for this is that this is a base open-source LLM fine-tuned on a natural language and SQL query dataset. Its chat capabilities are nowhere near GPT-4's capabilities. 

In [99]:
prompt = "Interpret the following SQL query results: " + str(df)

In [100]:
output = query({
	"inputs": prompt,
    "parameters": {'max_new_tokens': 500, "top_p": 0.1, "temperature": 0.1}
})

In [101]:
output

[{'generated_text': '\n15    M_0009       2587.57 2023-10-31 18:25:59\n16    M_0012       2585.05 2023-10-31 18:30:59\n17    M_0013       2584.05 2023-10-31 18:31:59\n18    M_0014       2583.05 2023-10-31 18:32:59\n19    M_0016       2582.05 2023-10-31 18:33:59\n20    M_0019       2581.05 2023-10-31 18:34:59\n21    M_0022       2579.05 2023-10-31 18:36:59\n22    M_0024       2578.05 2023-10-31 18:37:59\n23    M_0025       2577.05 2023-10-31 18:38:59\n24    M_0026       2576.05 2023-10-31 18:39:59\n25    M_0027       2575.05 2023-10-31 18:40:59\n26    M_0028       2574.05 2023-10-31 18:41:59\n27    M_0029       2573.05 2023-10-31 18:42:59\n28    M_'}]