# Notebook to test the Bedrock Agents and UC Functions.

In [1]:
from unitycatalog.ai.core.client import UnitycatalogFunctionClient
from unitycatalog.client import ApiClient, Configuration
from unitycatalog.ai.bedrock.toolkit import UCFunctionToolkit


In [2]:
config = Configuration()
config.host = "http://localhost:8080/api/2.1/unity-catalog"

# The base ApiClient is async
api_client = ApiClient(configuration=config)

client = UnitycatalogFunctionClient(api_client=api_client)

CATALOG = "AICatalog"
SCHEMA = "AISchema"


In [3]:
# Sample function
    
def location_weather_in_c(location_id: str, fetch_date: str) -> str:
    """Test function for AWS Bedrock integration.

    Args:
        location_id (str): The name to be included in the greeting message.
        fetch_date (str): The date with the location 

    Raises:
        Exception: If there is an error during the function execution.

    Returns:
        str: Weather result.
    """
    try:
        # Fetch from Databricks SQL Warehouse based UC function execution 
        return "23"
    except Exception as e:
        raise Exception(f"Error occurred: {e}")

In [4]:
client.uc.create_catalog(name=CATALOG, comment="Catalog for AI functions")

CatalogInfo(name='AICatalog', comment='Catalog for AI functions', properties={}, owner=None, created_at=1737005234233, created_by=None, updated_at=1737005234233, updated_by=None, id='7dc1e000-2854-46ac-9811-d140c1cb942e')

In [5]:
client.uc.create_schema(catalog_name=CATALOG, name=SCHEMA, comment="Schema for AI functions")

SchemaInfo(name='AISchema', catalog_name='AICatalog', comment='Schema for AI functions', properties={}, full_name='AICatalog.AISchema', owner=None, created_at=1737005278533, created_by=None, updated_at=1737005278533, updated_by=None, schema_id='c164dc56-8ee1-489c-9dad-19734bbcb0d7')

In [6]:
client.create_python_function(
    func=location_weather_in_c, 
    catalog=CATALOG, 
    schema=SCHEMA, 
    replace=True)

  check_docstring_signature_consistency(docstring_info.params, params_in_signature, func_name)


