# Sample Text2SQL Agent Walkthrough

This notebook will walk users through setting up a Text2SQL Agent with the BirdSQL - Mini Dev Dataset (https://github.com/bird-bench/mini_dev)

#### Ensure the latest version of boto3 is shown below

##### If not then run through setup_environment.ipynb in the agents_catalog/0-Notebook-environment/ folder

In [None]:
!pip freeze | grep -E "boto3"

#### Load in environment variables to notebook

In [None]:
# Retrieve import path
%store -r IMPORTS_PATH

# Retrieve account info
%store -r account_id
%store -r region

# Retrieve model lists
%store -r agent_foundation_model

#### Retrieve imports environment variable and bring libraries into notebook

In [None]:
%run $IMPORTS_PATH

In [None]:
# Store the variables in os

random_suffix_1 = uuid.uuid4().hex[:6]
base_bucket_name = f"{'text2sql-agent'}-{account_id}-{random_suffix_1}"
random_suffix_2 = uuid.uuid4().hex[:6]
athena_results_bucket_name = f"{'text2sql-athena-results'}-{account_id}-{random_suffix_2}"
athena_database_name = 'california_schools'

os.environ['REGION'] = region
os.environ['BASE_BUCKET_NAME'] = base_bucket_name
os.environ['ATHENA_RESULTS_BUCKET_NAME'] = athena_results_bucket_name
os.environ['BASE_DIR'] = 'dev_databases'
os.environ['DATABASE_NAME'] = athena_database_name

%store base_bucket_name
%store athena_results_bucket_name
%store athena_database_name

### Retrieve BirdSQL - Mini Dev Dataset

In [None]:
# Download .zip file to local directory

!cd 2-Sample-text2sql-agent
!wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip

## Set up Necessary services to run Text2SQL agent

In order to run the Text2SQL agent, we will need to setup the Athena databases to make SQL queries against. The following script will:
1. Unzip the downloaded folder
2. Create S3 buckets
3. Convert .sqlite files into individual .parquet files for each table
4. Upload to the database s3 bucket
5. Set up appropriate Athena permissions
6. Create databases in Athena

In [None]:
%run data_prep.py

# Create Text2SQL Agent

In [None]:
agent_name = 'sample-text2sql-agent'
agent_description = "Text2SQL agent to run against Bird-SQL Mini-Dev benchmark dataset"
agent_instruction = """
You are an AI Agent specialized in generating SQL queries for Amazon Athena against Amazon S3 .parquet files. 
Your primary task is to interpret user queries, generate appropriate SQL queries, and provide the executed sql 
query as well as relevant answers based on the data. Follow these instructions carefully: 1. Before generating any 
SQL query, use the /getschema tool to familiarize yourself with the data structure. 2. When generating an SQL query: 
a. Write the query as a single line, removing all newline characters. b. Column names must be exactly as they appear 
in the schema, including spaces. Do not replace spaces with underscores. c. Always enclose column names that contain 
spaces in double quotes ("). d. Be extra careful with column names containing special characters or spaces. 
3. Column name handling: a. Never modify column names. Use them exactly as they appear in the schema. 
b. If a column name contains spaces or special characters, always enclose it in double quotes ("). 
c. Do not use underscores in place of spaces in column names. 4. Query output format: 
a. Always include the exact query that was run in your response. Start your response with 
"Executed SQL Query:" followed by the exact query that was run. b. Format the SQL query in a code block 
using three backticks (```). c. After the query, provide your explanation and analysis. 
5. When providing your response: a. Start with the executed SQL query as specified in step 
4. b. Double-check that all column names in your generated query match the schema exactly. 
c. Ask for clarifications from the user if required. 6. Error handling: a. 
If a query fails due to column name issues: - Review the schema and correct any mismatched column names. - 
Ensure all column names with spaces are enclosed in double quotes. - Regenerate the query with corrected column names. - 
Display both the failed query and the corrected query. b. Implement retry logic with up to 3 attempts for failed queries. 
Here are a few examples of generating SQL queries based on a question: 
Question: What is the highest eligible free rate for K-12 students in the schools in Alameda County? 
Executed SQL Query: "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' 
ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1" Question: Please list the zip 
code of all the charter schools in Fresno County Office of Education. Executed SQL Query: "SELECT T2.Zip FROM frpm 
AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T1.`District Name` = 'Fresno County Office of Education' 
AND T1.`Charter School (Y/N)` = 1" Question: Consider the average difference between K-12 enrollment and 15-17 enrollment 
of schools that are locally funded, list the names and DOC type of schools which has a difference above this average. 
Executed SQL Query: "SELECT T2.School, T2.DOC FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode 
WHERE T2.FundingType = 'Locally funded' AND (T1.`Enrollment (K-12)` - T1.`Enrollment (Ages 5-17)`) > 
(SELECT AVG(T3.`Enrollment (K-12)` - T3.`Enrollment (Ages 5-17)`) FROM frpm AS T3 INNER JOIN schools AS T4 ON T3.CDSCode = 
T4.CDSCode WHERE T4.FundingType = 'Locally funded')"
"""

In [None]:
agents = AgentsForAmazonBedrock()

text2sql_agent = agents.create_agent(
    agent_name,
    agent_description,
    agent_instruction,
    agent_foundation_model,
    code_interpretation=False,
    verbose=False
)

text2sql_agent

In [None]:
text2sql_agent_id = text2sql_agent[0]
text2sql_agent_arn = f"arn:aws:bedrock:{region}:{account_id}:agent/{text2sql_agent_id}"

text2sql_agent_id, text2sql_agent_arn

In [None]:
api_schema_string = '''{
    "openapi": "3.0.1",
    "info": {
      "title": "Database schema look up and query APIs",
      "version": "1.0.0",
      "description": "APIs for looking up database table schemas and making queries to database tables."
    },
    "paths": {
      "/getschema": {
        "get": {
          "summary": "Get a list of all columns in the athena database",
          "description": "Get the list of all columns in the athena database table. Return all the column information in database table.",
          "operationId": "getschema",
          "responses": {
            "200": {
              "description": "Gets the list of table names and their schemas in the database",
              "content": {
                "application/json": {
                  "schema": {
                    "type": "array",
                    "items": {
                      "type": "object",
                      "properties": {
                        "Table": {
                          "type": "string",
                          "description": "The name of the table in the database."
                        },
                        "Schema": {
                          "type": "string",
                          "description": "The schema of the table in the database. Contains all columns needed for making queries."
                        }
                      }
                    }
                  }
                }
              }
            }
          }
        }
      },
      "/queryathena": {
        "get": {
          "summary": "API to send query to the athena database table",
          "description": "Send a query to the database table to retrieve information pertaining to the users question. The API takes in only one SQL query at a time, sends the SQL statement and returns the query results from the table. This API should be called for each SQL query to a database table.",
          "operationId": "queryathena",
          "parameters": [
            {
              "name": "query",
              "in": "query",
              "required": true,
              "schema": {
                "type": "string"
              },
              "description": "SQL statement to query database table."
            }
          ],
          "responses": {
            "200": {
              "description": "Query sent successfully",
              "content": {
                "application/json": {
                  "schema": {
                    "type": "object",
                    "properties": {
                      "responseBody": {
                        "type": "string",
                        "description": "The query response from the database."
                      }
                    }
                  }
                }
              }
            },
            "400": {
              "description": "Bad request. One or more required fields are missing or invalid."
            }
          }
        }
      }
    }
  } '''

In [None]:
api_schema = {"payload": api_schema_string}

### Attach Lambda function and create ActionGroup

In [None]:
text2sql_lambda_function_name = "text2sql"
text2sql_lambda_function_arn = f"arn:aws:lambda:{region}:{account_id}:function:{text2sql_lambda_function_name}"
%store text2sql_lambda_function_name
%store text2sql_lambda_function_arn

In [None]:
agents.add_action_group_with_lambda(
    agent_name=agent_name,
    lambda_function_name=text2sql_lambda_function_name,
    source_code_file="lambda_function.py",
    agent_action_group_name="queryAthena",
    agent_action_group_description="Action for getting the database schema and querying with Athena",
    api_schema=api_schema,
    verbose=True
)

### Add resource based policy to Lambda function to allow agent to invoke

In [None]:
lambda_client = boto3.client('lambda', region)

# Define the resource policy statement
policy_statement = {
    "Sid": "AllowBedrockAgentAccess",
    "Effect": "Allow",
    "Principal": {
        "Service": "bedrock.amazonaws.com"
    },
    "Action": "lambda:InvokeFunction",
    "Resource": text2sql_lambda_function_arn,
    "Condition": {
        "ArnEquals": {
            "aws:SourceArn": text2sql_agent_arn
        }
    }
}

try:
    # Get the current policy
    response = lambda_client.get_policy(FunctionName=text2sql_lambda_function_arn)
    current_policy = json.loads(response['Policy'])
    
    # Add the new statement to the existing policy
    current_policy['Statement'].append(policy_statement)
    
except lambda_client.exceptions.ResourceNotFoundException:
    # If there's no existing policy, create a new one
    current_policy = {
        "Version": "2012-10-17",
        "Statement": [policy_statement]
    }

# Convert the policy to JSON string
updated_policy = json.dumps(current_policy)

# Add or update the resource policy
response = lambda_client.add_permission(
    FunctionName=text2sql_lambda_function_arn,
    StatementId="AllowText2SQLAgentAccess",
    Action="lambda:InvokeFunction",
    Principal="bedrock.amazonaws.com",
    SourceArn=text2sql_agent_arn
)

print("Resource policy added successfully.")
print("Response:", response)

### Add permissions to Lambda function execution role

In [None]:
# Create clients
iam_client = boto3.client('iam')
lambda_client = boto3.client('lambda', region)

# Get the function configuration
response = lambda_client.get_function_configuration(FunctionName=text2sql_lambda_function_name)
role_arn = response['Role']
role_name = role_arn.split('/')[-1]

# Policy ARNs to attach
policy_arns = [
    'arn:aws:iam::aws:policy/AmazonAthenaFullAccess',
    'arn:aws:iam::aws:policy/AmazonS3FullAccess'
]

# Attach each policy
for policy_arn in policy_arns:
    try:
        iam_client.attach_role_policy(
            RoleName=role_name,
            PolicyArn=policy_arn
        )
        print(f"Successfully attached {policy_arn} to role {role_name}")
    except Exception as e:
        print(f"Error attaching {policy_arn}: {str(e)}")

# Verify attached policies
try:
    response = iam_client.list_attached_role_policies(RoleName=role_name)
    print("\nAttached policies:")
    for policy in response['AttachedPolicies']:
        print(f"- {policy['PolicyName']}")
except Exception as e:
    print(f"Error listing policies: {str(e)}")

### Invoke Text2SQL Agent Test Alias to see that it answers question properly

In [None]:
%%time

bedrock_agent_runtime_client = boto3.client("bedrock-agent-runtime", region)

session_id:str = str(uuid.uuid1())

query = "What is the highest eligible free rate for K-12 students in the schools in Alameda County?"
response = bedrock_agent_runtime_client.invoke_agent(
      inputText=query,
      agentId=text2sql_agent_id,
      agentAliasId="TSTALIASID", 
      sessionId=session_id,
      enableTrace=True, 
      endSession=False,
      sessionState={}
)

print("Request sent to Agent")
print("====================")
print("Agent processing query now")
print("====================")

# Initialize an empty string to store the answer
answer = ""

# Iterate through the event stream
for event in response['completion']:
    # Check if the event is a 'chunk' event
    if 'chunk' in event:
        chunk_obj = event['chunk']
        if 'bytes' in chunk_obj:
            # Decode the bytes and append to the answer
            chunk_data = chunk_obj['bytes'].decode('utf-8')
            answer += chunk_data

# Now 'answer' contains the full response from the agent
print("Agent Answer: {}".format(answer))
print("====================")

### Now that agent has been tested via direct invoke, prepare it by creating an alias

In [None]:
text2sql_agent_alias_id, text2sql_agent_alias_arn = agents.create_agent_alias(
    text2sql_agent[0], 'v1'
)

text2sql_agent_alias_id, text2sql_agent_alias_arn

# View question and answer file
Note: The 'birdsql_data.json' file that was generated in this folder compiles a list of questions and answers, you can try asking the Text2SQL agent some of these questions!