In [None]:
from langchain_aws import ChatBedrock
from langchain.prompts import PromptTemplate
from langchain.chains import TransformChain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from rich.pretty import pprint
import boto3

In [None]:
# looks up the id for the guardrail named 'sanitize'
bedrock_client = boto3.client('bedrock')
guardrails = bedrock_client.list_guardrails()
this_guardrail_id = None
for guardrail in guardrails['guardrails']:
    if guardrail['name'] == 'sanitizer':
        this_guardrail_id = guardrail['id']
        break
print("guardrail ID: " + this_guardrail_id)

# reference to the foundation model via amazon bedrock.
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-parameters.html
model = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",  # claude is awesome!    
    model_kwargs={
        "temperature": 0.95,
        "top_p": 0.95,
        "top_k": 250
    },
    guardrails={                                         # guardrails filter prompts and responses
        "guardrailIdentifier": this_guardrail_id, 
        "guardrailVersion": "1"
    }
)

In [None]:
unchanged_prompt = PromptTemplate(
    input_variables=["prompt"], 
    template="{prompt}"
)

In [None]:
rewrite_prompt = PromptTemplate(
    template="""
act as an expert in prompt engineering who is skilled at 
getting the most useful responses possible from a generative AI model.

rewrite the following to be a more effective prompt.
the prompt should emphasize a succinct response.

return only the text of the rewritten prompt, without a preamble.
                                           
{prompt}
"""
)

In [None]:
mapper = PromptTemplate(
    input_variables=["content"], 
    template="{content}"
)

transform_chain = TransformChain(
    input_variables=["content"], 
    output_variables=["content"], 
    transform=lambda data: {'content': data['content'].content}
)

In [None]:
chain = RunnableParallel(
    original = unchanged_prompt | model | StrOutputParser(),
    rewrite = rewrite_prompt | model | transform_chain | mapper  | model | StrOutputParser()
)

In [None]:
response = chain.invoke(
    {
        "prompt": "is long distance space travel possible?"
    }
)

print('ORIGINAL:')
print(response['original'])
print('---------------------------------------')
print('REWRITE:')
print(response['rewrite'])