In [None]:
from langchain_aws import ChatBedrock
from langchain.prompts import PromptTemplate
from langchain.chains import TransformChain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from rich.pretty import pprint

# maybe use a cheaper model for the rewrite?

In [None]:
model = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    model_kwargs={"temperature": 0.99, 'max_tokens': 20000},
    guardrails={"guardrailIdentifier": "pparlr97o7gz", "guardrailVersion": "1"}
)

In [None]:
unchanged_prompt = PromptTemplate(
    input_variables=["prompt"], 
    template="{prompt}"
)

In [None]:
rewrite_prompt = PromptTemplate(
    template="""
act as an expert in prompt engineering who is skilled at 
getting the most useful responses possible from a generative AI model.

rewrite the following to be a more effective prompt.
the prompt should emphasize a succinct response.

return only the text of the rewritten prompt, without a preamble.
                                           
{prompt}
"""
)

In [None]:
mapper = PromptTemplate(
    input_variables=["content"], 
    template="{content}"
)

transform_chain = TransformChain(
    input_variables=["content"], 
    output_variables=["content"], 
    transform=lambda data: {'content': data['content'].content}
)

In [None]:
chain = RunnableParallel(
    original = unchanged_prompt | model | StrOutputParser(),
    rewrite = rewrite_prompt | model | transform_chain | mapper  | model | StrOutputParser()
)

In [None]:
response = chain.invoke(
    {
        "prompt": "is long distance space travel possible?"
    }
)

print('ORIGINAL:')
print(response['original'])
print('---------------------------------------')
print('REWRITE:')
print(response['rewrite'])