# Sample Text2SQL Agent Walkthrough

This notebook will walk users through setting up a generic Text2SQL Agent and run it against the BirdSQL - Mini Dev Dataset (https://github.com/bird-bench/mini_dev)

#### Ensure the latest version of boto3 is shown below

##### If not then run through setup_environment.ipynb in the 0-Notebook-environment/ folder

In [1]:
!pip freeze | grep -E "boto3|pandas"

boto3==1.36.24
pandas==1.5.3


#### Load in environment variables to notebook

In [17]:
# Retrieve import path
%store -r IMPORTS_PATH

# Retrieve account info
%store -r account_id
%store -r region

# Retrieve model lists
%store -r agent_foundation_model

In [18]:
# Store the variables in os
os.environ['REGION'] = region
os.environ['BASE_BUCKET_NAME'] = 'text2sql-agent'
os.environ['ATHENA_RESULTS_BUCKET_NAME'] = 'text2sql-athena-results'
os.environ['BASE_DIR'] = 'dev_databases'
os.environ['DATABASE_NAME'] = 'california_schools'

### Retrieve BirdSQL - Mini Dev Dataset

In [4]:
# Download .zip file to local directory

!cd 2-Sample-text2sql-agent
!wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip

/bin/sh: line 0: cd: 2-Sample-text2sql-agent: No such file or directory
--2025-02-21 18:04:43--  https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip
8.141.181.247d-bench.oss-cn-beijing.aliyuncs.com (bird-bench.oss-cn-beijing.aliyuncs.com)... 
connected. to bird-bench.oss-cn-beijing.aliyuncs.com (bird-bench.oss-cn-beijing.aliyuncs.com)|8.141.181.247|:443... 
HTTP request sent, awaiting response... 200 OK
Length: 346207293 (330M) [application/zip]
Saving to: ‘dev.zip’


2025-02-21 18:05:05 (15.3 MB/s) - ‘dev.zip’ saved [346207293/346207293]



## Set up Necessary services to run Text2SQL agent

In order to run the Text2SQL agent, we will need to setup the Athena databases to make SQL queries against. The following script will:
1. Unzip the downloaded folder
2. Create S3 buckets
3. Convert .sqlite files into individual .parquet files for each table
4. Upload to the database s3 bucket
5. Set up appropriate Athena permissions
6. Create databases in Athena

In [22]:
%run data_prep.py

Created input.json for agent


### Create Text2SQL Agent

In [88]:
# get boto3 clients for required AWS services
sts_client = boto3.client('sts')
iam_client = boto3.client('iam')
s3_client = boto3.client('s3')
lambda_client = boto3.client('lambda')
bedrock_agent_client = boto3.client('bedrock-agent')
bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime')

In [89]:
session = boto3.session.Session()
region = session.region_name
account_id = sts_client.get_caller_identity()["Account"]
region, account_id

('us-west-2', '761018876800')

In [105]:
# Generate random prefix for unique IAM roles, agent name and S3 Bucket and 
# assign variables
suffix = f"{region}-{account_id}"
agent_name = "sample-text2sql-agent"
agent_alias_name = "sample-alias"
bucket_name = f'{agent_name}-{suffix}'
bucket_key = f'{agent_name}-schema.json'
schema_name = 'sample_text2sql_agent_openapi_schema.json'
schema_arn = f'arn:aws:s3:::{bucket_name}/{bucket_key}'
bedrock_agent_bedrock_allow_policy_name = f"{agent_name}-allow-{suffix}"
bedrock_agent_s3_allow_policy_name = f"{agent_name}-s3-allow-{suffix}"
lambda_role_name = f'{agent_name}-lambda-role-{suffix}'
agent_role_name = f'AmazonBedrockExecutionRoleForAgents_{suffix}'
lambda_code_path = "lambda_function.py"
lambda_name = f'{agent_name}-{suffix}'
model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0"


#### Create S3 bucket and upload API Schema

In [92]:
# Create S3 bucket for Open API schema
if region == "us-east-1":
    s3bucket = s3_client.create_bucket(
        Bucket=bucket_name
    )
else:
    s3bucket = s3_client.create_bucket(
        Bucket=bucket_name,
        CreateBucketConfiguration={ 'LocationConstraint': region } 
    )

In [113]:
#Upload Open API Schema to s3 bucket
s3_client.upload_file(schema_name, bucket_name, bucket_key)


#### Create Lambda function for Action Group

In [97]:
# Create IAM Role for the Lambda function
try:
    assume_role_policy_document = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Principal": {
                    "Service": "lambda.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
            }
        ]
    }

    assume_role_policy_document_json = json.dumps(assume_role_policy_document)

    lambda_iam_role = iam_client.create_role(
        RoleName=lambda_role_name,
        AssumeRolePolicyDocument=assume_role_policy_document_json
    )

    # Pause to make sure role is created
    time.sleep(10)
