In [None]:
import boto3
import time
import json

In [None]:
design_client = boto3.client('bedrock')
runtime_client = boto3.client('bedrock-runtime')

In [None]:
guardrail_id = '7l2cg7arccsk'
guardrail_version = 'DRAFT'
modelID = 'anthropic.claude-3-haiku-20240307-v1:0'

In [None]:
prompt = 'Is the AB503 Product a better investment than the S&P 500?'

In [None]:

payload = {
    "modelId": modelID,
    "contentType": "application/json",
    "accept": "application/json",
    "body": {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 1000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    }
}

# Convert the payload to bytes
body_bytes = json.dumps(payload['body']).encode('utf-8')

#### Use Guardrail with invoke model

In [None]:
# Invoke the model
response = runtime_client.invoke_model(
    body = body_bytes,
    contentType = payload['contentType'],
    accept = payload['accept'],
    modelId = payload['modelId'],
    guardrailIdentifier = guardrail_id, 
    guardrailVersion =guardrail_version, 
    trace = "ENABLED"
)

# Print the response
response_body = response['body'].read().decode('utf-8')
print(json.dumps(json.loads(response_body), indent=2))

#### Streaming Output

In [None]:
TEXT_UNIT = 1000 # characters

In [None]:
def check_severe_violations(violations):
    # When guardrail intervenes either the action on the request is BLOCKED or NONE
    # Here we check how many of the violations lead to blocking the request
    severe_violations = [violation['action']=='BLOCKED' for violation in violations]
    return sum(severe_violations)

In [None]:
def is_policy_assessement_blocked(assessments):
    # While creating the guardrail you could specify multiple types of policies.
    # At the time of assessment all the policies should be checked for potential violations
    # If there is even 1 violation that blocks the request, the entire request is blocked
    blocked = []
    for assessment in assessments:
        if 'topicPolicy' in assessment:
            blocked.append(check_severe_violations(assessment['topicPolicy']['topics']))
        if 'wordPolicy' in assessment:
            if 'customWords' in assessment['wordPolicy']:
                blocked.append(check_severe_violations(assessment['wordPolicy']['customWords']))
            if 'managedWordLists' in assessment['wordPolicy']:
                blocked.append(check_severe_violations(assessment['wordPolicy']['managedWordLists']))
        if 'sensitiveInformationPolicy' in assessment:
            if 'piiEntities' in assessment['sensitiveInformationPolicy']:
                blocked.append(check_severe_violations(assessment['sensitiveInformationPolicy']['piiEntities']))
            if 'regexes' in assessment['sensitiveInformationPolicy']:
                blocked.append(check_severe_violations(assessment['sensitiveInformationPolicy']['regexes']))
        if 'contentPolicy' in assessment:
            blocked.append(check_severe_violations(assessment['contentPolicy']['filters']))
    severe_violation_count = sum(blocked)
    print(f'\033[91m::Guardrail:: {severe_violation_count} severe violations detected\033[0m')
    return severe_violation_count>0

In [None]:
def apply_guardrail(text, text_source_type, guardrail_id, guardrail_version="DRAFT"):
    print(f'\n\n\033[91m::Guardrail:: Applying guardrail with {(len(text)//TEXT_UNIT)+1} text units\033[0m\n')
    response = runtime_client.apply_guardrail(
        guardrailIdentifier=guardrail_id,
        guardrailVersion=guardrail_version, 
        source=text_source_type, # can be 'INPUT' or 'OUTPUT'
        content=[{"text": {"text": text}}]
    )
    if response['action'] == 'GUARDRAIL_INTERVENED':
        is_blocked = is_policy_assessement_blocked(response['assessments'])
        alternate_text = ' '.join([output['text'] for output in response['outputs']])
        return is_blocked, alternate_text, response
    else:
        # Return the default response in case of no guardrail intervention
        return False, text, response

In [None]:
def stream_conversation(messages,
                        system_prompts):
    
    response = runtime_client.converse_stream(
        modelId=modelID,
        messages=messages,
        system=system_prompts
    )

    stream = response.get('stream')
    full_text = ""
    buffer_text = ""
    applied_guardrails = []
    if stream:
        for event in stream:
            if 'messageStart' in event:
                print(f"\nRole: {event['messageStart']['role']}")

            if 'contentBlockDelta' in event:
                new_text = event['contentBlockDelta']['delta']['text']

                if len(buffer_text + new_text) > TEXT_UNIT:
                    is_blocked, alt_text, guardrail_response = apply_guardrail(buffer_text, "OUTPUT", guardrail_id, guardrail_version)
                    if is_blocked:
                        event['messageStop'] = {
                            'stopReason': guardrail_response['action'], 
                            'output': alt_text,
                            'assessments': guardrail_response['assessments'],
                        }
                        full_text = alt_text
                    else:
                        full_text += alt_text
                    print(alt_text, end="")
                    applied_guardrails.append(guardrail_response)
                    buffer_text = new_text
                else: 
                    buffer_text += new_text

            if 'messageStop' in event:
                if event['messageStop']['stopReason'] == 'GUARDRAIL_INTERVENED':
                    print(f"\nStop reason: {event['messageStop']['stopReason']}")
                    break
                else:
                    print(f"\nStop reason: {event['messageStop']['stopReason']}")
                    is_blocked, alt_text, guardrail_response = apply_guardrail(buffer_text, "OUTPUT", guardrail_id, guardrail_version)
                    if is_blocked:
                        print(alt_text)
                        if 'metadata' not in event:
                            event['metadata'] = {}
                        event['metadata']['guardrails_usage'] = guardrail_response['usage']
                        applied_guardrails.append(guardrail_response)

            if 'metadata' in event:
                metadata = event['metadata']
                if 'usage' in metadata:
                    print("\nToken usage")
                    print(f"Input tokens: {metadata['usage']['inputTokens']}")
                    print(
                        f":Output tokens: {metadata['usage']['outputTokens']}")
                    print(f":Total tokens: {metadata['usage']['totalTokens']}")
                    print(f":Total text units: {(len(full_text)//TEXT_UNIT)+1}")
                if 'metrics' in event['metadata']:
                    print(
                        f"Latency: {metadata['metrics']['latencyMs']} milliseconds")
                if 'guardrails_usage' in event['metadata']:
                    print(event['metadata']['guardrails_usage'])
    return full_text, applied_guardrails

In [None]:
#prompt = "List 3 names of prominent CEOs and later tell me what is a bank and what are the benefits of opening a savings account?"
prompt = "Tell me about why financial independence is important and only at the very end ask the question if you can help me to invest after retirement?"

message = {
    "role": "user",
    "content": [{"text": prompt}]
}
messages = [message]

# System prompts.
system_prompt = """You are an assistant that helps with tasks from users. Be as elaborate as possible"""
system_prompts = [{"text" : system_prompt}]

full_text, applied_guardrails = stream_conversation(messages, system_prompts)
