# Self-reflection pattern

Self-reflection pattern contains at least 2 promtps: the prompt that perform task itself and prompt to reflect the previous response. In self-reflection pattern, 2 prompt are performed by same LLM model in separated fashion.

In this example, we are using transformers library to implement LLMChain logic.

In [1]:
import os
# Change root directory
os.chdir("../../src/")


In [None]:
from agent_design_pattern.agent import AgentMessage, LLMChain
from agent_design_pattern.orchestration import ReflectionAgent
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
class CasualOllamaChain(LLMChain):
    def __init__(self, model, system_prompt, user_prompt_template = "{query}", device="cuda", **kargs):
        super().__init__()
        # self.device = "cuda"
        # self.device = "auto"
        self.device = device
        if isinstance(model, str):
            self.tokenizer = AutoTokenizer.from_pretrained(model)
            # drop device_map if running on CPU
            self.model = AutoModelForCausalLM.from_pretrained(model, device_map=self.device)
            self.model.eval()
        else:
            self.tokenizer, self.model = model
        self.system_prompt = system_prompt
        self.user_prompt_template = user_prompt_template

    def invoke(self, message: AgentMessage, **kwargs) -> AgentMessage:
        user_prompt = self.user_prompt_template.format(**message.to_dict())
        chat = [
            { "role": "system", "content": self.system_prompt },
            { "role": "user", "content": user_prompt },
        ]
        chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        # tokenize the text
        input_tokens = self.tokenizer(chat, return_tensors="pt").to(self.device)
        # generate output tokens
        output = self.model.generate(**input_tokens, **kwargs)
        # decode output tokens into text
        output = self.tokenizer.batch_decode(output[:, input_tokens.input_ids.shape[-1]:], skip_special_tokens=True)
        # print output
        message.response = output[0]
        message.execution_result = "success"
        return message

In [7]:
system_prompt_task = """You are a helpful coding assistant.
You task is to write a python function and return the implementation of the function.
Some requirements:
- The logic is clear and easy to understand.
- The function arguments and return values (if any) should be typed.
- If the function is too long (for example greater than 80 lines), split the logic into multiple smaller functions.
- All functions should have docstring explanation. In the explanation, there should be an simple example to illustrate the function and how to call it.
- The response should contain function with docstring explanation. And DO NOT contain explanation outside of the code
"""
user_prompt_task = "{query}"
task_chain = CasualOllamaChain("ibm-granite/granite-4.0-h-1b", system_prompt_task, user_prompt_task, base_url="192.168.55.1::11434")

system_prompt_reflection = """You are a excellent code reviewer and refactor.
Given a function implementation and it explanation, your task is to review and code and correct if contains any mistake.
Some note:
- For the implementation, check if the orignal query and suggested implementation are match.
- Is there any syntax error in the code.
- For the explanation, verify if the docstring follows Google style docstring.
- In the docstring, make sure to have an example to call the function.

Make sure the final output only contain full function code, inline code comment and docstring, nothing else."""
user_prompt_reflection = "Input query: {query}\n\nFunction implementation: {context_response}"
# Use the same llm for task and self-reflection
reflection_chain = CasualOllamaChain([task_chain.tokenizer, task_chain.model], system_prompt_reflection, user_prompt_reflection, base_url="192.168.55.1::11434")  # use the same model for self-reflection

def state_callback(state: str):
    print(f"agent state: {state}")
reflection_agent = ReflectionAgent(task_chain, reflection_chain, state_change_callback=state_callback)

In [None]:
# Take a leetcode as an example. Source: https://leetcode.com/problems/palindrome-partitioning-ii/description/
query = """Write python function(s) to solve the following problem:
Given a string s, partition s such that every substring of the partition is a palindrome.
Return the minimum cuts needed for a palindrome partitioning of s.

Example 1:
Input: s = "aab"
Output: 1
Explanation: The palindrome partitioning ["aa","b"] could be produced using 1 cut.

Example 2:
Input: s = "a"
Output: 0

Example 3:
Input: s = "ab"
Output: 1

Constraints:
1 <= s.length <= 2000
s consists of lowercase English letters only."""

final_message = reflection_agent.execute(AgentMessage(query=query), max_new_tokens=16384)
print(final_message)
print("final response")
print(final_message.response)

agent state: running
agent state: reflecting
agent state: idle
query='Write python function(s) to solve the following problem:\nGiven a string s, partition s such that every substring of the partition is a palindrome.\nReturn the minimum cuts needed for a palindrome partitioning of s.\n\nExample 1:\nInput: s = "aab"\nOutput: 1\nExplanation: The palindrome partitioning ["aa","b"] could be produced using 1 cut.\n\nExample 2:\nInput: s = "a"\nOutput: 0\n\nExample 3:\nInput: s = "ab"\nOutput: 1\n\nConstraints:\n1 <= s.length <= 2000\ns consists of lowercase English letters only.' origin='ReflectionAgent_2' response='Here\'s the full function code with inline comments and docstring:\n\n```python\ndef minCut(s: str) -> int:\n    """\n    Returns the minimum cuts needed for a palindrome partitioning of the given string s.\n\n    Args:\n    s (str): The input string.\n\n    Returns:\n    int: The minimum cuts needed for palindrome partitioning.\n    """\n    n = len(s)\n    dp = [float(\'inf\'