## Set up a cross account profile(Experiments)

Similar to the preceding stages, conducting cross-account testing can be achieved by trying to access a different Sagemaker Endpoint. 

However, modifying your Sagemaker client to enable calling the Amazon Sagemaker endpoint from various AWS accounts may present certain limitations and security risks. These may involve concerns regarding permissions and access control, as well as potential security vulnerabilities associated with cross-account communication. As a result, it is advisable to exercise prudence and restrict the utilization of this feature solely for testing purposes.

![alt sagemaker terminal](sagemaker_terminal.png)

Follow the instructions provided by the workshop instructor.

To call the endpoint from another AWS account, you need to create an AWS credentials profile. For instance, you can set up a profile named `cross_account_endpoint`

```commandline
sh-4.2$ aws configure --profile cross_account_endpoint
AWS Access Key ID [None]: [Your Access ID]
AWS Secret Access Key [None]: [Your Secret Access Key]
Default region name [None]: us-east-1
Default output format [None]: json
```

Confirm that your profile has been set up successfully.

In [None]:
!aws configure get region --profile cross_account_endpoint

## Choose up your endpoint name

Please either copy your own endpoint name.

### Create a Sagemaker Client

Use  your current AWS credentials

In [None]:
import boto3
ENDPOINT_NAME = '[Sagemaker Endpoint You just deployed]'
client = boto3.client("runtime.sagemaker")

You can also choose to use the account provided by other profile. e.g. `cross_account_endpoint`

Please uncomment the follow cell if want to try it.

In [None]:
import boto3
ENDPOINT_NAME = '[Sagemaker Endpoint on the cross account]'
session = boto3.Session(profile_name='cross_account_endpoint')
client = session.client("runtime.sagemaker",)

In [None]:
import json

def query_endpoint_and_parse_response(payload_dict, endpoint_name):
    encoded_json = json.dumps(payload_dict).encode("utf-8")

    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )

    return json.loads(response['Body'].read().decode())[0]['generated_text']



## Set up model parameters


The following parameters are available for controlling text generation using the GenerationConfig class:

- do_sample (bool, optional, defaults to False): Determines whether to use sampling or greedy decoding.
- temperature (float, optional, defaults to 1.0): Modulates the next token probabilities.
- max_new_tokens (int, optional): Sets the maximum number of tokens to generate, excluding those in the prompt.
- top_k (int, optional, defaults to 50): Sets the number of highest probability vocabulary tokens to keep using top-k filtering.
- top_p (float, optional, defaults to 1.0): When set to a float less than 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.

For a complete list of available parameters and their descriptions, refer to the GenerationConfig class documentation at https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation.

In [None]:
parameters = {
    "max_new_tokens": 200,
    "top_k": 5,
    "top_p": .15,
    "do_sample": True,
    "temperature": 0.01
}


## Prompt with layman inputs

In [None]:
prompt_data = """
I have a table called patient with fields ID, AGE, WEIGHT, HEIGHT. 
Write me a SQL Query which will return the entry with the highest age

"""  #If you'd like to try your own prompt, edit this parameter!

In [None]:
payload = {"inputs": prompt_data, "parameters": parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [None]:
print(f"Result: {generated_texts}")

## Prompt with Table Schema

In [None]:
import json

payload = """You are an export of Presto Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
	transaction_date DATE COMMENT 'Transaction date',
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	price DOUBLE COMMENT 'The price of the product'
)

Question: What is total sale amount of Fruits
SQLQuery:

"""


In [None]:
payload = {"inputs": payload, "parameters": parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [None]:
print(f"Result: {generated_texts}")

### Another example

Can we join a table?

In [None]:
payload = """
You are an export of MySQL Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
    transaction_date DATE COMMENT 'the transaction date in the format yyyy-mm-dd'
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	sales_amount DOUBLE COMMENT 'The price of the product'
)

Your Table users schema as follows

CREATE EXTERNAL TABLE users (
	user_id STRING COMMENT 'user id',
	name STRING COMMENT User name
)

Question: What is total purchase done by "John"
SQLQuery:
"""

In [None]:
payload = {"inputs": payload, "parameters": parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [None]:
print(f"Result: {generated_texts}")