except:
    lambda_iam_role = iam_client.get_role(RoleName=lambda_role_name)

iam_client.attach_role_policy(
    RoleName=lambda_role_name,
    PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
)

{'ResponseMetadata': {'RequestId': 'f9e17745-6ef6-4225-8a6b-c211c177acb1',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'date': 'Fri, 14 Feb 2025 21:22:07 GMT',
   'x-amzn-requestid': 'f9e17745-6ef6-4225-8a6b-c211c177acb1',
   'content-type': 'text/xml',
   'content-length': '212'},
  'RetryAttempts': 0}}

In [103]:
# Package up the lambda function code
s = BytesIO()
z = zipfile.ZipFile(s, 'w')
z.write(lambda_code_path)
z.close()
zip_content = s.getvalue()

# Create Lambda Function
lambda_function = lambda_client.create_function(
    FunctionName=lambda_name,
    Runtime='python3.12',
    Timeout=180,
    Role=lambda_iam_role['Role']['Arn'],
    Code={'ZipFile': zip_content},
    Handler='lambda_function.lambda_handler'
)

#### Create Agent

In [106]:
# Create IAM policies for agent

bedrock_agent_bedrock_allow_policy_statement = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "AmazonBedrockAgentBedrockFoundationModelPolicy",
            "Effect": "Allow",
            "Action": "bedrock:InvokeModel",
            "Resource": [
                f"arn:aws:bedrock:{region}::foundation-model/{model_id}"
            ]
        }
    ]
}

bedrock_policy_json = json.dumps(bedrock_agent_bedrock_allow_policy_statement)

agent_bedrock_policy = iam_client.create_policy(
    PolicyName=bedrock_agent_bedrock_allow_policy_name,
    PolicyDocument=bedrock_policy_json
)

In [107]:
#Policy that allows fetching of agent's OpenAPI schema from S3

bedrock_agent_s3_allow_policy_statement = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "AllowAgentAccessOpenAPISchema",
            "Effect": "Allow",
            "Action": ["s3:GetObject"],
            "Resource": [
                schema_arn
            ]
        }
    ]
}


bedrock_agent_s3_json = json.dumps(bedrock_agent_s3_allow_policy_statement)
agent_s3_schema_policy = iam_client.create_policy(
    PolicyName=bedrock_agent_s3_allow_policy_name,
    Description=f"Policy to allow invoke Lambda that was provisioned for it.",
    PolicyDocument=bedrock_agent_s3_json
)

In [108]:
# Create IAM Role for the agent and attach IAM policies
assume_role_policy_document = {
    "Version": "2012-10-17",
    "Statement": [{
          "Effect": "Allow",
          "Principal": {
            "Service": "bedrock.amazonaws.com"
          },
          "Action": "sts:AssumeRole"
    }]
}

assume_role_policy_document_json = json.dumps(assume_role_policy_document)
agent_role = iam_client.create_role(
    RoleName=agent_role_name,
    AssumeRolePolicyDocument=assume_role_policy_document_json
)

# Pause to make sure role is created
time.sleep(10)
    
iam_client.attach_role_policy(
    RoleName=agent_role_name,
    PolicyArn=agent_bedrock_policy['Policy']['Arn']
)

iam_client.attach_role_policy(
    RoleName=agent_role_name,
    PolicyArn=agent_s3_schema_policy['Policy']['Arn']
)

{'ResponseMetadata': {'RequestId': 'da2e77fa-3d49-4299-b255-e78f2f1fee76',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'date': 'Mon, 17 Feb 2025 21:26:54 GMT',
   'x-amzn-requestid': 'da2e77fa-3d49-4299-b255-e78f2f1fee76',
   'content-type': 'text/xml',
   'content-length': '212'},
  'RetryAttempts': 0}}