FunctionInfo(name='location_weather_in_c', catalog_name='AICatalog', schema_name='AISchema', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location_id', type_text='STRING', type_json='{"name": "location_id", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.'), FunctionParameterInfo(name='fetch_date', type_text='STRING', type_json='{"name": "fetch_date", "type": "string", "nullable": false, "metadata": {"comment": "The date with the location"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=1, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The date with t

## Check if UC function exists

In [7]:
function_name = f"{CATALOG}.{SCHEMA}.location_weather_in_c"

In [8]:
try:
    print(client.get_function(function_name))
except Exception as e:
    print(f"Function {function_name} does not exist, creating it...")

name='location_weather_in_c' catalog_name='AICatalog' schema_name='AISchema' input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location_id', type_text='STRING', type_json='{"name": "location_id", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.'), FunctionParameterInfo(name='fetch_date', type_text='STRING', type_json='{"name": "fetch_date", "type": "string", "nullable": false, "metadata": {"comment": "The date with the location"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=1, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The date with the location')]) 

In [9]:
client.list_functions(catalog=CATALOG, schema=SCHEMA)

[FunctionInfo(name='bedrock_test_function', catalog_name='AICatalog', schema_name='AISchema', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='name', type_text='STRING', type_json='{"name": "name", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.')]), data_type=<ColumnTypeName.STRING: 'STRING'>, full_data_type='STRING', return_params=None, routine_body='EXTERNAL', routine_definition='try:\n    # Fetch from Databricks SQL Warehouse based UC function execution \n    return "hello: " + name\nexcept Exception as e:\n    raise Exception(f"Error occurred: {e}")', routine_dependencies=None, parameter_style='S', is_deterministic=True, sql_data_access='NO_SQL', i

In [10]:
function_name = f"{CATALOG}.{SCHEMA}.location_weather_in_c"
toolkit = UCFunctionToolkit(function_names=[function_name], client=client)

In [11]:
print(toolkit.client.get_function(function_name))

name='location_weather_in_c' catalog_name='AICatalog' schema_name='AISchema' input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location_id', type_text='STRING', type_json='{"name": "location_id", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.'), FunctionParameterInfo(name='fetch_date', type_text='STRING', type_json='{"name": "fetch_date", "type": "string", "nullable": false, "metadata": {"comment": "The date with the location"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=1, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The date with the location')]) 

# Adding Bedrock SDK calls 


In [12]:
import boto3, pprint, json, time, uuid

In [11]:
#boto3.setup_default_session()

In [13]:
agent_id = "AP5RQUVNTU"
agent_alias_id = "O6EXN8DJVZ"

In [14]:
# Bedrock agent configuration
session = toolkit.create_session(agent_id=agent_id,
                                agent_alias_id=agent_alias_id)
# Generate unique session ID
session_id = str(uuid.uuid1())

In [17]:
response = session.invoke_agent(
                input_text="What is the weather for location 1234 and date of 2024-11-19",
                enable_trace=True,
                session_id=session_id,
                uc_client = toolkit.client
                )

Tool Results Enter:
Tool Call Function Name Local: location_weather_in_c


ServiceException: (500)
Reason: Internal Server Error
HTTP response headers: <CIMultiDictProxy('Content-Type': 'application/json', 'Content-Length': '5622', 'Server': 'Armeria/1.28.4', 'Date': 'Mon, 3 Feb 2025 15:34:13 GMT')>
HTTP response body: {"error_code":"INTERNAL","details":[{"reason":"INTERNAL","metadata":{},"@type":"google.rpc.ErrorInfo"}],"stack_trace":"[java.base/java.util.Objects.requireNonNull(Objects.java:235), com.linecorp.armeria.common.HttpResponse.ofJson(HttpResponse.java:661), com.linecorp.armeria.common.HttpResponse.ofJson(HttpResponse.java:618), com.linecorp.armeria.common.HttpResponse.ofJson(HttpResponse.java:603), io.unitycatalog.server.service.FunctionService.getFunction(FunctionService.java:99), com.linecorp.armeria.internal.server.annotation.AnnotatedService.invoke(AnnotatedService.java:382), com.linecorp.armeria.internal.server.annotation.AnnotatedService.serve0(AnnotatedService.java:296), com.linecorp.armeria.internal.server.annotation.AnnotatedService.serve(AnnotatedService.java:272), com.linecorp.armeria.internal.server.annotation.AnnotatedService.serve(AnnotatedService.java:78), com.linecorp.armeria.internal.server.annotation.AnnotatedService$ExceptionHandlingHttpService.serve(AnnotatedService.java:536), com.linecorp.armeria.server.HttpServerHandler.serve0(HttpServerHandler.java:463), com.linecorp.armeria.server.HttpServerHandler.handleRequest(HttpServerHandler.java:398), com.linecorp.armeria.server.HttpServerHandler.channelRead(HttpServerHandler.java:281), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), com.linecorp.armeria.server.Http1RequestDecoder.channelRead(Http1RequestDecoder.java:282), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), com.linecorp.armeria.server.HttpServerUpgradeHandler.channelRead(HttpServerUpgradeHandler.java:227), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), io.netty.channel.CombinedChannelDuplexHandler$DelegatingChannelHandlerContext.fireChannelRead(CombinedChannelDuplexHandler.java:436), io.netty.handler.codec.ByteToMessageDecoder.fireChannelRead(ByteToMessageDecoder.java:346), io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:318), io.netty.channel.CombinedChannelDuplexHandler.channelRead(CombinedChannelDuplexHandler.java:251), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), io.netty.handler.codec.ByteToMessageDecoder.handlerRemoved(ByteToMessageDecoder.java:266), io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:537), io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:469), io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:290), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), io.netty.handler.logging.LoggingHandler.channelRead(LoggingHandler.java:280), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), io.netty.handler.flush.FlushConsolidationHandler.channelRead(FlushConsolidationHandler.java:152), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412), io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1407), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:440), io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420), io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:918), io.netty.channel.epoll.AbstractEpollStreamChannel$EpollStreamUnsafe.epollInReady(AbstractEpollStreamChannel.java:799), io.netty.channel.epoll.EpollEventLoop.processReady(EpollEventLoop.java:501), io.netty.channel.epoll.EpollEventLoop.run(EpollEventLoop.java:399), io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:994), io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74), io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30), java.base/java.lang.Thread.run(Thread.java:840)]","message":"content"}


In [22]:
import pprint 
pprint.pp(response)

NameError: name 'response' is not defined

In [17]:
event_stream = response['completion']

for event in event_stream:
   pprint.pp(event)

{'trace': {'agentAliasId': 'O6EXN8DJVZ',
           'agentId': 'AP5RQUVNTU',
           'agentVersion': '1',
           'callerChain': [{'agentAliasArn': 'arn:aws:bedrock:us-east-1:288584671716:agent-alias/AP5RQUVNTU/O6EXN8DJVZ'}],
           'sessionId': '11beed2a-ddc1-11ef-ba47-4e3e215d14d1',
           'trace': {'orchestrationTrace': {'modelInvocationInput': {'inferenceConfiguration': {'maximumLength': 2048,
                                                                                                'stopSequences': ['</invoke>',
                                                                                                                  '</answer>',
                                                                                                                  '</error>'],
                                                                                                'temperature': 0.0,
                                                                                        

In [44]:

function_input = event["returnControl"]["invocationInputs"][0]["functionInvocationInput"]

print(f"Function to call: {function_input['function']}")

print(f"Parameters: {function_input['parameters']}")

# Simulate weather result (replace with actual API call)
weather_result = "23"  # Example temperature
# Send result back to agent
final_response = session.invoke_agent(
    input_text="",
    session_id=session_id,
    enable_trace=True,
    session_state={
        'invocationId': event["returnControl"]["invocationId"],
        'returnControlInvocationResults': [{
            'functionResult': {
                'actionGroup': function_input["actionGroup"],
                'function': function_input["function"],
                'confirmationState': 'CONFIRM',
                'responseBody': {
                    "TEXT": {
                        'body':
                        f"weather_in_centigrade: {weather_result}"
                    }
                }
            }
        }]
    }
)



Function to call: location_weather_in_c
Parameters: [{'name': 'fetch_date', 'type': 'string', 'value': '2024-11-19'}, {'name': 'location_id', 'type': 'integer', 'value': '1234'}]


In [45]:
# Print final response
print("Agent Response:")
for final_event in final_response.get('completion', []):
    print(f"  {final_event}")

Agent Response:
  {'trace': {'agentAliasId': 'O6EXN8DJVZ', 'agentId': 'AP5RQUVNTU', 'agentVersion': '1', 'callerChain': [{'agentAliasArn': 'arn:aws:bedrock:us-east-1:288584671716:agent-alias/AP5RQUVNTU/O6EXN8DJVZ'}], 'sessionId': '18915a4e-d762-11ef-aef3-4e3e215d14d1', 'trace': {'orchestrationTrace': {'modelInvocationInput': {'inferenceConfiguration': {'maximumLength': 2048, 'stopSequences': ['</invoke>', '</answer>', '</error>'], 'temperature': 0.0, 'topK': 250, 'topP': 1.0}, 'text': '{"system":"You are a weather agent to fetche the current weather in celsius for a given locationYou have been provided with a set of functions to answer the user\'s question.You will ALWAYS follow the below guidelines when you are answering a question:<guidelines>- Think through the user\'s question, extract all data from the question and the previous conversations before creating a plan.- ALWAYS optimize the plan by using multiple function calls at the same time whenever possible.- Never assume any para

# Function Execution

## Listing Function

In [49]:
functions = client.list_functions(
    catalog=CATALOG,
    schema=SCHEMA,
    max_results=10  # Paginated results will contain a continuation token that can be submitted with additional requests
)

print(functions)

[FunctionInfo(name='bedrock_test_function', catalog_name='AICatalog', schema_name='AISchema', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='name', type_text='STRING', type_json='{"name": "name", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.')]), data_type=<ColumnTypeName.STRING: 'STRING'>, full_data_type='STRING', return_params=None, routine_body='EXTERNAL', routine_definition='try:\n    # Fetch from Databricks SQL Warehouse based UC function execution \n    return "hello: " + name\nexcept Exception as e:\n    raise Exception(f"Error occurred: {e}")', routine_dependencies=None, parameter_style='S', is_deterministic=True, sql_data_access='NO_SQL', i

In [63]:
for f in functions:
    if f.name == "location_weather_in_c":
      print(f"Function: {f.name}")
      print(f"Parameters: {f.input_params}")

Function: location_weather_in_c
Parameters: parameters=[FunctionParameterInfo(name='location_id', type_text='STRING', type_json='{"name": "location_id", "type": "string", "nullable": false, "metadata": {"comment": "The name to be included in the greeting message."}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The name to be included in the greeting message.'), FunctionParameterInfo(name='fetch_date', type_text='STRING', type_json='{"name": "fetch_date", "type": "string", "nullable": false, "metadata": {"comment": "The date with the location"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=1, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The date with the location')]


## Executing Function

In [74]:

location_id = "New Jersey"
fetch_date = "2025-01-01"

parameters = {"location_id": "New Jersey", "fetch_date": "2025-01-01"}
result = client.execute_function(function_name, parameters)


In [77]:
value =  result.value
print(f"The weather in centigrade for {location_id} on {fetch_date} is: {value}")

The weather in centigrade for New Jersey on 2025-01-01 is: 23
