## Set up your endpoint name

Please either copy your own endpoint name or follow the instructions provided by the workshop instructor.

In [172]:
ENDPOINT_NAME = 'huggingface-pytorch-tgi-inference-2023-07-16-05-23-44-657'

In [185]:
import boto3
import json

def query_endpoint_and_parse_response(payload_dict, endpoint_name):
    encoded_json = json.dumps(payload_dict).encode("utf-8")
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
   
    return json.loads(response['Body'].read().decode())[0]['generated_text']



## Set up model parameters


In [188]:
parameters = {
    "max_new_tokens": 200,
    "top_k": 5,
    "top_p": .15,
    "do_sample": True,
    "temperature": 0.01
}


## Prompt with layman inputs

In [189]:
prompt_data ="""
I have a table called patient with fields ID, AGE, WEIGHT, HEIGHT. 
Write me a SQL Query which will return the entry with the highest age

"""#If you'd like to try your own prompt, edit this parameter!

In [190]:
payload = {"inputs": prompt_data, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [191]:
print(f"Result: {generated_texts}")

Result: SELECT ID FROM patient WHERE AGE > (SELECT max(AGE) FROM patient)


## Prompt with Table Schema

In [213]:
import json
import boto3
sagemaker_client = boto3.client('sagemaker-runtime')
payload = """You are an export of MySQL Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
	transaction_date DATE COMMENT 'Transaction date',
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	price DOUBLE COMMENT 'The price of the product'
)

Question: What is total sale amount of Fruits
SQLQuery:

"""


In [214]:
payload = {"inputs": payload, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [215]:
print(f"Result: {generated_texts}")

Result: SELECT sum(price) FROM sales WHERE product = 'Fruits'


Another example

In [222]:
payload ="""
You are an export of MySQL Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
	transaction_date DATE COMMENT 'Transaction date',
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	price DOUBLE COMMENT 'The price of the product'
)

Question: What is total sales of "Fruits" on 2022-10-05
SQLQuery:
"""

In [223]:
payload = {"inputs": payload, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [224]:
print(f"Result: {generated_texts}")

Result: SELECT total_sales FROM sales WHERE product = "Fruits" AND transaction_date = "202210-05"