In [123]:
agent_foundation_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"
agent_description = "Text2SQL agent to run against Bird-SQL Mini-Dev benchmark dataset"
agent_instruction = """
You are an AI Agent specialized in generating SQL queries for a database. 
Your primary task is to interpret user queries, 
generate appropriate SQL queries, and provide relevant answer based on the data. 
Use only the appropriate tools as required by the specific question. Follow these instructions carefully: 
1. Before generating any SQL query, use the /getschema tool to familiarize yourself with the database structure. 
This will ensure your queries are correctly formatted and target the appropriate columns. 
2. When generating an SQL query: a. Write the query as a single line, removing all newline ("
") characters. 
b. Column names should remain consistent, do not modify the column names in the generated SQL query. 
3. When providing your response: a. Start with a brief summary of your understanding of 
the user's query. b. Explain the steps you're taking to address the query. c. Ask for clarifications from the 
user if required.
"""

response = bedrock_agent_client.create_agent(
    agentName=agent_name,
    description=agent_description,
    instruction=agent_instruction,
    foundationModel=agent_foundation_model,
    agentResourceRoleArn=agent_role['Role']['Arn']
)


In [124]:
agent_id = response['agent']['agentId']
agent_arn = response['agent']['agentArn']
agent_id , agent_arn

('AQIMYSEPOK', 'arn:aws:bedrock:us-west-2:761018876800:agent/AQIMYSEPOK')

#### Create Agent Action Group

In [125]:
# Pause to make sure agent is created
# time.sleep(30)
# Now, we can configure and create an action group here:

print(lambda_function['FunctionArn'])
print(bucket_name, bucket_key)

# Use agent helper to clean up
agent_action_group_response = bedrock_agent_client.create_agent_action_group(
    agentId=agent_id,
    agentVersion='DRAFT',
    actionGroupExecutor={
        'lambda': lambda_function['FunctionArn']
        # 'roleArn': 'arn:aws:iam::761018876800:role/sample-text2sql-agent-lambda-role-us-west-2-761018876800',  # Add your IAM role ARN here
        # 'roleSessionName': 'BedrockAgentSession'  # Add a session name
    },
    actionGroupName='QueryAthena',
    apiSchema={
        's3': {
            's3BucketName': bucket_name,
            's3ObjectKey': bucket_key
        }
    },
    description='Execute SQL queries in Athena database'
)



arn:aws:lambda:us-west-2:761018876800:function:sample-text2sql-agent-us-west-2-761018876800
sample-text2sql-agent-us-west-2-761018876800 sample-text2sql-agent-schema.json


In [126]:
# Create allow invoke permission on lambda
response = lambda_client.add_permission(
    FunctionName=lambda_name,
    StatementId='allow_bedrock',
    Action='lambda:InvokeFunction',
    Principal='bedrock.amazonaws.com',
    SourceArn=f"arn:aws:bedrock:{region}:{account_id}:agent/{agent_id}",
)

#### Prepare Agent

In [127]:
agent_prepare = bedrock_agent_client.prepare_agent(agentId=agent_id)
agent_prepare

{'ResponseMetadata': {'RequestId': '194cea38-3947-4aee-ad01-3b86d494664b',
  'HTTPStatusCode': 202,
  'HTTPHeaders': {'date': 'Mon, 17 Feb 2025 21:53:32 GMT',
   'content-type': 'application/json',
   'content-length': '119',
   'connection': 'keep-alive',
   'x-amzn-requestid': '194cea38-3947-4aee-ad01-3b86d494664b',
   'x-amz-apigw-id': 'GJhqhGNGvHcEADA=',
   'x-amzn-trace-id': 'Root=1-67b3afdc-63233e2f37f3a60e3d0eacee'},
  'RetryAttempts': 0},
 'agentId': 'AQIMYSEPOK',
 'agentStatus': 'PREPARING',
 'agentVersion': 'DRAFT',
 'preparedAt': datetime.datetime(2025, 2, 17, 21, 53, 32, 819336, tzinfo=tzlocal())}

#### Create Agent Alias

In [128]:
# Pause to make sure agent is prepared
# time.sleep(30)
agent_alias = bedrock_agent_client.create_agent_alias(
    agentId=agent_id,
    agentAliasName=agent_alias_name
)

#### Invoke Agent

In [130]:
# time.sleep(30)
# Extract the agentAliasId from the response
agent_alias_id = agent_alias['agentAlias']['agentAliasId']

agent_alias_id

'T9J04ATBNJ'

In [None]:
## create a random id for session initiator id
session_id:str = str(uuid.uuid1())

# invoke the agent API
agentResponse = bedrock_agent_runtime_client.invoke_agent(
    inputText="what are the open claims?",
    agentId=agent_id,
    agentAliasId=agent_alias_id, 
    sessionId=session_id,
    enableTrace=True,
    endSession=False,
    sessionState={}
)


print("Request sent to Agent:\n{}".format(response))
print("====================")
print("Agent processing query now")
print("====================")

# Initialize an empty string to store the answer
answer = ""

# Iterate through the event stream
for event in response['completion']:
    # Check if the event is a 'chunk' event
    if 'chunk' in event:
        chunk_obj = event['chunk']
        if 'bytes' in chunk_obj:
            # Decode the bytes and append to the answer
            chunk_data = chunk_obj['bytes'].decode('utf-8')
            answer += chunk_data

# Now 'answer' contains the full response from the agent
print("Agent Answer: {}".format(answer))
print("====================")


### Create input .json file for all Text2SQL using ground truth provided by BirdSQL Mini-Dev dataset

In [20]:
ATHENA_RESULTS_BUCKET_NAME = os.environ.get('ATHENA_RESULTS_BUCKET_NAME')
DATABASE_NAME = os.environ.get('DATABASE_NAME')

ATHENA_RESULTS_BUCKET_NAME, DATABASE_NAME

('text2sql-athena-results', 'california_schools')

In [24]:
import boto3
import time
import pandas as pd
import json
from typing import Dict, List, Any

def run_query(query: str, athena_client):
    """
    Run Athena query and return results as pandas DataFrame
    """
    
    try:
        response = athena_client.start_query_execution(
            QueryString=query,
            QueryExecutionContext={
                'Database': DATABASE_NAME
            },
            ResultConfiguration={
                'OutputLocation': f's3://{ATHENA_RESULTS_BUCKET_NAME}/athena-results/'
            }
        )
        
        query_execution_id = response['QueryExecutionId']
        
        while True:
            response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
            state = response['QueryExecution']['Status']['State']
            
            if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
                break
                
            time.sleep(1)
            
        if state == 'SUCCEEDED':
            response = athena_client.get_query_results(QueryExecutionId=query_execution_id)
            # Process data
            results = []
            for row in response['ResultSet']['Rows'][1:]:  # Skip header row
                for field in row['Data']:
                    value = field.get('VarCharValue', '')
                    
                    results.append(str(value))
            
            # Join results with commas
            return ', '.join(results)
        else:
            print(f"Query failed with state: {state}")
            return None
            
    except Exception as e:
        print(f"Error running query: {e}")
        return None

def generate_ground_truth_answer(question_id: int, question: str, sql_query: str, 
                               sql_context: str, query_results: str) -> Dict:
    """
    Generate ground truth answer using AWS Bedrock based on question and query results
    """
    bedrock_runtime = boto3.client(
        service_name='bedrock-runtime'
    )
        
    # Construct prompt for Bedrock
    prompt = f"""You are generating ground truth answers that will be used to evaluate the factual correctness of Text2SQL agent responses.

Question: {question}
Query Results: {query_results}

Generate a natural language answer that:
1. States all numerical values and facts from the query results explicitly
2. Uses consistent formatting for numbers (maintain exact precision from results)
3. Includes all relevant values if multiple results are returned
4. States the answer in a clear, declarative way that directly addresses the question
5. Avoids additional interpretations or information not present in the query results

Remember:
- Focus only on the facts present in the query results
- Use the exact numbers shown in the results
- Structure the answer to make fact-checking straightforward
- Be explicit about any percentages, counts, or measurements
- Make sure every number in the query results is mentioned in your answer

Your answer should be easy to compare with other responses for factual accuracy."""

    # Create request body for Claude model
    body = json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 512,
        "temperature": 0.5,
        "messages": [
            {
                "role": "user",
                "content": [{"type": "text", "text": prompt}],
            }
        ],
    })

    try:
        # Call Bedrock
        response = bedrock_runtime.invoke_model(
            modelId='anthropic.claude-3-5-sonnet-20241022-v2:0',  # or your preferred model
            body=body
        )
        
        # Parse response
        response_body = json.loads(response['body'].read())
        answer = response_body['content'][0]['text']
        
        # Format the response in the required structure
        formatted_response = {
            "question_id": question_id,
            "question": question,
            "question_type": "TEXT2SQL",
            "ground_truth": {
                "ground_truth_sql_query": sql_query,
                "ground_truth_sql_context": sql_context,
                "ground_truth_query_result": query_results,
                "ground_truth_answer": answer
            }
        }
        
        return formatted_response
        
    except Exception as e:
        print(f"Error generating answer: {e}")
        return None

        
def get_schema(athena_client):
    """
    Get schema information for all tables in Athena databases
    """

    sql = f"""
        SELECT
            table_name,
            column_name,
            data_type
        FROM information_schema.columns
        WHERE table_schema = '{DATABASE_NAME}'
        ORDER BY table_name, ordinal_position;
        """
        
    try:
        # Start query execution
        response = athena_client.start_query_execution(
            QueryString=sql,
            QueryExecutionContext={
                'Database': DATABASE_NAME
            }
        )
            
        query_execution_id = response['QueryExecutionId']
            
        def wait_for_query_completion(query_execution_id):
            while True:
                response = athena_client.get_query_execution(
                    QueryExecutionId=query_execution_id
                )
                state = response['QueryExecution']['Status']['State']
                
                if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
                    return state
                    
                time.sleep(2)
            
        # Wait for query completion
        state = wait_for_query_completion(query_execution_id)

        if state == 'SUCCEEDED':
            # Get query results
            results = athena_client.get_query_results(
                QueryExecutionId=query_execution_id
            )
            # Assuming you have a database connection and cursor setup
            # cursor.execute(sql)
            # results = cursor.fetchall()
            
            database_structure = []
            table_dict = {}

            # Skip the header row
            rows = results['ResultSet']['Rows'][1:]

            for row in rows:
                # Extract values from the Data structure
                table_name = row['Data'][0]['VarCharValue']
                column_name = row['Data'][1]['VarCharValue']
                data_type = row['Data'][2]['VarCharValue']
                
                # Initialize table if not exists
                if table_name not in table_dict:
                    table_dict[table_name] = []
                
                # Append column information
                table_dict[table_name].append((column_name, data_type))

            # Convert to the desired format
            for table_name, columns in table_dict.items():
                database_structure.append({
                    "table_name": table_name,
                    "columns": columns
                })

            return database_structure

        else:
            raise Exception(f"Query failed with state: {state}")
    except Exception as e:
            print(f"Error getting schema: {e}")
            raise

def generate_dataset(input_file: str, output_file: str, athena_client):
    """
    Generate dataset with ground truth answers in trajectory format
    """
    try:
        # Read input file
        with open(input_file, 'r') as f:
            questions_data = json.load(f)
        
        # Initialize trajectories dictionary
        trajectories = {}
        
        # Process each question
        for idx, item in enumerate(questions_data):


            # Break for testing purposes!
            if idx == 5:
                break



                
            question_id = item.get('question_id', 0)
            question = item['question']
            sql_query = item['SQL']
            
            print(f"\nProcessing question {question_id}: {question}")
            
            # Get table schema
            sql_context = get_schema(athena_client)
            # Run query
            query_results = run_query(sql_query.replace('`','"'),athena_client)
            if query_results is not None:
                # Generate answer with formatted response
                response = generate_ground_truth_answer(
                    question_id=question_id,
                    question=question,
                    sql_query=sql_query,
                    sql_context=str(sql_context),
                    query_results=query_results
                )
                
                if response:
                    # Create trajectory key
                    trajectory_key = f"Trajectory{idx + 1}"
                    
                    # Format the response for this trajectory
                    trajectory_response = [response]
                    
                    # Add to trajectories dictionary
                    trajectories[trajectory_key] = trajectory_response
                    print(f"Generated ground truth for question {question_id}")
            
        # Write results to output file
        with open(output_file, 'w') as f:
            json.dump(trajectories, f, indent=2)
            
        print(f"\nProcessed {len(trajectories)} questions. Results saved to {output_file}")
        
    except Exception as e:
        print(f"Error generating dataset: {e}")

INPUT_FILE = "input.json"
OUTPUT_FILE = "text2sql_data_file_auto.json"

athena_client = boto3.client('athena')

generate_dataset(INPUT_FILE, OUTPUT_FILE,athena_client)



Processing question 0: What is the highest eligible free rate for K-12 students in the schools in Alameda County?
Generated ground truth for question 0

Processing question 1: Please list the lowest three eligible free rates for students aged 5-17 in continuation schools.
Generated ground truth for question 1

Processing question 2: Please list the zip code of all the charter schools in Fresno County Office of Education.
Generated ground truth for question 2

Processing question 3: What is the unabbreviated mailing street address of the school with the highest FRPM count for K-12 students?
Generated ground truth for question 3

Processing question 4: Please list the phone numbers of the direct charter-funded schools that are opened after 2000/1/1.
Generated ground truth for question 4

Processed 5 questions. Results saved to text2sql_data_file_auto.json
